Using Transformers to teach Transformers how to train Transformers
Piotr Kozakowski
joint work with Łukasz Kaiser and Afroz Mohiuddin
at Google Brain
Hyperparameter tuning
Hyperparameters in deep learning:
- learning rate
- dropout rate
- moment decay in adaptive optimizers
- ...
Tuning is important, but hard.
Done manually, takes a lot of time.
Needs to be re-done for every new architecture and task.
Hyperparameter tuning
Some require scheduling, which takes even more work.
Zagoruyko et al. - Wide Residual Networks, 2016
Automatic methods
Existing methods:
Problems:
- not learnable
- hyperparameters typically fixed during training
- typically not adaptive
- grid/random search
- Bayesian optimization
- evolutionary algorithms
Solution: reinforcement learning
Learn a policy for controlling hyperparameters based on the observed validation metrics.
Can train on a wide range of architectures and tasks.
Zero-shot or few-shot transfer to new architectures and tasks.
Long-term goal: a general system that can automatically tune any model.
Open-source the policies, so all ML practitioners can use them.
Tuning as an RL problem
agent
environment
observations: validation metrics
rewards: changes in a chosen metric
actions: hyperparameter changes (discrete)
Tuning as an RL problem
Partially observable: observing all parameter values is intractable.
Nondeterministic: random weight initialization and dataset permutation.
Tasks
Language modeling:
- Transformer on LM1B
- Transformer on Penn Treebank
Translation:
Image classification:
- Transformer on WMT EN -> DE
- Wide ResNet on CIFAR-10
I'm going to
eat
school
France
it's windy today
heute ist es windig
frog
Transformer language model
Vaswani et al. - Attention Is All You Need, 2017
Tuned hyperparameters
For Transformers:
- learning rate
- weight decay rate
- dropouts, separately for each layer
For Wide ResNet:
- learning rate
- weight decay rate
- momentum mass in SGD
Model-free approach: PPO
PPO: Proximal Policy Optimization
- more sample-efficient than REINFORCE
- stable
- widely used
Use the Transformer language model without input embedding as a policy.
Schulman et al. - Proximal Policy Optimization Algorithms, 2017
Model-free approach: PPO
Experiment setup:
- optimizing for accuracy on a holdout set
- 20 PPO epochs
- 128 parallel model trainings in each epoch
- ~3h per model training, ~60h for the entire experiment
Model-free approach: PPO
Model-based approach: SimPLe
SimPLe: Simulated Policy Learning
Elements:
- policy
- "world model"
Train the world model on data collected in the environment.
Train the policy using PPO in the environment simulated by the world model.
Much more sample-efficient than model-free PPO.
Kaiser et al. - Model-Based Reinforcement Learning for Atari, 2019
Model-based approach: SimPLe
Time series forecasting
The metric curves are stochastic.
Predict the next point in the sequence:
Common approach: use a parametric distribution,
e.g. Gaussian:
Time series forecasting
Our approach: discretize to a fixed-point representation
and predict consecutive digits (symbols).
This way we can model any distribution within a set precision.
Experiment on synthetic data
dataset
prediction
discretization
Gaussian distribution
Transformer as a world model
Modeled sequence:
Input: both observations and actions.
Predict only observations.
Calculate rewards based on the last two observations.
Transformer as a world model
Transformer as a world model
Inference speed: < 1 minute to sample 128 episodes.
In comparison, > 1 hour to train one real architecture.
World model is at least 128 * 60 = 7680 times faster!
Transformer as a policy
Share the architecture with the world model.
Input same as for the world model.
Output: action distribution, value estimate.
Preinitialize from the parameters of a trained world model.
This leads to much faster learning.
SimPLe results
Experiment setup:
- optimizing for accuracy on a holdout set
- starting from a dataset of 4 * 20 * 128 = 10240 trajectories collected by PPO
- 10 SimPLe epochs, 50 simulated PPO epochs each
- 128 parallel model trainings in each data gathering phase
- ~3h for data gathering, ~1h for world model training, ~2h for policy training, ~60h for the entire experiment
SimPLe results
SimPLe vs PPO vs human
task | SimPLe | PPO | human |
---|---|---|---|
LM1B | 35.9% | 30.2% | 35% |
WMT EN -> DE | 59.9% | 49.5% | 60% |
Penn Treebank | 23.4% | 19.2% | 23.2% |
CIFAR-10 | 91.6% | 91.2% | 90% |
Final test accuracies:
Learned schedules - LM1B
Summary
Using world models allows faster training of better policies.
One of the first successful practical applications of
model-based RL.
Amount of data needed currently: ~11K model trainings.
Comparable to the first work on Neural Architecture Search.
Zoph et al. - Neural Architecture Search with Reinforcement Learning, 2016
Future work
Transfer: training general policies or world models to enable wide use.
Planning using the model (model predictive control).
Evaluation in settings that are notoriously unstable (unsupervised/reinforcement learning); adaptive tuning should help.
Speaker: Piotr Kozakowski (p.kozakowski@mimuw.edu.pl)
Paper under review for ICLR 2020:
Forecasting Deep Learning Dynamics with Applications to Hyperparameter Tuning
References:
Zagoruyko et al. - Wide Residual Networks, 2016
Vaswani et al. - Attention Is All You Need, 2017
Schulman et al. - Proximal Policy Optimization Algorithms, 2017
Kaiser et al. - Model-Based Reinforcement Learning for Atari, 2019
Zoph et al. - Neural Architecture Search with Reinforcement Learning, 2016
Using Transformers to teach Transformers how to train Transformers
By Piotr Kozakowski
Using Transformers to teach Transformers how to train Transformers
- 785