Dyego Araújo
Roberto I. Oliveira
Daniel Yukimura
(50% sure we can probably drop i.i.d. assumption.)
Independent random initialization of weights
Independent random initialization of weights
Layers \(\ell=0,1,\dots,L,L+1\):
Each layer \(\ell\) has \(N_\ell\) units of dimension \(d_\ell\).
"Weight" \(\theta^{(\ell)}_{i_\ell,i_{\ell+1}}\in\mathbb{R}^{D_\ell}\):
Full vector of weights: \(\vec{\theta}_N\).
Initialization:
At output layer \(L+1\), \(N_{L+1}=1\).
Weights close to input \(A_{i_1}:=\theta^{(0)}_{1,i_1}\) and output \(B_{i_L}:=\theta^{(L)}_{i_L,1}\) are not updated ("random features").
All other weights \(\theta^{(\ell)}_{i_\ell,i_{\ell+1}}\) (\(1\leq \ell\leq L-1)\): SGD with step size \(\epsilon\) and fresh samples at each iteration.
For \(\epsilon\ll 1\), \(N\gg 1\), after change in time scale, same as:
\[\frac{d}{dt}\vec{\theta}_N(t) = -N^2\, \nabla L_N(\vec{\theta}_N(t))\]
(\(N^2\) steps of SGD \(\sim 1\) time unit in the limit)
Assume \(L\geq 3\) hidden layers + technicalities.
One can couple the weights \(\theta^{(\ell)}_{i_\ell,i_{\ell+1}}(t)\) to certain limiting random variables \(\overline{\theta}^{(\ell)}_{i_\ell,i_{\ell+1}}(t)\) with small error:
Limiting random variables coincide with the weights at time 0. They also satisfy a series of properties we will describe below.
Full independence except at 1st and Lth hidden layers:
The following are all independent from one another.
1st and Lth hidden layers are deterministic functions:
Distribution \(\mu_t\) of limiting weights along a path at time \(t\),
\[(A_{i_1},\overline{\theta}^{(1)}_{i_1,i_2}(t),\overline{\theta}^{(2)}_{i_2,i_3}(t),\dots,\overline{\theta}^{(L)}_{i_L,i_{L+1}}(t),B_{i_{L+1}})\sim \mu_t\]
has the following factorization into independent components:
\[\mu_t = \mu^{(0,1)}_t\otimes \mu^{(2)}_t \otimes \dots \otimes \mu_{t}^{(L-1)}\otimes \mu_t^{(L,L+1)}.\]
Contrast with time 0 (full product).
At any time \(t\), the loss of function \(\widehat{y}(x,\vec{\theta}_N(t))\) is approximately the loss composition of functions of the generic form \(\int\,h(x,\theta)\,dP(\theta)\). Specifically,
\[L_N(\vec{\theta}_N)\approx \frac{1}{2}\mathbb{E}_{(X,Y)\sim P }\,\|Y - \overline{y}(X,t)\|^2\] where
Ansatz:
\(\Rightarrow\) McKean-Vlasov type of behavior.
Consider:
Self-consistency: \(Z(t)\sim \mu_t\) for all times \(t\geq 0\)
Under reasonable conditions, for any initial measure \(\mu_0\) there exists a unique trajectory
\[t\geq 0\mapsto \mu_t\in M_1(\mathbb{R}^D)\]
such that any random process \(Z\) satisfying
satisfies the McKean-Vlasov consistency property:
\[Z(t)\sim \mu_t, t\geq 0.\]
[McKean'1966, Gärtner'1988, Sznitman'1991, Rachev-Ruschendorf'1998,...]
Density \(p(t,x)\) of \(\mu_t\) evolves according to a (possibly) nonlinear PDE.
nonlinearity comes from \(\mu_t\sim p(t,x)\)
Meta-theorem:
Consider a system of \(N\gg 1\) evolving particles with mean field interactions.
If no single particle "prevails", then the thermodynamic limit is a McKean-Vlasov process.
Important example:
Mei-Montanari-Nguyen for networks with \(L=1\).
Function computed by network:
\[\widehat{y}(x,\vec{\theta}_N) = \frac{1}{N}\sum_{i=1}^N\sigma_*(x,\theta_i).\]
Loss:
\[L_N(\vec{\theta}_N):=\frac{1}{2}\mathbb{E}_{(X,Y)\sim P}(Y-\widehat{y}(X,\vec{\theta}_N))^2\]
Evolution of particle \(i\) in the right time scale:
where
\[\widehat{\mu}_{N,t}:= \frac{1}{N}\sum_{i=1}^N \delta_{\theta_i(t)}.\]
Assume a LLN holds and a deterministic limiting density \[\widehat{\mu}_{N,t}\to \mu_t\] emerges for large \(N\). Then particles should follow:
In particular, independence at time 0 is preserved at all times. Also limiting self consistency:
Typical \(\theta_i(t)\sim \widehat{\mu}_{N,t}\approx \mu_t\).
Let \(\mu_t,t\geq 0\) be the McK-V solution to
\[Z(0)\sim \mu_0, \frac{d}{dt}Z(t) = \psi(Z(t),\mu_t),Z(t)\sim \mu_t.\]
Now consider trajectories \(\overline{\theta}_i[0,T]\) of the form:
\[\overline{\theta}_i(0):=\theta_i(0),\; \frac{d}{dt}\overline{\theta}_i(t) = \psi(\overline{\theta}_i(t),\mu_t).\]
Then:
\[\mathbb{E}|\theta_i(t) - \overline{\theta}_i(t)|\leq Ce^{CT}\,(\epsilon + (D/N)^{1/2}).\]
At any time \(t\geq 0\),
\[|L_N(\vec{\theta}_N(t)) - L(\mu_t)| \leq C_t\,(\epsilon + N^{-1/2}),\]
where, for \(\mu\in M_1(\mathbb{R}^d)\),
\[L(\mu) = \frac{1}{2}\mathbb{E}_{(X,Y)\sim P}\left(Y - \int_{\R^d}\,\sigma_*(X,\theta)\,d\mu(\theta)\right)^2.\]
The abstract McK-V framework is still useful. However, the dependencies in our system are trickier.
Recursion at time t:
Recursion at time t for \(\ell=2,3,\dots,L+1\):
Backwards recursion:
Recall \(\widehat{y}(x,\vec{\theta}_N)=\sigma^{(L+1)}(z^{(L)}_1(x,\vec{\theta}_N))\). Define:
Limiting version:
For \(\ell\leq L-2\), define:
For \(2\leq \ell\leq L-2\), backprop formula:
is replaced in the evolution of \(\overline{\theta}^{(\ell)}_{i_\ell,i_{\ell+1}}\) by:
\[\overline{M}^{(\ell+1)}(x,t)\,D_\theta\sigma^{(\ell)}_*(\overline{z}^{(\ell)}(x,t),\overline{\theta}^{(\ell)}_{i_\ell,i_{\ell+1}}(t))\]
For \(2\leq \ell\leq L-2\), the time derivative of \(\overline{\theta}^{(L)}(t)\) should take the form
where the \(\overline{z}\) are determined by the marginals:
Find unique McK-V process \((\mu_t)_{t\geq 0}\) that should correspond to weights on a path from input to output:
\[(A_{i_1},\theta^{(1)}_{i_1,i_2}(t),\dots,\theta^{(L-1)}_{i_{L-1},i_{L}}(t),B_{i_{L}})\approx \mu_t.\]
Prove that trajectories have the right dependency structure.
\[\mu_{[0,T]} = \mu^{(0,1)}_{[0,T]}\otimes \mu^{(2)}_{[0,T]} \otimes \mu_{[0,T]}^{(L-1)}\otimes \mu_{[0,T]}^{(L,L+1)}.\]
Populate the network with limiting weight trajectories:
1. I.i.d. part:
Generate i.i.d. random variables: \[A_{i_1}\sim \mu_0^{(0)},B_{i_L}\sim \mu^{(L+1)}_0, \overline{\theta}_{i_\ell,i_{\ell+1}^{(\ell)}([0,T])}\sim \mu^{(\ell)}_{[0,T]}\,(2\leq \ell\leq L-2)\]
2. Conditionally independent part:
For each pair \(i_1,i_2\), generate \(\theta^{(1)}_{i_1,i_2}([0,T])\) from the conditional measure \(\mu^{(0,1)}_{[0,T]}\) given the weight \(A_{i_1}\). Similarly for weights \(\theta^{(L-1)}_{i_{L-1},i_L}\).
The limiting weights we generated coincide with true weights at time 0. From McK-V drift conditions, can check via Gronwall's inequality that limiting and true weights are close.
Here is an illustration of what the system does to us.
[Public domain/Wikipedia]
What goes wrong if weights close to input and output are allowed to change over time?
These evolve at a faster time scale.
Cannot be dealt with via our techniques.
Requires deeper understanding of the limit.
However: if input has \(D\approx N\) i.i.d. coordinates, no need for random features.
(90% probability of this statement being correct)
The limiting PDE for the measures is horrendous!
True, but this leads to an interesting problem.
Forget SGD: start from a blank slate!
\(x\mapsto h_{P}(x):= \int\,h(x,\theta)\,dP(\theta)\) where \(P\in M_1(\mathbb{R}^{D})\).