Artëm Sobolev
Research Scientist in Machine Learning
Artëm Sobolev, AI Engineer at Luka Inc
@art_sobolev | http://artem.sobolev.name/
This talk: What if z∼p(z∣θ)?
(and we'd like to find optimal θ as well)
Let F(z)=L(f(x,z∣ϕ),y), then
∇θEp(z∣θ)F(z)=∇θ∫F(z)p(z∣θ)dz=∫F(z)∇θp(z∣θ)dz =∫F(z)∇θlogp(z∣θ)p(z∣θ)dz=Ep(z∣θ)∇θlogp(z∣θ)F(z)
Intuition: push probabilities of good samples (as measured by F(z)) up.
Pros: very general, does not require differentiable F.
Cons: known to have large variance, sensitive to values of F.
We'll get back to this estimator later in the talk.
Pros: backprop through sample
Cons: requires ability to differentiate CDF w.r.t. θ
⇒
∇θEp(ε∣θ)F(z)=Ep(ε∣θ)∇θF(Tθ(ε))+Ep(ε∣θ)∇θlogp(ε∣θ)F(Tθ(ε))
The formula ∇θEp(ε∣θ)F(z)=Ep(ε∣θ)∇θF(Tθ(ε))+Ep(ε∣θ)∇θlogp(ε∣θ)F(Tθ(ε)) requires us to sample ε∣θ. With a bit of algebra we can rewrite these addends in terms of samples z∣θ: Ep(z∣θ)∇zF(z)∇θhθ(Tθ−1(z)) Ep(z∣θ)F(z)[∇θlogp(z∣θ)+∇zlogp(z∣θ)hθ(Tθ−1(z))+uθ(Tθ−1(z))]
Where
Pros: interpolates between reparametrisation and REINFORCE
Cons: need to come up with differentiable Tθ
F(μ,σ)=Ez∼N(μ,σ2)[z2+c]=Eε∼N(0,1)[(μ+εσ)2+c]→μ,σmin
∇^μFrep(μ,σ)=2(μ+σε)∇^μFSF(μ,σ)=σε((μ+σε)2+c)
∇^σFrep(μ,σ)=2ε(μ+σε)∇^σFSF(μ,σ)=σε2−1((μ+σε)2+c)
D[∇^μFrep(μ,σ)]=4σ2D[∇^σFrep(μ,σ)]=4μ2+8σ2
D[∇^μFSF(μ,σ)]=σ2(μ2+c)2+15σ2+14μ2+6c
D[∇^σFSF(μ,σ)]=σ22(μ2+c)2+74σ2+60μ2+20c
Mu
z=argmaxk[γk+logpk]
ζ=softmaxτ(γk+logpk)
⇒
Pros: works in categorical case, temperature controls bias
Cons: still biased, not clear how to tune temperature
Many other estimators only relax the backward pass
Pros: don't see any
Cons: mathematically unsound ¯\_(ツ)_/¯
How to design a control variate?
Pros: more efficient, still easy to implement
Cons: requires training an extra model b(x) (unless VIMCO), does not use z in the baseline, doesn't use gradient ∇zF(z)
Pros: uses gradient information ∇zF(z)
Cons:
∇θEp(X∣θ)F(στ(X))=∇θEp(X,z∣θ)F(στ(X))=
Ep(z∣θ)[∇θEp(X∣z,θ)F(στ(X∣z))]+Ep(z∣θ)Ep(X∣z,θ)[F(στ(X∣z))]∇θlogp(z∣θ)
We arrive to the following formula
∇θEp(z∣θ)F(z)=Eu,v[(F(z)−ηF(ζ∣z))∇θlogp(z∣θ)+η∇θF(ζ)−η∇θF(ζ∣z)]
Where z=H(X), ζ=στ(X), ζ∣z=στ(X∣z)
X=log1−uu+log1−μ(θ)μ(θ)
η and τ are tuneable parameters, optimised to reduce the variance
Pros:
Cons:
What if F(z) is not differentiable or we don't know its gradients (like in RL)?
The F~ is optimized to minimize the variance Varg^i=Eg^i2−(Eg^i)2 g^i is unbiased, hence the second term does not depend on F~
By Artëm Sobolev
My talk on stochastic computation graphs for BayesGroup seminar