Piotr Kozakowski
joint work with Łukasz Kaiser and Afroz Mohiuddin
at Google Brain
Hyperparameters in deep learning:
Tuning is important, but hard.
Done manually, takes a lot of time.
Needs to be re-done for every new architecture and task.
Some require scheduling, which takes even more work.
Zagoruyko et al. - Wide Residual Networks, 2016
Existing methods:
Problems:
Learn a policy for controlling hyperparameters based on the observed validation metrics.
Can train on a wide range of architectures and tasks.
Zero-shot or few-shot transfer to new architectures and tasks.
Long-term goal: a general system that can automatically tune any model.
Open-source the policies, so all ML practitioners can use them.
agent
environment
observations: validation metrics
rewards: changes in a chosen metric
actions: hyperparameter changes (discrete)
Partially observable: observing all parameter values is intractable.
Nondeterministic: random weight initialization and dataset permutation.
Language modeling:
Translation:
Image classification:
I'm going to
eat
school
France
it's windy today
heute ist es windig
frog
Vaswani et al. - Attention Is All You Need, 2017
For Transformers:
For Wide ResNet:
PPO: Proximal Policy Optimization
Use the Transformer language model without input embedding as a policy.
Schulman et al. - Proximal Policy Optimization Algorithms, 2017
Experiment setup:
SimPLe: Simulated Policy Learning
Elements:
Train the world model on data collected in the environment.
Train the policy using PPO in the environment simulated by the world model.
Much more sample-efficient than model-free PPO.
Kaiser et al. - Model-Based Reinforcement Learning for Atari, 2019
The metric curves are stochastic.
Predict the next point in the sequence:
Common approach: use a parametric distribution,
e.g. Gaussian:
Our approach: discretize to a fixed-point representation
and predict consecutive digits (symbols).
This way we can model any distribution within a set precision.
dataset
prediction
discretization
Gaussian distribution
Modeled sequence:
Input: both observations and actions.
Predict only observations.
Calculate rewards based on the last two observations.
Inference speed: < 1 minute to sample 128 episodes.
In comparison, > 1 hour to train one real architecture.
World model is at least 128 * 60 = 7680 times faster!
Share the architecture with the world model.
Input same as for the world model.
Output: action distribution, value estimate.
Preinitialize from the parameters of a trained world model.
This leads to much faster learning.
Experiment setup:
task | SimPLe | PPO | human |
---|---|---|---|
LM1B | 35.9% | 30.2% | 35% |
WMT EN -> DE | 59.9% | 49.5% | 60% |
Penn Treebank | 23.4% | 19.2% | 23.2% |
CIFAR-10 | 91.6% | 91.2% | 90% |
Final test accuracies:
Using world models allows faster training of better policies.
One of the first successful practical applications of
model-based RL.
Amount of data needed currently: ~11K model trainings.
Comparable to the first work on Neural Architecture Search.
Zoph et al. - Neural Architecture Search with Reinforcement Learning, 2016
Transfer: training general policies or world models to enable wide use.
Planning using the model (model predictive control).
Evaluation in settings that are notoriously unstable (unsupervised/reinforcement learning); adaptive tuning should help.
Speaker: Piotr Kozakowski (p.kozakowski@mimuw.edu.pl)
Paper under review for ICLR 2020:
Forecasting Deep Learning Dynamics with Applications to Hyperparameter Tuning
References:
Zagoruyko et al. - Wide Residual Networks, 2016
Vaswani et al. - Attention Is All You Need, 2017
Schulman et al. - Proximal Policy Optimization Algorithms, 2017
Kaiser et al. - Model-Based Reinforcement Learning for Atari, 2019
Zoph et al. - Neural Architecture Search with Reinforcement Learning, 2016