Super-efficiency of automatic differentiation for functions defined as a minimum
Pierre Ablin\(^1\), Gabriel Peyré\(^1\), Thomas Moreau\(^2\)
1: ENS and CNRS, PSL university 2: Inria and CEA
Context
Context: computing the gradient of a minimum
- Let \(\mathcal{L}(z, x)\) a smooth function where \(z\in \mathbb{R}^m\) and \(x\in\mathbb{R}^n\), and
$$\boxed{\ell(x) = \min_z \mathcal{L}(z, x)}$$
- We want to compute \(\nabla \ell(x)\)
Examples: min-min optimization
- Dictionary learning
- Alternate optimization
- Frechet means
max-min optimization
- Game theory
- GANs
Technical assumptions
- \(\ell(x) = \min_z \mathcal{L}(z, x)\), we want to compute \(\nabla \ell(x)\)
We assume:
- For all \(x\), \(\mathcal{L}\) has a unique minimizer \(z^*(x)=\arg\min_z \mathcal{L}(z, x)\)
$$\ell(x) = \mathcal{L}(z^*(x), x)$$
- \(x\to z^*(x) \) is differentiable
Consequence (Danskin):
Approximate optimization
- \(\ell(x) = \min_z \mathcal{L}(z, x)\), we want to compute \(\nabla \ell(x)\)
-No closed-form: We assume that we only have access to a sequence \(z_t\) such that \(z_t \to z^*\)
Example: \(z_t\) produced by gradient descent on \(\mathcal{L}\)
- \(z_0 = 0\)
- \(z_{t+1} = z_t - \rho \nabla_1 \mathcal{L}(z_t, x)\)
- \(z_t\) depends on \(x\): \(z_t(x)\)
How can we approximate \(\nabla \ell(x)\) using \(z_t(x)\) ? At which speed ?
The analytic estimator
Very simple to compute: simply plug \(z_t\) in \(\nabla_2 \mathcal{L}\) !
The automatic estimator
\(g^2_t\) is computed by automatic differentiation:
- Easy to code (e.g. in Python, use pytorch, tensorflow, autograd...)
- \(\sim\) as costly to compute as \(z_t\)
- Memory cost linear with # iterations
The implicit estimator
Implicit function theorem:
- Need to invert a linear system: might be too costly to compute
Three gradient estimators
- Analytic: \(g^1_t(x) = \nabla_2 \mathcal{L}(z_t(x), x)\)
- Automatic: \(g^2_t(x) = \frac{\partial z_t}{\partial x} \nabla_1\mathcal{L}(z_t(x), x) + \nabla_2 \mathcal{L}(z_t(x), x)\)
- Implicit: \(g^3_t(x) = \mathcal{J}(z_t(x), x) \nabla_1\mathcal{L}(z_t(x), x) + \nabla_2 \mathcal{L}(z_t(x), x)\)
These estimators are all "consistent":
If \(z_t(x)= z^*(x)\)
Convergence speed of the estimators
Toy example (regularized logistic regression)
- \(D\) rectangular matrix
- \(\mathcal{L}(z, x) = \sum_{i=1}^n\log\left(1 + \exp(-x_i[Dz]_i)\right) + \frac{\lambda}{2}\|z\|^2\)
- \(z_t\) produced by gradient descent
Implicit estimator
If \(\nabla_2 \mathcal{L} \) is \(L\)-Lipschitz:
- The implicit estimator converges at the same speed as \(z_t\)
Automatic estimator
It is of the form \(g^2_t(x) = J \nabla_1 \mathcal{L}(z_t, x) + \nabla_2 \mathcal{L}(z_t, x)\), \(J\) rectangular matrix
Taylor expansion around \(z^*\):
Automatic estimator
Implicit function theorem: \(R(J) = J\nabla_{11}\mathcal{L}(z^*, x) + \nabla_{21}\mathcal{L}(z^*, x)\) cancels when \(J = \frac{\partial z^*}{\partial x}\)
If \(\frac{\partial z_t}{\partial x} \to \frac{\partial z^*}{\partial x}\):
Super-efficiency: Stricly faster than \(g^1_t\) !
Implicit estimator
Twice as fast as \(g^1_t\) !
If \(\mathcal{J}\) is Lipschitz w.r.t. \(z\):
Example
Gradient descent in the strongly convex case
Assumptions and reminders
-We assume that \(\mathcal{L}\) is \(\mu\)-strongly convex and \(L\)-smooth w.r.t. \(z\) for all \(x\).
- \(z_t\) produced by gradient descent with step \(1/L\):
\(z_{t+1} = z_t - \frac1L \nabla_1\mathcal{L}(z_t, x)\)
Linear convergence:
Convergence of the jacobian
\(z_{t+1} = z_t - \frac1L \nabla_1\mathcal{L}(z_t, x)\)
Let \(J_t =\frac{\partial z_t}{\partial x}\) :
\(J_{t+1} =J_t - \frac1L \left(J_t\nabla_{11}\mathcal{L}(z_t, x) + \nabla_{21}\mathcal{L}(z_t, x)\right) \)
[Gilbert, 1992]:
Convergence speed of the estimators
In the paper...
- Analysis in stochastic gradient descent case with step \(\propto t^{-\alpha}\):
\(|g^1_t - g^*| = O(t^{-\alpha})\), \(|g^2_t - g^*| = O(t^{-2\alpha})\), \(|g^3_t - g^*| = O(t^{-2\alpha})\)
- Consequences on optimization: gradients are wrong, so what?
- Practical guidelines taking time and memory into account
Thanks for you attention !
icml presentation
By Pierre Ablin
icml presentation
- 602