Shen Shen
April 23, 2025
2:30pm, Room 32-144
Recap
Identity (quite useful in ML)
\(\begin{aligned} \nabla_\theta p_\theta(\tau) & =p_\theta(\tau) \frac{\nabla_\theta p_\theta(\tau)}{p_\theta(\tau)} \\ & =p_\theta(\tau) \nabla_\theta \log p_\theta(\tau)\end{aligned}\)
\(=\nabla_\theta \sum_\tau P(\tau ; \theta) R(\tau)\)
\(=\sum_\tau \nabla_\theta P(\tau ; \theta) R(\tau)\)
\(=\sum_\tau \frac{P(\tau ; \theta)}{P(\tau ; \theta)} \nabla_\theta P(\tau ; \theta) R(\tau)\)
\(=\sum_\tau P(\tau ; \theta) \frac{\nabla_\theta P(\tau ; \theta)}{P(\tau ; \theta)} R(\tau)\)
\(=\sum_\tau P(\tau ; \theta) \nabla_\theta \log P(\tau ; \theta) R(\tau)\)
\(\nabla_\theta U(\theta)\)
Recap
where \(P(\tau ; \theta)=\prod_{t=0} \underbrace{P\left(s_{t+1} \mid s_t, a_t\right)}_{\text {transition }} \cdot \underbrace{\left.\pi_\theta\left(a_t \mid s_t\right)\right]}_{\text {policy }}\)
\(=\sum_\tau P(\tau ; \theta) \nabla_\theta \log P(\tau ; \theta) R(\tau)\)
\(\nabla_\theta U(\theta)\)
Transition is unknown....
Stuck?
Recap
\(=\sum_\tau P(\tau ; \theta) \nabla_\theta \log P(\tau ; \theta) R(\tau)\)
Approximate with the empirical (Monte-Carlo) estimate for \(m\) sample traj. under policy \(\pi_\theta\)
\(\nabla_\theta U(\theta) \approx \hat{g}=\frac{1}{m} \sum_{i=1}^m \nabla_\theta \log P\left(\tau^{(i)} ; \theta\right) R\left(\tau^{(i)}\right)\)
Valid even when:
\(\nabla_\theta U(\theta)\)
Recap
where \(P(\tau ; \theta)=\prod_{t=0} \underbrace{P\left(s_{t+1} \mid s_t, a_t\right)}_{\text {transition }} \cdot \underbrace{\left.\pi_\theta\left(a_t \mid s_t\right)\right]}_{\text {policy }}\)
\(=\nabla_\theta \log [\prod_{t=0} \underbrace{P\left(s_{t+1} \mid s_t, a_t\right)}_{\text {transition }} \cdot \underbrace{\left.\pi_\theta\left(a_t \mid s_t\right)\right]}_{\text {policy }}\)
\(=\nabla_\theta\left[\sum_{t=0} \log P\left(s_{t+1} \mid s_t, a_t\right)+\sum_{t=0} \log \pi_\theta\left(a_t \mid s_t\right)\right]\)
\(=\nabla_\theta \sum_{t=0} \log \pi_\theta\left(a_t \mid s_t\right)\)
\(=\sum_{t=0} \underbrace{\nabla_\theta \log \pi_\theta\left(a_t \mid s_t\right)}_{\text {no transition model required, }}\)
\(\nabla_\theta \log P(\tau ; \theta)\)
\(\nabla_\theta U(\theta) \approx \hat{g}=\frac{1}{m} \sum_{i=1}^m \nabla_\theta \log P\left(\tau^{(i)} ; \theta\right) R\left(\tau^{(i)}\right)\)
Recap
\(=\sum_{t=0} \underbrace{\nabla_\theta \log \pi_\theta\left(a_t \mid s_t\right)}_{\text {no transition model required}}\)
\(\nabla_\theta \log P(\tau ; \theta)\)
\(\nabla_\theta U(\theta) \approx \hat{g}=\frac{1}{m} \sum_{i=1}^m \nabla_\theta \log P\left(\tau^{(i)} ; \theta\right) R\left(\tau^{(i)}\right)\)
The following expression provides us with an unbiased estimate of the gradient, and we can compute it without access to the transition model:
Unbiased estimator \(\mathrm{E}[\hat{g}]=\nabla_\theta U(\theta)\), but very noisy.
where
Recap
\(=\sum_\tau P(\tau ; \theta) R(\tau)\)
\(U(\theta)\)
\(\approx \hat{g} = \frac{1}{m}\sum_{i=1}^{m}\left(\sum_{t=0}\nabla_\theta \log \pi_\theta(a_t^{(i)}\mid s_t^{(i)})\right)R(\tau^{(i)})\)
This policy gradient estimator typically has high variance, due to:
\(=\frac{1}{m} \sum_{i=1}^m \nabla_\theta \log P\left(\tau^{(i)} ; \theta\right) R\left(\tau^{(i)}\right)\)
\(\nabla_\theta U(\theta)\)
\(=\sum_\tau P(\tau ; \theta) R(\tau)\)
\(U(\theta)\)
\(\nabla_\theta U(\theta) \approx \hat{g} = \frac{1}{m}\sum_{i=1}^{m}\left(\sum_{t=0}\nabla_\theta \log \pi_\theta(a_t^{(i)}\mid s_t^{(i)})\right)R(\tau^{(i)})\)
This policy gradient estimator typically has high variance, due to:
\(=\sum_\tau P(\tau ; \theta) R(\tau)\)
\(U(\theta)\)
\(\nabla_\theta U(\theta) \approx \hat{g} = \frac{1}{m}\sum_{i=1}^{m}\left(\sum_{t=0}\nabla_\theta \log \pi_\theta(a_t^{(i)}\mid s_t^{(i)})\right)R(\tau^{(i)})\)
This policy gradient estimator typically has high variance, due to:
\(=\sum_\tau P(\tau ; \theta) R(\tau)\)
\(U(\theta)\)
\(\nabla_\theta U(\theta) \approx \hat{g} = \frac{1}{m}\sum_{i=1}^{m}\left(\sum_{t=0}\nabla_\theta \log \pi_\theta(a_t^{(i)}\mid s_t^{(i)})\right)R(\tau^{(i)})\)
This policy gradient estimator typically has high variance, due to:
\(=\sum_\tau P(\tau ; \theta) R(\tau)\)
\(U(\theta)\)
\(\nabla_\theta U(\theta) \approx \hat{g} = \frac{1}{m}\sum_{i=1}^{m}\left(\sum_{t=0}\nabla_\theta \log \pi_\theta(a_t^{(i)}\mid s_t^{(i)})\right)R(\tau^{(i)})\)
This policy gradient estimator typically has high variance, due to:
Variance reduction:
\(\hat{g}=\frac{1}{m} \sum_{i=1}^m \nabla_\theta \log P\left(\tau^{(i)} ; \theta\right)\left(R\left(\tau^{(i)}\right)\right)\)
\(=\frac{1}{m} \sum_{i=1}^m\left(\sum_{t=0}^{h-1} \nabla_\theta \log \pi_\theta\left(a_t^{(i)} \mid s_t^{(i)}\right)\right)\left(\sum_{t=0}^{h-1} R\left(s_t^{(i)}, a_t^{(i)}\right)\right)\)
\(=\frac{1}{m} \sum_{i=1}^m\left(\sum_{t=0}^{I I-1} \nabla_\theta \log \pi_\theta\left(a_t^{(i)} \mid s_t^{(i)}\right)\left[\left(\sum_{k=0}^{t-1} R\left(s_k^{(i)}, a_k^{(i)}\right)\right)+\left(\sum_{k=t}^{h-1} R\left(s_k^{(i)}, a_k^{(i)}\right)\right)\right]\right)\)
[Policy Gradient Theorem: Sutton et al 1999; GPOMDP: Bartlett & Baxter, 2001; Survey: Peters & Schaal, 2006]
Removing terms that don't depend on current action can lower variance:
\(\frac{1}{m} \sum_{i=1}^m \sum_{t=0}^{h-1} \nabla_\theta \log \pi_\theta\left(a_t^{(i)} \mid s_t^{(i)}\right)\left(\sum_{k=t}^{h-1} R\left(s_k^{(i)}, a_k^{(i)}\right)\right)\)
Temporal structure
Variance reduction:
\(\nabla U(\theta) \approx \hat{g}=\frac{1}{m} \sum_{i=1}^m \nabla_\theta \log P\left(\tau^{(i)} ; \theta\right)\left(R\left(\tau^{(i)}\right)-b\right)\)
[Williams, REINFORCE paper, 1992]
\[\frac{1}{m} \sum_{i=1}^m \nabla_\theta \log P\left(\tau^{(i)} ; \theta\right)(-b)\]
\(\nabla U(\theta) \approx \hat{g}=\frac{1}{m} \sum_{i=1}^m \nabla_\theta \log P\left(\tau^{(i)} ; \theta\right)\left(R\left(\tau^{(i)}\right)-b\right)\)
\(\begin{aligned} & \mathbb{E}\left[\nabla_\theta \log P(\tau ; \theta) b\right] \\ & =\sum_\tau P(\tau ; \theta) \nabla_\theta \log P(\tau ; \theta) b \\ & =\sum_\tau P(\tau ; \theta) \frac{\nabla_\theta P(\tau ; \theta)}{P(\tau ; \theta)} b \\ & =\sum_\tau \nabla_\theta P(\tau ; \theta) b\end{aligned}\)
\(=\nabla_\theta\left(\sum_\tau P(\tau) b\right)=b \nabla_\theta\left(\sum_\tau P(\tau)\right)=b \times 0\)
[Williams, REINFORCE paper, 1992]
❓ Why still unbiased despite the additional term
\[\frac{1}{m} \sum_{i=1}^m \nabla_\theta \log P\left(\tau^{(i)} ; \theta\right)(-b)\]
Variance reduction:
\(\nabla U(\theta) \approx \hat{g}=\frac{1}{m} \sum_{i=1}^m \nabla_\theta \log P\left(\tau^{(i)} ; \theta\right)\left(R\left(\tau^{(i)}\right)-b\right)\)
[Williams, REINFORCE paper, 1992]
✅ Still unbiased, despite the additional term
\[\frac{1}{m} \sum_{i=1}^m \nabla_\theta \log P\left(\tau^{(i)} ; \theta\right)(-b)\]
❓with good choice of \(b,\) can reduce variance of \(\nabla U(\theta)\)
Control Variates
We can also define a new estimator: \(X’ = X - \alpha(Y - \nu)\) where \(\alpha\) can be chosen optimally to minimize variance.
Control variates in RL
\(\nabla U(\theta) \approx \hat{g}=\frac{1}{m} \sum_{i=1}^m \nabla_\theta \log P\left(\tau^{(i)} ; \theta\right)\left(R\left(\tau^{(i)}\right)-b\right)\)
\(\nabla U(\theta) \approx \hat{g}=\frac{1}{m} \sum_{i=1}^m \nabla_\theta \log P\left(\tau^{(i)} ; \theta\right)\left(R\left(\tau^{(i)}\right)-\hat{V}^\pi(s)\right)\)
[Greensmith, Bartlett, Baxter, JMLR 2004 for variance reduction techniques.]
advantage function
\(-b\)
instead of increase the likelihood of all "winning" games, increase the likelihood of "better than average score" games
\(-b\)
How to estimate \(V^\pi\)?
Or, collect \(\tau_1, \ldots, \tau_m\), and regress against empirical return: \(\phi_{i+1} \leftarrow \underset{\phi}{\arg \min } \frac{1}{m} \sum_{i=1}^m \sum_{t=0}^{H-1}\left(V_\phi^\pi\left(s_t^{(i)}\right)-\left(\sum_{k=t}^{H-1} R\left(s_k^{(i)}, u_k^{(i)}\right)\right)\right)^2\)
Or, similar to fitted Q-learning, do fitted V-learning: \(\phi_{i+1} \leftarrow \min _\phi \sum_{\left(s, u, s^{\prime}, r\right)}\left\|r+V_{\phi_i}^\pi\left(s^{\prime}\right)-V_\phi(s)\right\|_2^2\)
[Greensmith, Bartlett, Baxter, JMLR 2004 for variance reduction techniques.]
How to estimate advatange?
(GAE) [Schulman et al, ICLR 2016]
TD(lambda) / eligibility traces [Sutton and Barto, 1990]
How to estimate advatange?
[Async Advantage Actor Critic (A3C) [Mnih et al, 2016]
Actor-critic method
[Williams, REINFORCE paper, 1992]
actor
critic
Vanilla policy gradient/REINFORCE: step-sizing issue
Trust region policy optimization (TRPO)
\(\max L(\pi)=\mathbb{E}_{\pi \text { old }}\left[\frac{\pi(a \mid s)}{\pi_{\mathrm{old}}(a \mid s)} A^\pi_{old}\right.\)\(\left.(s, a)\right]\)
Constraint: \(\quad \mathbb{E}_{\pi_{\text {old }}}\left[K L\left(\pi \mid \pi_{\text {old }}\right)\right] \leq \epsilon\)
importance sampling
\(\mathbb{E}_{x \sim q}\left[\frac{p(x)}{q(x)} f(x)\right]=\mathbb{E}_{x \sim p}[f(x)]\)
\(U(\theta)=\mathbb{E}_{\tau \sim \theta} \mathrm{old}\left[\frac{P(\tau \mid \theta)}{P\left(\tau \mid \theta_{\mathrm{old}}\right)} R(\tau)\right]\)
\(=\mathbb{E}_{\tau \sim \theta_{\text {old }}}\left[\frac{\pi(\tau \mid \theta)}{\pi\left(\tau \mid \theta_{\text {old }}\right)} R(\tau)\right]\)
\(\nabla_\theta U(\theta)=\mathbb{E}_{\tau \sim \theta} \text { old }\left[\frac{\nabla_\theta P(\tau \mid \theta)}{P\left(\tau \mid \theta_{\text {old }}\right)} R(\tau)\right]\)
[Tang and Abbeel, On a Connection between Importance Sampling and the Likelihood Ratio Policy Gradient, 2011]
[Tang and Abbeel, On a Connection between Importance Sampling and the Likelihood Ratio Policy Gradient, 2011]
TRPO
\(\max L(\pi)=\mathbb{E}_{\pi \text { old }}\left[\frac{\pi(a \mid s)}{\pi_{\mathrm{old}}(a \mid s)} A^\pi_{old}\right.\)\(\left.(s, a)\right]\)
Constraint: \(\quad \mathbb{E}_{\pi_{\text {old }}}\left[K L\left(\pi \mid \pi_{\text {old }}\right)\right] \leq \epsilon\)
\(\max \mathbb{E}_{\pi \text { old }}\left[\frac{\pi(a \mid s)}{\pi_{\mathrm{old}}(a \mid s)} A^\pi_{old}\right.\)\(\left.(s, a)\right]\)
\(-\beta\left(\mathbb{E}_t\left[\operatorname{KL}\left(\pi_{o l d} \mid \pi\right)\right]-\delta\right)\)
Proximal Policy Optimization (PPO, v1)
Proximal Policy Optimization (PPO, v2)
Recall the objective:
\(\hat{\mathbb{E}}_t\left[\frac{\pi_\theta\left(a_t \mid s_t\right)}{\pi_\theta \text { old }}{ }^{\left(a_t \mid s_t\right)} \hat{A}_t\right]\)
\(=\hat{\mathbb{E}}_t\left[\rho_t(\theta) \hat{A}_t\right]\)
A pwer bound of the above:
\(L^{C L I P}(\theta)=\hat{\mathbb{E}}_t\left[\min \left(\rho_t(\theta) \hat{A}_{t<0}, \operatorname{clip}\left(\rho_t(\theta), 1-\epsilon, 1+\epsilon\right) \hat{A}_t\right)\right]\)
[table credit: Daniel Bick]
- The main loss \(L_t(\theta)=\min \left(r_t(\theta) \hat{A}_t, \operatorname{clip}\left(r_t(\theta)\right), 1-\epsilon, 1+\epsilon\right) \hat{A}_t\)
- Advantage estimate based on truncated GAE
- Adding an entropy term
Beyond what we covered
- POMDP
- Inverse RL
- Apprenticeship learning
- Behavioral cloning (Data Augmentation, a.k.a. DAgger)
- Transfer learning (sim to real)
- Domain randomization
- Multi-task learning
- Curriculum learning
- Hierarchical RL
- Safe/verifiable RL
- Multi-agent RL
- Offline RL
- Rewards shaping
- Fairness, ethical, explainable AI (value alignment)
We'd love to hear your thoughts.
Variance Reduction Achieved