Using Transformers to teach Transformers how to train Transformers

Piotr Kozakowski

joint work with Łukasz Kaiser and Afroz Mohiuddin

at Google Brain

Hyperparameter tuning

Hyperparameters in deep learning:

  • learning rate
  • dropout rate
  • moment decay in adaptive optimizers
  • ...

Tuning is important, but hard.

Done manually, takes a lot of time.

Needs to be re-done for every new architecture and task.

Hyperparameter tuning

Some require scheduling, which takes even more work.

\eta = 0.1
\eta = 0.02
\eta = 0.004
\eta = 0.0008

Zagoruyko et al. - Wide Residual Networks, 2016

Automatic methods

Existing methods:

 

 

 

Problems:

  • not learnable
  • hyperparameters typically fixed during training
  • typically not adaptive
  • grid/random search
  • Bayesian optimization
  • evolutionary algorithms

Solution: reinforcement learning

Learn a policy for controlling hyperparameters based on the observed validation metrics.

Can train on a wide range of architectures and tasks.

Zero-shot or few-shot transfer to new architectures and tasks.

Long-term goal: a general system that can automatically tune any model.

Open-source the policies, so all ML practitioners can use them.

Tuning as an RL problem

agent

environment

observations: validation metrics

rewards: changes in a chosen metric

actions: hyperparameter changes (discrete)

Tuning as an RL problem

Partially observable: observing all parameter values is intractable.

 

Nondeterministic: random weight initialization and dataset permutation.

Tasks

Language modeling:

  • Transformer on LM1B
  • Transformer on Penn Treebank 

Translation:

 

Image classification:

  • Transformer on WMT EN -> DE
  • Wide ResNet on CIFAR-10
I'm going to
eat
school
France
it's windy today
heute ist es windig
frog

Transformer language model

Vaswani et al. - Attention Is All You Need, 2017

Tuned hyperparameters

For Transformers:

  • learning rate
  • weight decay rate
  • dropouts, separately for each layer                              

For Wide ResNet:

  • learning rate
  • weight decay rate
  • momentum mass in SGD                                               

Model-free approach: PPO

PPO: Proximal Policy Optimization

  • more sample-efficient than REINFORCE
  • stable
  • widely used

Use the Transformer language model without input embedding as a policy.

Schulman et al. - Proximal Policy Optimization Algorithms, 2017

Model-free approach: PPO

Experiment setup:

  • optimizing for accuracy on a holdout set
  • 20 PPO epochs
  • 128 parallel model trainings in each epoch
  • ~3h per model training, ~60h for the entire experiment

Model-free approach: PPO

Model-based approach: SimPLe

SimPLe: Simulated Policy Learning

Elements:

  • policy
  • "world model"                                                                   

Train the world model on data collected in the environment.

Train the policy using PPO in the environment simulated by the world model.

Much more sample-efficient than model-free PPO.

\pi : O^* \rightarrow P(A)
\epsilon : O^* \times A \rightarrow P(O)

Kaiser et al. - Model-Based Reinforcement Learning for Atari, 2019

Model-based approach: SimPLe

Time series forecasting

The metric curves are stochastic.

Predict the next point in the sequence:

 

 

Common approach: use a parametric distribution,

e.g. Gaussian:

P(x_1, ..., x_n) = {\prod_{i=1}^n} P(x_i | x_1, ..., x_{i - 1})
P(x_i | x_1, ..., x_{i - 1}) = \mathcal{N}(f(x_1, ..., x_{i - 1}), \sigma^2)

Time series forecasting

Our approach: discretize to a fixed-point representation

and predict consecutive digits (symbols).

This way we can model any distribution within a set precision.

Experiment on synthetic data

dataset

prediction

discretization

Gaussian distribution

Transformer as a world model

Modeled sequence:

Input: both observations and actions.

Predict only observations.

Calculate rewards based on the last two observations.

o_1 a_1 o_2 a_2 \dots o_n

Transformer as a world model

Transformer as a world model

Inference speed: < 1 minute to sample 128 episodes.

In comparison, > 1 hour to train one real architecture.

World model is at least 128 * 60 = 7680 times faster!

Transformer as a policy

Share the architecture with the world model.

Input same as for the world model.

Output: action distribution, value estimate.

 

Preinitialize from the parameters of a trained world model.

This leads to much faster learning.

SimPLe results

Experiment setup:

  • optimizing for accuracy on a holdout set
  • starting from a dataset of  4 * 20 * 128 = 10240 trajectories collected by PPO
  • 10 SimPLe epochs, 50 simulated PPO epochs each
  • 128 parallel model trainings in each data gathering phase
  • ~3h for data gathering, ~1h for world model training, ~2h for policy training, ~60h for the entire experiment

SimPLe results

SimPLe vs PPO vs human

task SimPLe PPO human
LM1B 35.9% 30.2% 35%
WMT EN -> DE 59.9% 49.5% 60%
Penn Treebank 23.4% 19.2% 23.2%
CIFAR-10 91.6% 91.2% 90%

Final test accuracies:

Learned schedules - LM1B

Summary

Using world models allows faster training of better policies.

 

One of the first successful practical applications of

model-based RL.

 

Amount of data needed currently: ~11K model trainings.

Comparable to the first work on Neural Architecture Search.

Zoph et al. - Neural Architecture Search with Reinforcement Learning, 2016

Future work

Transfer: training general policies or world models to enable wide use.

 

Planning using the model (model predictive control).

 

Evaluation in settings that are notoriously unstable (unsupervised/reinforcement learning); adaptive tuning should help.

Speaker: Piotr Kozakowski (p.kozakowski@mimuw.edu.pl)

Paper under review for ICLR 2020:

Forecasting Deep Learning Dynamics with Applications to Hyperparameter Tuning

References:

Zagoruyko et al. - Wide Residual Networks, 2016

Vaswani et al. - Attention Is All You Need, 2017

Schulman et al. - Proximal Policy Optimization Algorithms, 2017

Kaiser et al. - Model-Based Reinforcement Learning for Atari, 2019

Zoph et al. - Neural Architecture Search with Reinforcement Learning, 2016

Using Transformers to teach Transformers how to train Transformers

By Piotr Kozakowski

Using Transformers to teach Transformers how to train Transformers

  • 677