Artyom Sorokin | 08 Apr
in many tasks, but we start from Reinforcement Learning...
Basic Theoretical Results in Model-Free Reinforcement Learning are proved for Markov Decision Processes.
Markovian propery:
In other words: "The future is independent of the past given the present."
Graphical Model for POMDP:
POMDP is a 6-tuple \(<S,A,R,T,\Omega, O>\):
A proper belief state allows a POMDP to be formulated as a MDP over belief states (Astrom, 1965)
Belief State update:
General Idea:
Belief Update "Beilief" MDP Plan with Value Iteration Policy
Problems:
Huge sum or even integral over all states
(Ma et al, ICLR 2020)
\(obs_{t=10}\)
\(obs_{t=12}\)
\(obs_{t=20}\)
\(act_{t=10}\)
\(act_{t=12}\)
\(act_{t=13}\)
We need this information
At this moment!
Memory
. . .
. . .
\(obs_{t=10}\)
\(obs_{t=12}\)
\(obs_{t=20}\)
\(act_{t=10}\)
\(act_{t=12}\)
\(act_{t=20}\)
\(h_{10}\)
. . .
. . .
\(h_9\)
\(h_{19}\)
. . .
. . .
. . .
Information
Gradients
Long-Short Term Memory: LSTM
Differential Neural Computer: DNC
\(obs_{t=10}\)
\(obs_{t=12}\)
\(obs_{t=20}\)
\(act_{t=10}\)
\(act_{t=12}\)
\(act_{t=20}\)
. . .
. . .
Memory Window
Information
Gradients
\(obs_{t=10}\)
\(obs_{t=12}\)
\(obs_{t=20}\)
. . .
. . .
\(e_{10}\)
\(e_{12}\)
\(a_{10}\)
\(e_{20}\)
Embeddings
Query
\(a_{12}\)
Attention weight: \(a_t = {e_{t}}^T q/\sum_i {e_{i}}^T q \)
\(q\)
\(obs_{t=10}\)
\(obs_{t=12}\)
\(obs_{t=20}\)
. . .
. . .
\(e_{10}\)
\(e_{12}\)
\(a_{10}\)
Embeddings
Query
\(a_{12}\)
Context Vector \(c_{20} = \sum_t a_t e_t\)
Attention weight: \(a_t = {e_{t}}^T q/\sum_i {e_{i}}^T q \)
\(q\)
Self-Attention computation:
Each \(z_t\) contains relevant information about \(o_t\) collected over all steps in
Memory Window:
This is how real Transformer looks:
Kind of...
This is how real Transformer looks:
Variational AutoEncoder: VAE
AutoEncoder
Asynchronous Methods for Deep Reinforcement Learning (Mnih et al)
Deep Recurrent Q-Learning for Partially Observable MDPs
(Hausknecht et al.)
"To deal with partial observability, the temporal sequence of observations is processed by a deep long short-term memory (LSTM) system"
AlphaStar Grandmaster level in StarCraft II using multi-agent reinforcement learning (Vinyalis et al)
"The LSTM composes 84% of the model’s total parameter count."
Dota 2 with Large Scale Deep Reinforcement Learning (Berner et al)
Recurrent Experience Replay in Distributed Reinforcement Learning (Kapturowski et al.)
R2D2 is a DRQN build on top of Ape-X (Horgan et al, 2018) with addition of two heuristics:
Burn-in - 40 steps, full rollout - 80 steps
Recurrent Experience Replay in Distributed Reinforcement Learning (Kapturowski et al)
Recurrent Experience Replay in Distributed Reinforcement Learning (Kapturowski et al)
good at tracking order of observations
susceptible to noise in observations
bad at long-term dependencies
order often doesn't matter
high variability in observation sequences
long-term dependencies
AMRL: Aggregated Memory For Reinforcement Learning
Add aggregators that ignore order of observations: Agregators also act as residual skip connections across time. Instead of true gradients a straight-through estimator(Bengio et al., 2013) is used.
AMRL: Aggregated Memory For Reinforcement Learning
AMRL: Aggregated Memory For Reinforcement Learning
AMRL: Aggregated Memory For Reinforcement Learning
AMRL: Aggregated Memory For Reinforcement Learning
A3C/PPO + LSTM
Unsupervised Predictive Memory in a Goal-Directed Agent
a monstrous combination of VAE and Q-function estimator
uses simplified DNC under the hood
no gradients flow between policy and MBR
trained with Policy Gradients and GAE
VAE:
DNC:
Unsupervised Predictive Memory in a Goal-Directed Agent
MBR is optimized to be a "world model", i.e. to produce predictions that are consistent with observed trajectories: \(p(o_{0:T}, R_{0:T})\)
Compresses observations \(o_t\) into low-dimensional state representations \(z_t\) (inspired by compressive bottleneck theory)
Stores all previous representations \(z_{0:t}\) in memory
Uses it's memory to predict future: \(o_{t+1}\) and \(R_{t+1}\)
We do this with VAE
we can just predict \(z_{t+1}\)
DNC lives here
Model-Based Predictor has a loss function based on the variational lower bound:
Reconstruction Loss: KL Loss:
Forces MBR to make good compressed representations \(z_t\); recall VAE
Train memory to make good predictions
We want MBR to make good predictions of observations and Q-values
\(z_{t}\) prediction from memory (steps: 0 .. t-1)
actual compressed representation \(z_t\) given \(o_t\) and memory from (steps: 0 .. t-1)
Module takes all memory from the previous step and produces parameters of Diagonal Gaussian distribution:
Another MLP \(f^{post}\) takes:
and generates correction for the prior:
At the end, latent state variable \(z_t\) is sampled from posterior distribution.
Unsupervised Predictive Memory in a Goal-Directed Agent
MERLIN is compared against two baselines: A3C-LSTM, A3C-DNC
Unsupervised Predictive Memory in a Goal-Directed Agent
Problem:
HINT: Casts an RL problem as a supervised Learning problem!
Solution:
Real Solution:
Just treat RL as a sequence modeling problem
Solution:
encode each step with three tokens
Reward-to-Go:
\(R_t=\sum^T_{i=t} r_i\)
Decision Transformer: Reinforcement Learning via Sequence Modeling
Training Decision Transformer:
Inference with Decision Transformer:
Solution:
Decision Transformer: Reinforcement Learning via Sequence Modeling
Differences with Decision Transformer:
Solution:
Offline Reinforcement Learning as One Big Sequence Modeling Problem
Append goal state here
Decision Transformer:
Solution:
Trajectory Transformer:
Solution:
Trajectory Transformer and Decision Transformer learn to generate tragectories that match with some statistics about the future
Solution:
DT: \(r(s_t, a_t)\)
TT: \(s_t\)
DT: summation
TT: select last
Generalized Decision Transformer for offline Hindsight Information Matching
Future statisitcs: Discretized X-velocity
Solution:
Backflipping Dataset
Running Dataset
Results for combined distribution promt:
Generalized Decision Transformer for offline Hindsight Information Matching
Solution:
Generalized Decision Transformer for offline Hindsight Information Matching
Stabilizing Transformers For Reinforcement Learning
Stabilizing Transformers For Reinforcement Learning
Stabilizing Transformers For Reinforcement Learning
Stabilizing Transformers For Reinforcement Learning
Towards mental time travel: a hierarchical memory for reinforcement learning agents
Towards mental time travel: a hierarchical memory for reinforcement learning agents
Long delay == 30 seconds
Towards mental time travel: a hierarchical memory for reinforcement learning agents
Extra materials
Catastrophic interference i.e knowledge in neural networks is non-local
Nature of Gradient Descent
Semantic memory makes better use of experiences (i.e. better generalization)
Episodic memory requires fewer experiences (i.e. more accurate)
"We will show that in general, just as model-free control is better than model-based control after substantial experience, episodic control is better than model-based control after only very limited experience."
A Tree MDP is just a MPD without circles.
Text
Text
A) Tree MDP with branching factor = 2 B) Tree MDP with branching factor = 3 C) Tree MDP with branching factor = 4
Lets store all past experiences in \(|A|\) dictionaries \(Q_{a}^{EC} \)
\(s_t, a_t \) are keys and discounted future rewards \(R_t\) are values.
Dictionary update:
If a state space has a meaningful distance, then we can use k-nearest neightbours to estimate new \((s,a)\) pairs:
Model-Free Episodic Control
Lets store all past experiences in \(|A|\) dictionaries \(Q_{a}^{EC} \)
\(s_t, a_t \) are keys and discounted future rewards \(R_t\) are values.
Dictionary update:
If a state space has a meaningful distance, then we can use k-nearest neightbours to estimate new \((s,a)\) pairs:
Two possible feature compressors for \(s_t\): Random Network, VAE
Test environments:
Differences with Model-Free Episodic Control:
CNN instead of VAE/RandomNet
Differential Neural Dictionaries (DND)
Replay Memory like in DQN, but small...
CNN and DND learn with gradient descent
Neural Episodic Control
For each action \(a \in A \), NEC has a dictionary \(M_a = (K_a , V_a )\).
Keys and Queries are generated by CNN
To estimate \(Q(s_t, a)\) we sample p-nearest neighbors from \(M_a\)
\(k\) is a kernel for distance estimate. In experiments:
Once a key \(h_t\) is queried from a DND, that key and its corresponding output are appended to the DND. If \(h_t \not\in M_a\) then we just store it with N-step Q-value estimate:
otherwise, we update stored value with tabular Q-learning rule:
Learn DND and CNN-encoder:
Sample mini-batches from replay buffer that stores triplets \((s_t,a_t, R_t)\) and use \(R_t\) as a target.
Neural Episodic Control
Replay Buffer here \(\ne\) DND
Neural Episodic Control