Lecture 09:

Memory in RL

Artyom Sorokin | 06 Apr

Memory is Important

in many tasks, but we start from Reinforcement Learning...

Working Memory

Markov Decision Process  

Basic Theoretical Results in Model-Free Reinforcement Learning are proved for Markov Decision Processes.

 

Markovian propery:

 

 

 

 

 

 

 

 

 

 

 

In other words:   "The future is independent of the past given the present."

p(s_{t+1}, r_{t} | s_t, a_t, s_{t-1}, a_{t-1}, ..., s_{1}, a_{1}, s_{0}) = p( s_{t+1}, r_{t} | s_t, a_t )

When does agent observe the state?

Partially Observable MDP

Definition

Graphical Model for POMDP:

POMDP is a 6-tuple \(<S,A,R,T,\Omega, O>\):

  • \(S\) is a set of states
  • \(A\) is a set of actions
  • \(R: S \times A \to \mathbb{R}\) is a reward function
  • \(T: S \times A \times S \to [0,1]\) is a transition function \(T(s,a,s\prime) =  P(S_{t+1}=s\prime|S_t=s, A_t=a)\)
  • \(O\) is a set of observations.
  • \(\Omega\) is a set of \(|O|\) conditional probability distributions \(P(o|s)\)

Partially Observable MDP

 Exact Solution

A proper belief state allows a POMDP to be formulated as a MDP over belief states (Astrom, 1965)

 

Belief State update:

 

 

 

 

 

 

 

General Idea:

Belief Update        "Beilief" MDP        Plan with Value Iteration        Policy

b_0(s) = P(S_0=s)
b_{t+1}(s) = p(s|o_{t+1}, a_{t}, b_{t}) = \dfrac{p(s, o_{t+1}|a_t, b_t)}{p(o_{t+1}|a_t, b_t)}
\propto p(o_{t+1},| s, a_t, b_t)\, p(s | a_t, b_t)
= p(o_{t+1}| s) \sum_{s_i} p(s | a_t, s_i)\, b_t(s_i)

Problems: 

  • Need a model: \(p(o|s)\) and \(p(s'| a, s)\)
  • Can compute exact belief update only for small/simple MDP

 

Huge sum or even integral over all states

Learning in POMDP

Choose your fighter

Learning in POMDP

Don't Give Up and Approximate

Approximate belief states:

  • Deep Variational Belief Filters(Karl et al, ICLR 2017) 
  • Deep Variational Reinforcement Learning for POMDPs(Igl et al, ICML 2018)
  • Discriminative Particle Filter Reinforcement Learning(Ma et al, ICLR 2020)

(Ma et al, ICLR 2020)

Learning in POMDP

Look into the Future

Predctive State Representations:

  • Predictive State Representations (Singh et al, 2004) 
  • Predictive-State Decoders: Encoding the Future into RNN (Venkatraman et al, NIPS 2017)
  • Recurrent Predictive State Policy Networks (Hefny et al, 2018)

Learning in POMDP

Relax and Use Memory

 

 

Window-based Memory:

  • Control of Memory, Active Perception, and Action in Minecraft (Oh et al, 2016) 
  • Stabilizing Transformers For Reinforcement Learning (Parisotto et al, 2019)
  • Obstacle Tower Challenge winner solution (Nichols, 2019)* 

 

Memory as RL problem:

  • Learning Policies with External Memory ( Peshkin et al, 2001)
  • Reinforcement Learning Neural Turing Machines (Zaremba et al, 2015)
  • Learning Deep NN Policies with Continuous Memory States (Zhang et al, 2015)

Recurrent Neural Networks:

  • DRQN (Hausknecht et al, 2015)
  • A3C-LSTM (Mnih et al, 2016)
  • Neural Map (Parisotto et al, 2017)
  • MERLIN (Wayne et al, 2018)
  • Relational Recurrent Neural Networks (Santoro et al, 2018)
  • Aggregated Memory for Reinforcement Learning (Beck et al, 2020)

Why We Need Working Memory?

\(obs_{t=10}\)

\(obs_{t=12}\)

\(obs_{t=20}\)

\(act_{t=10}\)

\(act_{t=12}\)

\(act_{t=13}\)

We need this information

At this moment!

Memory

. . .

. . .

Recurrent Memory:

\(obs_{t=10}\)

\(obs_{t=12}\)

\(obs_{t=20}\)

\(act_{t=10}\)

\(act_{t=12}\)

\(act_{t=20}\)

\(h_{10}\)

. . .

. . .

\(h_9\)

\(h_{19}\)

. . .

. . .

. . .

Information

Gradients

Problem with RNNs:

Solutions for Vanishing Gradients

Fight RNN problems by building more complex RNNs

Long-Short Term Memory: LSTM

Differential Neural Computer: DNC

Window Based Memory:

\(obs_{t=10}\)

\(obs_{t=12}\)

\(obs_{t=20}\)

\(act_{t=10}\)

\(act_{t=12}\)

\(act_{t=20}\)

. . .

. . .

Information

Gradients

Memory Window

Soft-Attention:

\(obs_{t=10}\)

\(obs_{t=12}\)

\(obs_{t=20}\)

. . .

. . .

\(e_{10}\)

\(e_{12}\)

\(a_{10}\)

\(e_{20}\)

Embeddings

Query

\times
\times
\times

\(a_{12}\)

Attention weight: \(a_t = {e_{t}}^T q/\sum_i {e_{i}}^T q \)

\(q\)

Soft-Attention:

\(obs_{t=10}\)

\(obs_{t=12}\)

\(obs_{t=20}\)

. . .

. . .

\(e_{10}\)

\(e_{12}\)

\(a_{10}\)

Embeddings

Query

\times
\times
\times

\(a_{12}\)

Context Vector \(c_{20} = \sum_t a_t e_t\)

Attention weight: \(a_t = {e_{t}}^T q/\sum_i {e_{i}}^T q \)

\(q\)

Transformer Basics: Self-Attention

Transformer Basics: Self-Attention

Self-Attention computation:

Each \(z_t\) contains relevant information about \(o_t\) collected over all steps in

Memory Window:

Full Transformer

This is how real Transformer looks:

Kind of...

This is how real Transformer looks:

Transformer:

  • All time-steps in a memory window attends to all other time-steps
  • For each time-step you have: Query \(q_t\), Key \(k_t\), Value \(v_t\)
  • Transformer has N attention heads!
  • Transformer uses positional encoding to relate different time steps temporarily
  • You repeat this process for several layers!

Transformer

Attention is All You Need!

AutoEncoders

 

Variational AutoEncoder: VAE

AutoEncoder

Just add LSTM to everything

Off-Policy Learning (DRQN):

  • Add LSTM before last 1-2 layers
  • Sample sequences of steps from Experience Replay

 

On-Policy Learning (A3C/PPO):

  • Add LSTM before last 1-2 layers
  • Keep LSTM hidden state \(h_t\) between rollouts 

 

Asynchronous Methods for Deep Reinforcement Learning (Mnih et al,  2016) | DeepMind, ICML, 3113 citations

 

Deep Recurrent Q-Learning for Partially Observable MDPs

(Hausknecht et al, 2015) AAAI,  582 citations )

Just add LSTM to everything

Default choice for memory in big projects

"To deal with partial observability, the temporal sequence of observations is processed by a deep long short-term memory (LSTM) system"

AlphaStar Grandmaster level in StarCraft II using multi-agent reinforcement learning (Vinyalis et al,  2019) | DeepMind, Nature, 16 Citations

Just add LSTM to everything

Default choice for memory in big projects

"The LSTM composes 84% of the model’s total parameter count."

Dota 2 with Large Scale Deep Reinforcement Learning (Berner et al,  2019) | OpenAI, 17 Citations

R2D2: We can do better

DRQN tests two sampling methods:

  • Sample full episode sequences
    • Problem: sample correlation in mini-batch is proportional to the sequence length 
  • Sample random sub-sequences of length k (10 steps in the paper)
    • Problem: initial hidden state is zero at the start of a rollout

Recurrent Experience Replay in Distributed Reinforcement  Learning (Kapturowski et al,  2019) | DeepMind, ICLR, 49 citations

R2D2: We can do better

R2D2 is a DRQN build on top of Ape-X (Horgan et al, 2018) with addition of two heuristics:

  • Stored state: Storing the recurrent state in replay and using it to initialize the network at training time
  • Burn-in: Use a portion of the replay sequence only for unrolling the network and producing a start state, and update the network only on the remaining part of the sequence

 

 

 

 

 

 

 

 

 

Burn-in - 40 steps, full rollout - 80 steps

Recurrent Experience Replay in Distributed Reinforcement  Learning (Kapturowski et al,  2019) | DeepMind, ICLR, 49 citations

R2D2

Results: Atari-57

R2D2

Results: DMLab-30

Recurrent Experience Replay in Distributed Reinforcement  Learning (Kapturowski et al,  2019) | DeepMind, ICLR, 49 citations

AMRL

Motivation

Recurrent Neural Networks:

  • good at tracking order of observations

  • susceptible to noise in observations

  • bad at long-term dependencies

RL Tasks:

  • order often doesn't matter

  • high variability in observation sequences

  • long-term dependencies

AMRL: Aggregated Memory For Reinforcement Learning (2020) | MS Research, ICLR

AMRL:

Robust Aggregators

Add aggregators that ignore order of observations: 













Agregators also act as residual skip connections across time.

Instead of true gradients a straight-through estimator(Bengio et al., 2013) is used.


 

AMRL: Aggregated Memory For Reinforcement Learning (2020) | MS Research, ICLR

AMRL

Architecture and Baselines

AMRL: Aggregated Memory For Reinforcement Learning (2020) | MS Research, ICLR

AMRL

Experiments

AMRL: Aggregated Memory For Reinforcement Learning (2020) | MS Research, ICLR

AMRL

Experiments

AMRL

Experiments

AMRL: Aggregated Memory For Reinforcement Learning (2020) | MS Research, ICLR

Are rewards enough to learn memory?

  • Learning only by optimizing future rewards

    • A3C/PPO + LSTM

    • DRQN, R2D2
    • AMRL

 

  • What if we know a little bit more?

    • Neural Map (Parisotto et al, 2017)
    • Working Memory Graphs(Loynd et al, 2020,)

 

  • Learing with rich self-supervised sensory signals

    • World Models (Ha et al, 2018, Mar 27*)
    • MERLIN (Wayne et al, 2018, Mar 28*)

MERLIN

Unsupervised Predictive Memory in a Goal-Directed Agent (2018) | DeepMind, 67 citations

MERLIN has two basic components:

  • Model-Based Predictor

    • a monstrous combination of VAE and Q-function estimator
    • uses simplified DNC under the hood
  • Policy

    • no gradients flow between policy and MBR
    • trained with Policy Gradients and GAE 

MERLIN is trained on-policy in A3C-like manner:

  • 192 parallel workers, 1 parameter server
  • rollout length is 20-24 steps

 

VAE:

DNC:

MERLIN

Unsupervised Predictive Memory in a Goal-Directed Agent (2018) | DeepMind, 67 citations

Model-Based Predictor:

  • MBR is optimized to be a "world model", i.e. to produce predictions that are consistent with observed trajectories: \(p(o_{0:T}, R_{0:T})\)

  • Compresses observations \(o_t\) into low-dimensional state representations \(z_t\) (inspired by compressive bottleneck theory)

  • Stores all previous representations \(z_{0:t}\) in memory

  • Uses it's memory to predict future: \(o_{t+1}\) and \(R_{t+1}\)

We do this with VAE

\textcolor{red}{z_t}
\textcolor{red}{o_t}
\textcolor{red}{o_t}

we can just predict \(z_{t+1}\)

\textcolor{red}{R_t}

DNC lives here

MERLIN

MBR Loss

Model-Based Predictor has a loss function based on the variational lower bound:

 

 

 

 

 

Reconstruction Loss: 





KL Loss:

Forces MBR to make good compressed representations \(z_t\); recall VAE

Train memory to make good predictions

We want MBR to make good predictions of observations and Q-values

\(z_{t}\) prediction from memory (steps: 0 .. t-1)

actual compressed representation \(z_t\) given \(o_t\) and memory from (steps: 0 .. t-1)

MERLIN

Memory-Based Predictor

Prior Distribution

Module takes all memory from the previous step and produces parameters of Diagonal Gaussian distribution: 

 

 

 

Posterior Distribution

Another MLP \(f^{post}\) takes:

                                

                                                                                                     

and generates correction for the prior:

                  

                               

 At the end, latent state variable \(z_t\) is sampled from posterior distribution.

\textcolor{blue}{[ \mu_{t}^{prior}, log \Sigma^{prior}_{t} ]} = f^{prior}(h_{t-1}, m_{t-1})
n_t = [e_t, h_{t-1}, m_{t-1}, \textcolor{blue}{\mu_{t}^{prior}, log \Sigma^{prior}_{t}} ]
\textcolor{red}{[\mu^{post}_{t}, log \Sigma^{post}_{t}]} = f^{post}(n_t) + \textcolor{blue}{[\mu^{prior}_{t}, log \Sigma^{prior}_{t}]}

MERLIN

Architecture

Unsupervised Predictive Memory in a Goal-Directed Agent (2018) | DeepMind, 67 citations

MERLIN

Experiments

MERLIN is compared against two baselines: A3C-LSTM, A3C-DNC

 

MERLIN

Experiments

Unsupervised Predictive Memory in a Goal-Directed Agent (2018) | DeepMind, 67 citations

Stabilizing Transformers for RL

Stabilizing Transformers For Reinforcement Learning (2019) | DeepMind

Stabilizing Transformers for RL

Gating Layer

Stabilizing Transformers For Reinforcement Learning (2019) | DeepMind

Stabilizing Transformers for RL

Experiments

Stabilizing Transformers For Reinforcement Learning (2019) | DeepMind

Stabilizing Transformers for RL

Ablation Study

Stabilizing Transformers For Reinforcement Learning (2019) | DeepMind

Episodic Memory

Memory for MDP?

Neural Networks can't adapt fast:

  • Catastrophic interference i.e knowledge in neural networks is non-local

  • Nature of  Gradient Descent

Episodic Memory

Motivation

Semantic memory makes better use of experiences (i.e. better generalization)

​Episodic memory requires fewer experiences (i.e. more accurate)

Episodic Memory

Experiment with Tree MDP

"We will show that in general, just as model-free control is better than model-based control after substantial experience, episodic control is better than model-based control after only very limited experience."

                                          (Lengyel, Dayan, NIPS 2008)

 

A Tree MDP is just a MPD without circles.

 

Episodic Memory

Experiment with Tree MDP

Text

Text

A) Tree MDP with branching factor = 2

B) Tree MDP with branching factor = 3

C) Tree MDP with branching factor = 4

Model-Free Episodic Control

Lets store all past experiences in \(|A|\) dictionaries \(Q_{a}^{EC} \)


\(s_t, a_t \) are keys and discounted future rewards \(R_t\) are values.

 

Dictionary update:

 

 

 

If a state space has a meaningful distance, then we can use k-nearest neightbours to estimate new \((s,a)\) pairs:

 

Model-Free Episodic Control (2016) | DeepMind, 100 citations

Model-Free Episodic Control

Lets store all past experiences in \(|A|\) dictionaries \(Q_{a}^{EC} \)


\(s_t, a_t \) are keys and discounted future rewards \(R_t\) are values.

 

Dictionary update:

 

 

 

If a state space has a meaningful distance, then we can use k-nearest neightbours to estimate new \((s,a)\) pairs:

 

 

 

 

Two possible feature compressors for \(s_t\): Random Network, VAE

 

Model Free Episodic Control

Results

Test environments:

  • Some games from Atari 57
  • 3D Mazes in DMLab

Neural Episodic Control

Deep RL + Semantic Memory

Differences with Model-Free Episodic Control:
  • CNN instead of VAE/RandomNet
  • Differential Neural Dictionaries (DND)
  • Replay Memory like in DQN, but small...
  • CNN and DND learn with gradient descent

Neural Episodic Control (2017) | DeepMind, 115 citations

Neural Episodic Control

Differential Neural Dictionaries

For each action  \(a \in A \), NEC has a dictionary \(M_a = (K_a , V_a )\).

Keys and Queries are generated by CNN

 

Neural Episodic Control

Differential Neural Dictionaries

To estimate \(Q(s_t, a)\) we sample p-nearest neighbors from \(M_a\)

\(k\) is a kernel for distance estimate. In experiments:

 

Neural Episodic Control

DND Update

Once a key \(h_t\) is queried from a DND, that key and its corresponding output are appended to the DND. If \(h_t \not\in M_a\) then we just store it with N-step Q-value estimate:

 

 

 

 

otherwise, we update stored value with tabular Q-learning rule:

 

Learn DND and CNN-encoder:

Sample mini-batches from replay buffer that stores triplets \((s_t,a_t, R_t)\) and use \(R_t\) as a target.

Neural Episodic Control (2016) | DeepMind, 115 citations

Replay Buffer here \(\ne\) DND

Neural Episodic Control

Experiments

Neural Episodic Control

Experiments

Neural Episodic Control (2016) | DeepMind, 115 citations

Resume

Working Memory

  • Can Generalize
  • Hard to increase capacity

 

Episodic Memory

  • Hard to generalize
  • Can store a lot of information

Thank you for your attention!