Super-efficiency of automatic differentiation for functions defined as a minimum
1: ENS and CNRS, PSL university 2: Inria and CEA
Context
Sparse coding
Let \(k\) images \(u_1, ..., u_k\) in \(\mathbb{R}^d\).
\(u_1, ..., u_k\) :
\(D=[d_1, ...,d_m]\) a set of points in \(\mathbb{R}^d\).
\(D\) is the dictionary
\(D =\)
Sparse coding (contd)
Each image is decomposed in the basis of \(D\):
\(u_i \simeq z^i_1 d_1 +\cdots + z^i_m d_m\)
\(\simeq \)
\(3.1 \times\)
\(-1.4 \times\)
\(+ \cdots+0.6 \times\)
\(u_i\)
\(z^i_1 \)
\(z^i_2 \)
\(z^i_m \)
\(d_1\)
\(d_2\)
\(d_m\)
\(u_i \simeq Dz_i\)
Sparse coding (contd)
\(u_i \simeq Dz_i\)
Sparse coding: \(z\) should be sparse (many zeros). Estimated with the Lasso:
Dictionary learning
Instead of a fixed \(D\), learn it from the data:
Double minimization problem !
We would like to do gradient descent on \(D\).
How to compute the gradient of \(\ell\)?
Computing the gradient of a minimum
- Let \(\mathcal{L}(z, x)\) a smooth function where \(z\in \mathbb{R}^m\) and \(x\in\mathbb{R}^n\), and
$$\boxed{\ell(x) = \min_z \mathcal{L}(z, x)}$$
- We want to compute \(\nabla \ell(x)\)
Examples: min-min optimization
- Dictionary learning
- Alternate optimization
- Frechet means
max-min optimization
- Game theory
- GANs
Technical assumptions
- \(\ell(x) = \min_z \mathcal{L}(z, x)\), we want to compute \(\nabla \ell(x)\)
We assume:
- For all \(x\), \(\mathcal{L}\) has a unique minimizer \(z^*(x)=\arg\min_z \mathcal{L}(z, x)\)
$$\ell(x) = \mathcal{L}(z^*(x), x)$$
- \(x\to z^*(x) \) is differentiable
Consequence (Danskin):
Approximate optimization
- \(\ell(x) = \min_z \mathcal{L}(z, x)\), we want to compute \(\nabla \ell(x)\)
-No closed-form: We assume that we only have access to a sequence \(z_t\) such that \(z_t \to z^*\)
Example: \(z_t\) produced by gradient descent on \(\mathcal{L}\)
- \(z_0 = 0\)
- \(z_{t+1} = z_t - \rho \nabla_1 \mathcal{L}(z_t, x)\)
- \(z_t\) depends on \(x\): \(z_t(x)\)
How can we approximate \(\nabla \ell(x)\) using \(z_t(x)\) ? At which speed ?
The analytic estimator
Very simple to compute: simply plug \(z_t\) in \(\nabla_2 \mathcal{L}\) !
The automatic estimator
\(g^2_t\) is computed by automatic differentiation:
- Easy to code (e.g. in Python, use pytorch, tensorflow, autograd...)
- \(\sim\) as costly to compute as \(z_t\)
- Memory cost linear with # iterations
The implicit estimator
Implicit function theorem:
- Need to invert a linear system: might be too costly to compute
Three gradient estimators
- Analytic: \(g^1_t(x) = \nabla_2 \mathcal{L}(z_t(x), x)\)
- Automatic: \(g^2_t(x) = \frac{\partial z_t}{\partial x} \nabla_1\mathcal{L}(z_t(x), x) + \nabla_2 \mathcal{L}(z_t(x), x)\)
- Implicit: \(g^3_t(x) = \mathcal{J}(z_t(x), x) \nabla_1\mathcal{L}(z_t(x), x) + \nabla_2 \mathcal{L}(z_t(x), x)\)
These estimators are all "consistent":
If \(z_t(x)= z^*(x)\)
Convergence speed of the estimators
Toy example (regularized logistic regression)
- \(D\) rectangular matrix
- \(\mathcal{L}(z, x) = \sum_{i=1}^n\log\left(1 + \exp(-x_i[Dz]_i)\right) + \frac{\lambda}{2}\|z\|^2\)
- \(z_t\) produced by gradient descent
Implicit estimator
If \(\nabla_2 \mathcal{L} \) is \(L\)-Lipschitz:
- The implicit estimator converges at the same speed as \(z_t\)
Automatic estimator
It is of the form \(g^2_t(x) = J \nabla_1 \mathcal{L}(z_t, x) + \nabla_2 \mathcal{L}(z_t, x)\), \(J\) rectangular matrix
Taylor expansion around \(z^*\):
Automatic estimator
Implicit function theorem: \(R(J) = J\nabla_{11}\mathcal{L}(z^*, x) + \nabla_{21}\mathcal{L}(z^*, x)\) cancels when \(J = \frac{\partial z^*}{\partial x}\)
If \(\frac{\partial z_t}{\partial x} \to \frac{\partial z^*}{\partial x}\) then \(R(J) \to 0\) and:
Super-efficiency: Stricly faster than \(g^1_t\) !
Implicit estimator
Twice as fast as \(g^1_t\) !
If \(\mathcal{J}\) is Lipschitz w.r.t. \(z\):
Example
Gradient descent in the strongly convex case
Assumptions and reminders
-We assume that \(\mathcal{L}\) is \(\mu\)-strongly convex and \(L\)-smooth w.r.t. \(z\) for all \(x\).
- \(z_t\) produced by gradient descent with step \(1/L\):
\(z_{t+1} = z_t - \frac1L \nabla_1\mathcal{L}(z_t, x)\)
Linear convergence:
Convergence of the Jacobian
\(z_{t+1} = z_t - \frac1L \nabla_1\mathcal{L}(z_t, x)\)
Let \(J_t =\frac{\partial z_t}{\partial x}\) :
\(J_{t+1} =J_t - \frac1L \left(J_t\nabla_{11}\mathcal{L}(z_t, x) + \nabla_{21}\mathcal{L}(z_t, x)\right) \)
[Gilbert, 1992]:
Convergence speed of the estimators
In the paper...
- Analysis in stochastic gradient descent case with step \(\propto t^{-\alpha}\):
\(|g^1_t - g^*| = O(t^{-\alpha})\), \(|g^2_t - g^*| = O(t^{-2\alpha})\), \(|g^3_t - g^*| = O(t^{-2\alpha})\)
- Consequences on optimization: gradients are wrong, so what?
- Practical guidelines taking time and memory into account
Thanks for you attention !
Super-efficiency autodiff
By Pierre Ablin
Super-efficiency autodiff
- 625