CS 4/5789: Introduction to Reinforcement Learning

Lecture 12: Supervised Learning

Prof. Sarah Dean

MW 2:55-4:10pm
255 Olin Hall

Reminders

  • Homework
    • PA 2 due Friday
    • PSet 4 released Friday
    • PA 3 released next week
  • Prelim
    • Grades released by next week
    • Corrections at end of semester

Agenda

  1. Recap
  2. Supervised Learning
    • Setting, ERM, Optimization
  3. Example: Linear and Deep
  4. Supervised Learning in MDPs

Recap: Unit 1

  • So far, we have algorithms (VI, PI, DP, LQR) for when transitions \(P\) or \(f\) are known and tractable to optimize
  • We covered some sampled-based approximate algorithms for nonlinear control (iLQR, DDP)
  • In Unit 2, we develop algorithms for unknown (or intractable) transitions

Recap: MDPs and Control

action

state

\(a_t\)

reward

\(s_t\)

\(r_t\)

  • In Unit 2, we develop algorithms for unknown (or intractable) transitions

Feedback in RL

action \(a_t\)

state \(s_t\)

reward \(r_t\)

Control feedback

  • between states and actions
    • "reaction"
  • studied in control theory "automatic feedback control"
  • our focus for Unit 1

policy \(\pi\)

transitions \(P,f\)

Feedback in RL

action \(a_t\)

state \(s_t\)

reward \(r_t\)

  1. Control feedback: between states and actions

policy

data \((s_t,a_t,r_t)\)

policy \(\pi\)

transitions \(P,f\)

experience

Data feedback

  • between data and policy
    • "adaptation"
  • connection to machine learning
  • our new focus in Unit 2

 

Feedback in RL

action \(a_t\)

state \(s_t\)

reward \(r_t\)

  1. Control feedback: between states and actions
  2. Data feedback: between data and poicy

policy

data \((s_t,a_t,r_t)\)

policy \(\pi\)

transitions \(P,f\)

experience

unknown in Unit 2

Today: Supervised Learning

predictor

data

Agenda

  1. Recap
  2. Supervised Learning
    • Setting, ERM, Optimization
  3. Example: Linear and Deep
  4. Supervised Learning in MDPs

Setting: Supervised Learning

  • Dataset drawn i.i.d. from distribution \(\mathbb P(x,y)=\mathbb P(y|x)\mathbb P(x)\) $$\mathcal D= \{(x_1,y_1),\dots,(x_n,y_n)\}$$
  • Goal: learn a good predictor \(f:\mathcal X\to\mathcal Y\)
  • Brainstorm: any ideas?
  • Fact: the conditional expectation \(\mathbb E[y|x]\) minimizes the mean squared error
  • Definition: the mean squared error is $$\mathsf{MSE}(f) = \mathbb E[(y-f(x))^2]$$

MSE and Condition Exp.

  • Fact: the conditional expectation \(\mathbb E[y|x]\) minimizes the mean squared error \(\mathsf{MSE}(f) = \mathbb E[(y-f(x))^2]\)
  • Proof:
    • \(\mathsf{MSE}(f) = \mathbb E[(y-\mathbb E[y|x]+\mathbb E[y|x]-f(x))^2]\)
    • \(= \mathbb E[(y-\mathbb E[y|x])^2+(\mathbb E[y|x]-f(x))^2 + 2(y-\mathbb E[y|x])(\mathbb E[y|x]-f(x))]\)
    • \( \mathbb E[(y-\mathbb E[y|x])(\mathbb E[y|x]-f(x))]=\)
      • \( =\mathbb E[\mathbb E[(y-\mathbb E[y|x])(\mathbb E[y|x]-f(x))\mid x ]]\) (tower rule)
      • \( =\mathbb E[\mathbb E[(y-\mathbb E[y|x])\mid x ](\mathbb E[y|x]-f(x))]\) (constant w.r.t \(x\))
      • \( =\mathbb E[(\mathbb E[y| x ]-\mathbb E[y|x])(\mathbb E[y|x]-f(x))]=0\) (linearity of expectation)
    • \(\min_f \mathsf{MSE}(f)= \min_f \mathbb E[(y-\mathbb E[y|x])^2+(\mathbb E[y|x]-f(x))^2]\)

PollEV

Empirical risk minimization

  • Fact: the conditional expectation \(\mathbb E[y|x]\) minimizes the mean squared error \(\mathsf{MSE}(f) = \mathbb E[(y-f(x))^2]\)
  • This fact motivates \(\mathbb E[y|x]\) as a target and suggests how to do it $$\mathbb E[(y-f(x))^2]\approx \frac{1}{n}\sum_{i=1}^n (y_i-f(x_i))^2$$
  • Definition: empirical risk minimization (ERM) is $$\hat f(x)=\arg\min_f \frac{1}{n}\sum_{i=1}^n (y_i-f(x_i))^2$$
  • Any potential problems?
    • ex: \(f(x) = \sum_{i=1}^ny_i\mathbb 1\{x=x_i\}\) overfits!

Function classes

  • We constrain ERM to a function class \(\mathcal F\) $$\hat f(x)=\arg\min_{f\in\mathcal F} \frac{1}{n}\sum_{i=1}^n (y_i-f(x_i))^2$$
  • Example: if \(x\) is scalar then quadratic functions are $$\mathcal F =\{ax^2+bx+c\mid a, b, c \in\mathbb R\}$$
  • How to choose a function class?
    1. Approximation: \(\mathbb E[y|x]\approx \arg\min_{f\in\mathcal F} \mathbb E[(y-f(x))^2]\)
    2. Complexity: \(\mathcal F\) doesn't contain "too many" functions
    3. Optimizability: able to compute the argmin (or close)
  • Statistical learning theory tells us that ERM optimum (3) performs well when approximation error (1) and complexity (2) are low

Optimization

  • Usually, \(\mathcal F\) is parametrized by some \(\theta\in\mathbb R^d\), i.e. \(f_\theta(x)\)
  • Then parametrized ERM is $$\hat\theta = \arg\min_{\theta\in\mathbb R^d} \frac{1}{n}\sum_{i=1}^n \underbrace{ (y_i-f_\theta(x_i))^2}_{L_i(\theta)}$$
  • Notation: write ERM objective \(L(\theta)=\frac{1}{n}\sum_{i=1}^nL_i(\theta)\)
  • How to solve optimization?
    1. In simple cases, closed-form \(\nabla_\theta L(\theta)=0\)
    2. Iteratively with gradient descent (GD) \(\theta^{(i+1)} = \theta^{(i)}+\eta\nabla_\theta L(\theta^{(i)})\)
    3. Efficiently w/ stochastic GD \(\theta^{(i+1)} = \theta^{(i)}+\eta\nabla_\theta L_i(\theta^{(i)})\)
  • More on this in a few weeks!

Agenda

  1. Recap
  2. Supervised Learning
    • Setting, ERM, Optimization
  3. Example: Linear and Deep
  4. Supervised Learning in MDPs

Linear models

  • Linear models \(\mathcal F = \{f_\theta(x)=\theta^\top x \mid \theta\in\mathbb R^d\}\)
  • ERM optimization problem is convex! $$\hat\theta = \arg\min_{\theta\in\mathbb R^d} \frac{1}{n}\sum_{i=1}^n(y_i-\theta^\top x_i)^2$$
  • In this case the gradient \(\nabla_\theta L(\theta)=\frac{2}{n}\sum_{i=1}^nx_i(y_i-\theta^\top x_i)\)
  1. Closed form \(\hat\theta=\left(\sum_{i=1}^nx_ix_i^\top \right)^{-1} \sum_{i=1}^ny_ix_i\)
    • unique only when inverse exists, otherwise family of solutions $$\theta\quad\text{s.t.}\quad \textstyle\left(\sum_{i=1}^nx_ix_i^\top\right) \theta=\sum_{i=1}^nx_iy_i$$
  2. GD (with proper step size) will converge to the solution with the smallest norm

Linear models

  • Linear models \(\mathcal F = \{f_\theta(x)=\theta^\top x \mid \theta\in\mathbb R^d\}\)
  • Can work very well in practice, especially in high dimensions
  • Need good features
    • Construct a transformation \(\varphi(x)\) using domain knowledge
    • Then define \(f_\theta(x) = \theta^\top\varphi(x)\)
  • In high dimensions regularization is important
    • Ridge penalty: add \(\lambda \|\theta\|_2^2\) to discourage large coefficients
    • Lasso penality: add \(\lambda \|\theta\|_1\) to encourage sparse coefficients

Deep models

  • a.k.a. deep networks, neural networks
  • Building blocks
    1. Linear transformation: \(Wx+b\)
    2. Nonlinear transformation, e.g. ReLU \(\sigma(u) = \max(u,0)\)
  • Simple but nontrivial example: \(f(x) = W_2\sigma(W_1x+b_1)+b_2\)
    1. Start with input \(x\in\mathbb R^d\)
    2. Linearly transform with \(W_1\in\mathbb R^{m\times d}\) and \(b_1\in\mathbb R^{m}\)
    3. Apply nonlinearity elementwise to get \(\sigma(W_1x+b_1)\)
    4. Linearly trasform with \(W_2\in\mathbb R^{1\times m}\) and \(b_2\in\mathbb R\)
  • Parameter \(\theta\) contains all \(W_i,b_i\), dimension scales with width and depth

Optimizing Deep models

  • Computing gradients requires chain rule over many linear and nonlinear functions composed together
  • A trick called backpropagation allows this to be done efficiently
    • Details not in scope for this class
  • However, \(L(\theta)\) is not generally convex, so there may be many local optima
  • In practice, many optimization tricks
    • Popular alternative to SGD is called Adam

Agenda

  1. Recap
  2. Supervised Learning
    • Setting, ERM, Optimization
  3. Example: Linear and Deep
  4. Supervised Learning in MDPs 
  • Supervised learning: features \(x\) and labels \(y\)
    • Goal: predict labels with \(\hat f(x)\approx \mathbb E[y|x]\)
    • Requirements: dataset \(\{x_i,y_i\}_{i=1}^N\)
    • Method: \(\hat f = \arg\min_{f\in\mathcal F} \sum_{i=1}^N (f(x_i)-y_i)^2\)
  • Important functions in MDPs
    • Transitions \(P(s'|s,a)\)
    • Value/Q of a policy \(V^\pi(s)\) and \(Q^\pi(s,a)\)
    • Optimal Value/Q \(V^\star(s)\) and \(Q^\star(s,a)\)
    • Optimal policy \(\pi^\star(s)\)

Supervised Learning for MDPs

  • Supervised learning: features \(x\) and labels \(y\)
  • Important functions in MDPs
    • Transitions \(P(s'|s,a)\)
      • features \(s,a,s'\), sampled outcomes observed
    • Value/Q of a policy \(V^\pi(s)\) and \(Q^\pi(s,a)\)
      • features \(s\) or \(s,a\), labels ?
    • Optimal Value/Q \(V^\star(s)\) and \(Q^\star(s,a)\)
      • features \(s\) or \(s,a\), labels ?
    • Optimal policy \(\pi^\star(s)\)
      • features \(s\), labels from expert, otherwise ?

Supervised Learning for MDPs

PSet 4, Unit 3

next week

next week

two weeks

Recap

  • PA 2 due Friday
  • PSet 4 released Friday
  • Prelim grades released next week

 

  • Supervised learning

 

  • Next lecture: Fitted Value Iteration