CS 4/5789: Introduction to Reinforcement Learning

Lecture 11: Model-Based RL

Prof. Sarah Dean

MW 2:45-4pm
255 Olin Hall

Reminders

  • Homework
    • PSet 2 regrade requests today-Friday
    • PA 2 due Friday
    • PSet 4 released next week
    • 5789 Paper Review Assignment
  • Midterm 3/15 during lecture
    • Let us know conflicts/accomodations ASAP! (EdStem)
    • Review Lecture on 3/12 (last year's slides/recording)
    • Materials: slides (Lectures 1-10, some of 11-13), PSets 1-4
      • also: equation sheet (next week), 2023 notes, PAs

Agenda

1. Recap: MDPs and Control

2. MBRL with Query Model

3. Sub-Optimality

4. Model Error

Recap: MDPs and Control

  • So far, we have algorithms (VI, PI, DP, LQR) for when transitions \(P\) or \(f\) are known and tractable to optimize
  • In Unit 2, we develop algorithms for unknown (or intractable) transitions

\(\mathcal M = \{\mathcal{S}, \mathcal{A}, r, P, \gamma\}\)

Infinite horizon discounted MDP with finite states and actions

maximize   \(\displaystyle \mathbb E\left[\sum_{i=1}^\infty \gamma^t r(s_t, a_t)\right]\)

s.t.   \(s_{t+1}\sim P(s_t, a_t), ~~a_t\sim \pi(s_t)\)

\(\pi\)

minimize   \(\displaystyle\sum_{t=0}^{H-1} c(s_t, a_t)\)

s.t.   \(s_{t+1}=f(s_t, a_t), ~~a_t=\pi_t(s_t)\)

\(\pi\)

\(\mathcal M = \{\mathcal{S}, \mathcal{A}, c, f, H\}\)

Finite horizon deterministic MDP with continuous states/actions

Recap: MDPs and Control

action

state

\(a_t\)

reward

\(s_t\)

\(r_t\)

  • In Unit 2, we develop algorithms for unknown (or intractable) transitions

Recap: Distributions

  • Recall the state distribution for a policy \(\pi\) (Lecture 2)
    • \( d^{\pi}_{\mu_0,t}(s) = \mathbb{P}\{s_t=s\mid s_0\sim \mu_0,s_{k+1}\sim P(s_k, \pi(s_k))\} \)
  • We showed that it can be written as \(d^{\pi}_{\mu_0,t} = P_\pi^\top  d^{\pi}_{\mu_0,t-1} = (P_\pi^t)^\top \mu_0\)
  • The discounted distribution (PSet) \(d_{\mu_0}^{\pi} = (1-\gamma)  \sum_{t=0}^{\infty} \gamma^t \underbrace{(P_\pi ^t)^\top \mu_0}_{d_{\mu_0,t}^{\pi}} \)
  • In Unit 2 we no longer know distributions
    • e.g. a Binomial \( p(x;n,p) = \binom{n}{x}p^x(1-p)^{n-x}\)
  • Instead, we observe samples
    • e.g. we \(x_1, x_2, x_3,\dots\) drawn from Binomial

When the initial state is fixed to a known \(s_0\), i.e. \(\mu_0=e_{s_0}\) we write \(d_{s_0,t}^{\pi}\)

Agenda

1. Recap: MDPs and Control

2. MBRL with Query Model

3. Sub-Optimality

4. Model Error

  • Query model (also called generative model)
    • setting where we can query, for any \(s\) and \(a\), the transition/dynamics to sample $$ s'\sim P(s,a)$$
  • This is black-box sampling access
    • Directly applicable to games, physic simulators
  • Simple starting point to understand sample complexity: how many samples are required for good performance?

Query Model

Algorithm: MBRL with Queries

  • Inputs: sample points \(\{s_i,a_i\}_{i=1}^N\)
  • For \(i=1,\dots, N\):
    • sample \(s'_i \sim P(s_i, a_i)\) and record \((s'_i,s_i,a_i)\)
  • Fit transition model \(\hat P\) from data \(\{(s'_i,s_i,a_i)\}_{i=1}^N\)
  • Design \(\hat \pi\) using \(\hat P\)

Model-based RL

Agenda

1. Recap: MDPs and Control

2. MBRL with Query Model

3. Sub-Optimality

4. Model Error

Example: Sub-Optimality

\(0\)

\(1\)

stay: \(1\)

switch: \(1\)

stay: \(p_1\)

switch: \(1-p_2\)

stay: \(1-p_1\)

switch: \(p_2\)

  • The reward is:
    • \(+1\) for \(s=0\) and \(-\frac{1}{2}\) for
      \(a=\) switch
    • Let \(\gamma=\frac{1}{2}\)
  • Recall from Lecture 4 that the policy \(\pi(s)=\)stay is optimal if $$p_2\leq \frac{2p_1}{2- p_1}+\frac{1}{4}$$
  • If we mis-estimate \(\hat p_1,\hat p_2\), may choose the wrong policy
  • Model error in \(\hat P\) leads to sub-optimal policies
  • Notation:
    • \(P\) is the true transition function
    • \(V^{\pi}\) is the true value of a policy \(\pi\)
      • i.e., \(\mathsf{PolicyEval}(P,r,\pi)\) "value on \(P\)"
    • \(\hat V^{\pi}\) is the estimated value of a policy \(\pi\)
      • i.e., \(\mathsf{PolicyEval}(\hat P,r,\pi)\) "value on \(\hat P\)"
    • \(V^\star\) and \(\pi^\star\) are the true optimal value and policy
  • Assumption for today's lecture: \(0\leq r(s,a) \leq 1\) for all states and actions

Sub-Optimality

  • Suppose Policy Iteration on \(\hat P, r\) converges to a fixed policy \(\hat \pi^\star\)
    • What do we know about \(\hat\pi^\star\)? PollEv
  • The sub-optimality is
    • \(V^\star(s_0) - V^{\hat\pi^\star}(s_0)\)
      • \(=V^\star(s_0)  - \hat V^{\hat\pi^\star}(s_0) + \hat V^{\hat\pi^\star}(s_0) - V^{\hat\pi^\star}(s_0)\)
      • \(\leq V^\star(s_0)  - \hat V^{\pi^\star}(s_0) + \hat V^{\hat\pi^\star}(s_0) - V^{\hat\pi^\star}(s_0)\)
        • (\(\hat\pi^\star\) optimal for \(\hat P\))
    • Two terms: value of \(\pi^\star\) and value of \(\hat \pi^\star\) on \(P\) vs. \(\hat P\)

Error in Policies

Simulation Lemma: For a deterministic policy \(\pi\), $$|\hat V^\pi(s_0) - V^\pi(s_0)| \leq \frac{\gamma}{(1-\gamma)^2} \mathbb E_{s\sim d_{\pi}^{s_0}}\left[ \|\hat P(\cdot |s,\pi(s)) - P(\cdot|s,\pi(s))\|_1\right]$$

Difference in Value

For a fixed policy, what is the difference in value when computed using \(P\) vs. when using \(\hat P\)?

  • error in predicted next state
  • averaged over discounted state distribution

\(\underbrace{\qquad\qquad}\)

\(\sum_{s'\in\mathcal S} |\hat P(s'|s,\pi(s)) - P(s'|s,\pi(s))| \)

total variation distance on distribution over \( s'\)

Simulation Lemma: For a deterministic policy, $$|\hat V^\pi(s_0) - V^\pi(s_0)| \leq \frac{\gamma}{(1-\gamma)^2} \mathbb E_{s\sim d^{\pi}_{s_0}}\left[ \|\hat P(\cdot |s,\pi(s)) - P(\cdot|s,\pi(s))\|_1\right]$$

Simulation Lemma Proof

  • First, derive a recursion:
    • \(\hat V^\pi-V^\pi = R^\pi + \gamma \hat P_\pi \hat V^\pi - R^\pi - \gamma  P_\pi  V^\pi\) (Bellman Eq)
    • \(=\gamma (\hat P_\pi \hat V^\pi -  P_\pi \hat V^\pi + P_\pi \hat V^\pi - P_\pi  V^\pi)\) (simplify and add zero)
    • \(=\gamma (\hat P_\pi  -  P_\pi) \hat V^\pi + \gamma P_\pi (\hat V^\pi - V^\pi)\)
  • Iterating expression \(k\) times:
    • \(\hat V^\pi-V^\pi =\sum_{\ell=1}^k \gamma^{\ell} P_\pi^{\ell-1} (\hat P_\pi  -  P_\pi) \hat V^\pi + \gamma^k P_\pi^k (\hat V^\pi - V^\pi)\)
    • \(\hat V^\pi-V^\pi =\sum_{\ell=1}^\infty \gamma^\ell P_\pi^{\ell-1} (\hat P_\pi  -  P_\pi) \hat V^\pi \)    (limit \(k\to\infty\))

For alternative proof without vector notation

see https://sdean.website/cs4789sp22/lec10-notes.pdf

Simulation Lemma: For a deterministic policy, $$|\hat V^\pi(s_0) - V^\pi(s_0)| \leq \frac{\gamma}{(1-\gamma)^2} \mathbb E_{s\sim d^{\pi}_{s_0}}\left[ \|\hat P(\cdot |s,\pi(s)) - P(\cdot|s,\pi(s))\|_1\right]$$

Simulation Lemma Proof

  • Recursive argument: \(\hat V^\pi-V^\pi =\gamma \sum_{\ell=0}^\infty \gamma^\ell P_\pi^\ell (\hat P_\pi  -  P_\pi) \hat V^\pi \)  
  • \(\hat V^\pi(s_0) - V^\pi(s_0) = (\hat V^\pi-V^\pi)^\top e_{s_0}\)
    • \(=\gamma\sum_{\ell=0}^\infty \gamma^\ell [(\hat P_\pi  -  P_\pi) \hat V^\pi]^\top (P_\pi^\ell)^\top e_{s_0}\)
    • \(= \gamma [(\hat P_\pi  -  P_\pi) \hat V^\pi]^\top \sum_{\ell=0}^\infty \gamma^\ell (P_\pi^\ell)^\top e_{s_0}\)
    • \(=\frac{\gamma}{1-\gamma} [(\hat P_\pi  -  P_\pi) \hat V^\pi]^\top d_\pi^{s_0} \) (definition of discounted distribution)
    • \(= \frac{\gamma}{1-\gamma}\sum_{s\in\mathcal S} \left((\hat P_\pi  -  P_\pi) \hat V^\pi\right) [s] d_\pi^{s_0}[s] \)
    • \(=\frac{\gamma}{1-\gamma} \mathbb E_{s\sim d^\pi_{s_0}}\left [\left((\hat P_\pi  -  P_\pi) \hat V^\pi\right) [s] \right ] \) (definition of expectation)

Simulation Lemma: For a deterministic policy, $$|\hat V^\pi(s_0) - V^\pi(s_0)| \leq \frac{\gamma}{(1-\gamma)^2} \mathbb E_{s\sim d^{\pi}_{s_0}}\left[ \|\hat P(\cdot |s,\pi(s)) - P(\cdot|s,\pi(s))\|_1\right]$$

Simulation Lemma Proof

  • Recursive argument: \(\hat V^\pi-V^\pi =\gamma \sum_{\ell=0}^\infty \gamma^\ell P_\pi^\ell (\hat P_\pi  -  P_\pi) \hat V^\pi \)  
  • \(\hat V^\pi(s_0) - V^\pi(s_0) = \frac{\gamma}{1-\gamma}\mathbb E_{s\sim d_\pi^{s_0}}\left [\left((\hat P_\pi  -  P_\pi) \hat V^\pi\right) [s] \right ] \)
    • \(= \frac{\gamma}{1-\gamma}\mathbb E_{s\sim d_\pi^{s_0}}\left [(\hat P_\pi  -  P_\pi) [s]^\top  \hat V^\pi \right ]\) (defn of matrix multiplication)
    • \(= \frac{\gamma}{1-\gamma}\mathbb E_{s\sim d_\pi^{s_0}}\left [\sum_{s'\in\mathcal S}(\hat P(s'|s,\pi(s))  -  P(s'|s,\pi(s))\hat V^\pi(s') \right ]\)
  • \(|\hat V^\pi(s_0) - V^\pi(s_0)|\leq\frac{\gamma}{1-\gamma} \mathbb E_{s\sim d_\pi^{s_0}}\left [\sum_{s'\in\mathcal S}|\hat P(s'|s,\pi(s))  -  P(s'|s,\pi(s)||\hat V^\pi(s')| \right ]\)
    • \(\leq\frac{\gamma}{1-\gamma} \mathbb E_{s\sim d_\pi^{s_0}}\left [\sum_{s'\in\mathcal S}|\hat P(s'|s,\pi(s))  -  P(s'|s,\pi(s)|\frac{1}{1-\gamma} \right ]\) (bounded reward)

Agenda

1. Recap: MDPs and Control

2. MBRL with Query Model

3. Sub-Optimality

4. Model Error

  • If transitions are deterministic in finite MDP,
    • sample every state and every action once
    • exactly reconstruct deterministic \(P\)
    • compute optimal policy
  • Sample complexity?
    • \(SA\)
  • If transition are stochastic, need to estimate probabilities

Easy Case: Deterministic Transitions

  • If transitions are deterministic in linear continuous MDP,
    • sample \(s'=f(s,a)\) for every standard basis vector $$\{(e_i,0)\}_{i=1}^{n_s}\cup \{(0,e_j)\}_{j=1}^{n_a}$$
    • exactly reconstruct \(A\) and \(B\)
      • \(s'_i = Ae_i\) is a column of \(A\)
      • \(s'_j = Be_j\) is a column of \(B\)
    • compute optimal LQR policy
  • Sample complexity?
    • \(n_sn_a\)

Easy Case: Deterministic Transitions

  • Consider biased coin which is heads with probability \(p\) for an unknown value of \(p\in[0,1]\)
  • How to estimate from trials?
    • Flip coin \(N\) times $$\hat p =\frac{\mathsf{\# heads}}{N} $$
  • Consider \(S\) sided die which is side \(s\) with probability \(p_s\) for \(s\in\{1,\dots,S\}=[S]\), where the \(p_s\) are unknown
  • How to estimate from trials?
    • Roll dice \(N\) times $$\hat p_s =\frac{\mathsf{\# times~land~on}~s}{N} $$

Warmup: Coin & Dice

  • For the weighted coin,
    • Estimate \(\hat p\) is Binomial\((p, N)\)
    • \(|p-\hat p| \leq\sqrt{\frac{\log(2/\delta)}{N}}\) with probability \(1-\delta\)
  • For the \(S\) side die, with probability \(1-\delta\), $$\max_{s\in[S]} |p_s-\hat p_s| \leq \sqrt{\frac{\log(2S/\delta)}{N}} $$
  • Alternatively, the total variation distance w.p. \(1-\delta\) $$ \sum_{s\in[S]} |p_s-\hat p_s| \leq \sqrt{\frac{S\log(2/\delta)}{N}} $$
  • Why? Unbiased \(\mathbb E[\hat p]=p\) & concentration (Hoeffding's)

Estimation Errors

  • This slide is not in scope. Hoeffding's inequality states that for independent random variables \(X_1,\dots,X_n\) with \(a_i\leq X_i\leq b_i\), let \(S_n=\sum_{i=1}^n X_i\), then $$ \mathbb P\{|S_n-\mathbb E[S_n]|\geq t\} \leq 2\exp\left(\frac{-2t^2}{\sum_{i=1}^n (b_i-a_i)^2}\right)$$
  • Rearranging, this is equivalent to $$ |\frac{1}{n}S_n-\mathbb E[\frac{1}{n} S_n]|\leq \frac{1}{n}\sqrt{\frac{1}{2}\log(2/\delta)\sum_{i=1}^n (b_i-a_i)^2} \quad\text{w.p. at least}\quad  1-\delta$$

Hoeffding's Inequality

Transition Estimation

  • How to estimate \(P(s,a)\) from queries?
    • Query \(s_i'\sim P(s,a)\) for \(i=1,\dots n\) samples $$\hat P(s'|s,a) =\frac{\mathsf{\# times~}~s'_i=s'}{n} $$
    • Repeat procedure for all \(s,a\)
    • Total number of samples \(N=SA\cdot n\)
  • Lemma: Estimation error of above, with probability \(1-\delta\) $$\max_{s,a }\sum_{s'\in\mathcal S} |P(s'|s,a)-\hat P(s'|s,a)| \leq \sqrt{\frac{S^2A \log(2SA/\delta)}{N}} $$
  • The sub-optimality of \(\hat \pi^\star=\mathsf{PolicyIteration}(\hat P, r)\)
  • \(V^\star(s_0) - V^{\hat\pi^\star}(s_0)\)
    • \(\leq V^\star(s_0)  - \hat V^{\pi^\star}(s_0) + \hat V^{\hat\pi^\star}(s_0) - V^{\hat\pi^\star}(s_0)\)
    • \(\leq \frac{\gamma}{(1-\gamma)^2} \mathbb E_{s\sim d_{\pi^\star}^{s_0}}\left[ \|\hat P(\cdot |s,\pi^\star(s)) - P(\cdot|s,\pi^\star(s))\|_1\right]\)
      \(+\frac{\gamma}{(1-\gamma)^2} \mathbb E_{s\sim d_{\hat\pi^\star}^{s_0}}\left[ \|\hat P(\cdot |s,\hat\pi^\star(s)) - P(\cdot|s,\hat\pi^\star(s))\|_1\right]\) (Simulation Lemma x2)
    • \(\leq \frac{\gamma}{(1-\gamma)^2} \mathbb E_{s\sim d_{\pi^\star}^{s_0}}\left[ \max_a \|\hat P(\cdot |s,a) - P(\cdot|s,a)\|_1\right]\)
      \(+\frac{\gamma}{(1-\gamma)^2} \mathbb E_{s\sim d_{\hat\pi^\star}^{s_0}}\left[ \max_a\|\hat P(\cdot |s,a) - P(\cdot|s,a)\|_1\right]\)
    • \(\leq \frac{2\gamma}{(1-\gamma)^2} \max_{s,a} \|\hat P(\cdot |s,a) - P(\cdot|s,a)\|_1\)

Sub-Optimality

Theorem: For \(1\leq \delta\leq 1\), run MBRL Algorithm with \(N \geq \frac{4\gamma^2 S^2 A\log(2SA/\delta)}{\epsilon^2(1-\gamma)^4}\). Then with probability at least \(1-\delta\), $$V^\star(s)-V^{\hat \pi^\star}(s) \leq \epsilon\quad\forall~~s\in\mathcal S$$

Sample Complexity

Algorithm: Tabular MBRL with Queries

  • Inputs: total number of samples \(N\)
  • For all \((s,a)\in\mathcal S\times \mathcal A\):
    • For \(i=1,\dots, \frac{N}{SA}\):
      • sample \(s'_i \sim P(s, a)\) and record \((s'_i,s,a)\)
  • Fit transition model \(\hat P\) and design \(\hat \pi^\star\) with PI

Proof Outline:

  • From Suboptimality slide and Model Error Lemma, with probability \(1-\delta\) $$V^\star(s_0) - V^{\hat\pi^\star}(s_0) \leq \frac{2\gamma}{(1-\gamma)^2} \sqrt{\frac{S^2A \log(2SA/\delta)}{N}}$$
  • Set RHS equal to \(\epsilon\), solve for \(N\), and use the fact that \(\gamma<1\).

Recap

  • PA due Friday
  • Prelim in class 3/15

 

  • Query Setting of MBRL
  • Simulation Lemma
  • Estimation

 

  • Next lecture: Approximate Policy Iteration

Sp23 CS 4/5789: Lecture 11

By Sarah Dean

Private

Sp23 CS 4/5789: Lecture 11