Symbolic and Automatic Differentiation in Lean
A stepping stone towards scientific computing in Lean
Tomáš Skřivan
What is Automatic Differentiation?
x:X
y:Y
f
-
Transform program f to a program ∂f that computes derivative of f
x:X
dy:Y
∂f
dx:X
x:X
dx:X
∂†f
dy:Y
∇f=∂†f1 for f:X→R
∂†fx=(∂fx)†
Why Scientific Computing in Lean?
Scientific computing is full of mathematics
Lean understands mathematics
+
The goal: To have an interactive computer algebra system operating directly on the source code.
- symbolic computation - no jumping between C++ and Mathematica
- source code transformation - e.g. automatic differentiation
- optimization - use mathematical argument to produce faster code (e.g. Halide)
- specification to implementation
- you do not have to be compiler engineer to do these
Lean is general purpose programming language with powerful metaprogramming
(similar motivation as "Why We Created Julia")
+
PROFIT
?
Personal Motivation: Physics Simulation in Computer Graphics
Main motivation is to reduce the time it takes to create the first working prototype!
Example: Harmonic Oscillator
def H (m k : ℝ) (x p : ℝ) := (1/(2*m)) * ∥p∥² + k/2 * ∥x∥² approx solver (m k : ℝ) (steps : Nat) := odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x p', -∇ (x':=x), H m k x' p)) by -- Unfold Hamiltonian and compute gradients unfold H symdiff; symdiff -- Apply RK4 method rw [odeSolve_fixed_dt runge_kutta4_step] approx_limit steps; simp; intro steps';
Goals (1) m k : ℝ steps : ℕ ⊢ Approx (odeSolve fun t x => (1 / m * x.snd, -(k * x.fst)))
Goals (1) m k : ℝ steps : ℕ ⊢ Approx (odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x p', -∇ (x':=x), H m k x' p)))
Goals (1) m k : ℝ steps steps' : ℕ ⊢ Approx (odeSolve_fixed_dt_impl steps' runge_kutta4_step fun t x => (1 / m * x.snd, -(k * x.fst)))
(x, p) := solver m k substeps t (x, p) Δt
Two Main Operations: Differential and Adjoint
Differential ∂:(X→Y)→(X→X→Y)
Adjoint †:(X→Y)→(Y→X)
x:X
f
f
y:Y
x:X
∂f
dy:Y
dx:X
∂
x:X
f
f
y:Y
y:Y
f
f†
x:X
†
Simplifier and Typeclass approach
Differential ∂
Adjoint †
instance comp_is_smooth (f : Y → Z) (g : X → Y) [IsSmoothT f] [IsSmoothT g] : IsSmoothT (λ x => f (g x)) := ... @[simp] theorem chain_rule (f : Y → Z) (g : X → Y) [IsSmoothT f] [IsSmoothT g] : ∂ (λ x => f (g x)) = λ x dx => ∂ f (g x) (∂ g x dx) := ...
instance comp_has_adjoint (f : Y → Z) (g : X → Y) [IsLinT f] [IsLinT g] : IsLinT (λ x => f (g x)) := ... @[simp] theorem adj_of_comp (f : Y → Z) (g : X → Y) [IsLinT f] [IsLinT g] : (λ x => f (g x))† = λ z => g† (f† z) := ...
Lambda calculus and SKI combinators
instance I_is_smooth : IsSmoothT λ x : X => x := ... @[simp] theorem diff_of_I : ∂ (λ x : X => x) = λ x dx => dx := ...
instance K_is_smooth (x : X) : IsSmoothT λ (y : Y) => x := ... @[simp] theorem diff_of_K : ∂ (λ y : Y => x) = λ y dy => (0 : X) := ...
instance S_is_smooth (f : X → Y → Z) (g : X → Y) [IsSmoothNT 2 f] [IsSmoothT g] : IsSmoothT (λ x => f x (g x)) := ... @[simp] theorem diff_of_S (f : X → Y → Z) (g : X → Y) [IsSmoothNT 2 f] [IsSmoothT g] : ∂ (λ x => f x (g x)) = λ x dx => ∂ f x dx (g x) + ∂ (f x) (g x) (∂ g x dx) := ...
Differential ∂
Lambda calculus and SKI combinators
Differential ∂
instance (priority := low) swap_is_smooth (f : α → X → Y) [∀ a, IsSmoothT (f a)] : IsSmoothT (λ x a => f a x) := ... @[simp low] theorem diff_of_swap (f : α → X → Y) [∀ a, IsSmoothT (f a)] : ∂ (λ x a => f a x) = λ x dx a => ∂ (f a) x dx := ...
instance swap'_is_smooth (f : X → α → Y) [IsSmoothT f] : ∀ a, IsSmoothT (λ x => f x a) := ... @[simp] theorem diff_of_swap' (f : X → α → Y) [IsSmoothT f] (a : α) : ∂ (λ x => f x a) = λ x dx => ∂ f x dx a := ...
Lambda calculus and SKI combinators
Differential ∂
Adjoint †
Forward Mode AD and Functoriality
g
∂
x:X
∂(f∘g)
dz:Z
dx:X
?
x:X
f
g
f
f
z:Z
∂f
x:X
dx:X
∂g
y:Y
dy:Y
dz:Z
Forward Mode AD and Functoriality
𝒯 (λ x => f (g x)) = λ x dx => let (y,dy) := 𝒯 g x dx let (z,dz) := 𝒯 f y dy (z,dz)
x:X
T(f∘g)
dz:Z
dx:X
z:Z
x:X
Tg
dy:Y
dx:X
y:Y
Tf
dz:Z
z:Z
x:X
f
g
f
f
z:Z
Reverse Mode AD and Functoriality
x:X
f
g
?
f
f
z:Z
x:X
∂†(f∘g)
dx:X
dz:Z
∂†
x:X
∂†g
dx:X
y:Y
∂†f
dy:Y
dz:Z
Reverse Mode AD and Functoriality
x:X
f
g
?
f
f
z:Z
x:X
R(f∘g)
z:Z
∂†(f∘g)x:Z→X
∂†(f∘g)x:Z→X
ℛ (λ x => f (g x)) = λ x => let (y,dg') := ℛ g x let (z,df') := ℛ f y (z, λ dz => dg' (df' dz))
∂†
x:X
Rg
y:Y
∂†gx:Y→X
Rf
z:Z
∂†fy:Z→Y
Forward Mode AD and Let Bindings
∂ (λ x => let y := g x let z := f y x + y + z) rewrite_by autodiff
fun x dx => let y := g x let dy := ∂ g x dx let z := f y let dz := ∂ f y dy dx + dy + dz
=
- Working prototype of differentiating through let bindings as a custom simplifier step. Currently, the step does not produce a proof.
- Ideally use T
∂ (λ x => let y := g x let z := f y x + y + z) rewrite_by autodiff
λ x dx => let (y,dy) := 𝒯 g x dx let (z,dz) := 𝒯 f y dy dx + dy + dz
=
Differentiation Approach Overview
-
Two main operations ∂ and †
- rules for I, K, S, C, C' and primitive functions
- their correctness needs to be proven using mathlib
-
Derived operations ∂†,T,R
- their rules can be easily derived from rules of ∂ and †
- main purpose is for efficiency - faster symbolic computation and more efficient resulting code
-
Special handling of let bindings
- works for ∂ but not verified
- needs to be done for other operations
- But there are problems ...
IsLinT is Transitive Closure of IsLin
- We want:
- IsLin f to imply IsSmooth f
- ∂fxdx=fdx for f linear
- Problem: Attempting to prove linearity on every function is too expensive!
- Solution: Introduce a transitive closure IsLinT f of IsLin f via rules for I, K, S, C, C'.
- The class IsLin f is used only on elementary functions like (⋅+⋅),(⋅∗⋅),(−⋅)
Multiple Arguments
example (f : X → Y → Z) [IsSmoothNT 2 f] : differential (λ (x,y) => f x y) = λ (x,y) (dx,dy) => ∂ (λ x' => f x' y) x dx + ∂ (λ y' => f x y') y dy := by symdiff; done
- Proving IsSmoothT (λ x => λ y ⟿ f x y) can be difficult
- This can't be done with linearity. We have special variants of I,K,S,C,C' rules for IsLinNT 2 f, IsLinNT 3 f, ...
example (f : X → Y → Z) [∀ x, IsSmoothT (λ y => f x y)] [IsSmoothT (λ x => λ y ⟿ f x y)] : [IsSmoothNT 2 f] := ...
Smoothness w.r.t to (x,y) can be reduced to smoothness w.r.t to y and x (with values in Y⟿Z)
Differentiation w.r.t to (x,y) can be reduced to differentiation w.r.t to x and y
Unification Problems
Applying composition rule is difficult!
instance comp_is_smooth (f : Y → Z) (g : X → Y) [IsSmoothT f] [IsSmoothT g] : IsSmoothT (λ x => f (g x)) := ...
Adding trailing argument trips unification
example (f' : Y → α → Z) (g : X → Y) (a : α) [IsSmoothT (fun y => f' y a)] [IsSmoothT g] : IsSmoothT (λ x => f' (g x) a) := by try infer_instance -- fails to solve apply comp_is_smooth (fun y => f' y a) g
unif_hint (f : Y → Z) (f' : Y → α → Z) (g : X → Y) (a : α) where f =?= λ y => f' y a |- IsSmoothT (λ x => f (g x)) =?= IsSmoothT (λ x => f' (g x) a)
Solved by unification hint
It is not over ... and it goes on and on ...
example (f : Y → α → β → Z) (g : X → Y) (a : α) (b : β) [IsSmoothT (fun y => f y a b)] [IsSmoothT g] : IsSmoothT (λ x => f (g x) a b) := by infer_instance -- still fails!
Simp Guard
[Meta.Tactic.simp.rewrite] @SciLean.comp.arg_x.adj_simp:100, SciLean.adjoint fun x => f (g x) a ==> fun z => SciLean.adjoint (fun x => g x) (SciLean.adjoint (fun x => f x a) z) [Meta.Tactic.simp.rewrite] @SciLean.comp.arg_x.adj_simp:100, SciLean.adjoint fun x => f x a ==> fun z => SciLean.adjoint (fun x => x) (SciLean.adjoint (fun x => f x a) z) [Meta.Tactic.simp.rewrite] @SciLean.comp.arg_x.adj_simp:100, SciLean.adjoint fun x => f x a ==> fun z => SciLean.adjoint (fun x => x) (SciLean.adjoint (fun x => f x a) z) ...
example (...) : (λ x => f (g x) a)† = λ x' => g† ((λ x => f x a)† x') := by simp -- infinite loop
unif_hint (...) where f =?= λ x => f' x a |- (λ x => f (g x))† =?= (λ x => f' (g x) a)†
An unification hint can cause an infinite loop
Simp Guard
@[simp, simp_guard g (λ x => x)] theorem comp.arg_x.adj_simp (f : Y → Z) [HasAdjointT f] (g : X → Y) [HasAdjointT g] : (λ x => f (g x))† = λ z => g† (f† z) := ...
Solution: simp_guard g (λ x => x)
- do not apply simp rule if g is equal to λ x => x
Problematic Adjoints
Creating constant arrays is wasteful!
Creating almost everywhere zero arrays is wasteful!
Problematic Adjoints and Unification
Matrix transposition
example {n m} (A : Fin n → Fin m → ℝ) : (λ (x : Fin m → ℝ) => λ i => ∑ j, A i j * x j)† = (λ y => ∑ j i', λ i => [[i'=i]] * A j i' * y j) := by symdiff; done
@[simp] theorem adjoint_sum_eval (f : ι → κ → X → Y) [∀ i j, HasAdjointT (f i j)] : (λ (x : κ → X) => λ i => ∑ j, (f i j) (x j))† = λ y => λ j => ∑ i, (f i j)† (y i) := ...
Ok, this is somewhat hard. Let's add a new simp rule:
Problematic Adjoints and Unification
Now we have:
Ups
unif_hint (f? : ι → κ → X → Y) (f : ι → κ → X → α → Y) (g : ι → κ → α) where f? =?= λ i j x => f i j x (g i j) |- (λ (x : κ → X) => λ i => ∑ j, (f? i j) (x j))† =?= (λ (x : κ → X) => λ i => ∑ j, f i j (x j) (g i j))†
Let's add an unification hint:
Problematic Adjoints and Unification
Now we have:
unif_hint (f? : ι → κ → X → Y) (f : ι → κ → W → α → Y) (g : ι → κ → α) (h : ι → κ → X → W) where f? =?= λ i j x => f i j (h i j x) (g i j) |- (λ (x : κ → X) => λ i => ∑ j, (f? i j) (x j))† =?= (λ (x : κ → X) => λ i => ∑ j, f i j (h i j (x j)) (g i j))†
Let's add an unification hint:
Ups
Problematic Adjoints and Unification
Now we have:
I give up
Calculus of Variations
Adjunction on Semi-Hilbert Spaces
Adjunction of A:(N→R)→(N→R)
Sensible only if the matrix has only finitely many nonzero elements in every column and row
Adjunction of dxd:(R⟿R)→(R⟿R)
True only for f and g compactly supported and when integrating over large enough domain.
Adjoint Failure of K, C and eval rules
Adjoint failure of K:
Adjoint failure of eval:
As a consequence we do not have C, C' (can't take an adjoint of the green arrow)
but we still have the product rule
Two Main Goals for Adjunction
example {n k} (x : Fin n → ℝ) : (λ (w : Fin k → ℝ) => λ (i : Fin n) => ∑ j, w j * x (i + j) )† = λ (y : Fin n → ℝ) => λ (j : Fin k) => ∑ i, x i * y (i - j) := ...
Convolution:
Mixing f(x) and f′(x)
example (p q : ℝ ⟿ ℝ) : (λ (f : ℝ ⟿ ℝ) => λ x ⟿ p x * f x + q x * ⅆ f x)† = (λ (f : ℝ ⟿ ℝ) => λ x ⟿ p x * f x - ⅆ (x':=x), q x' * f x') := ...
Manifold Like Types
Types like R+,Array X,X⊕Y,{x:X,∥x∥=1 },N↭M are not vector spaces!
We need tangent spaces:
- Ta(Array X)={b : Array X, b.size = a.size}
- Tx(X⊕Y)=X Ty(X⊕Y)=Y
- Tx{x:X,∥x∥=1 }={v:X, v⊥x }
- Tf(N⟿M)=(x:N)⇝ Tf(x)M
What is the good mathematical semantics? Diffeological spaces?
How to Handle Imperative Code?
∂ λ (x₀ : ℝ^{n}) => Id.run do let mut x := x₀ for i in [0:10] do x := x + Δt * f x x
λ (x₀ dx₀: ℝ^{n}) => Id.run do let mut x := x₀ let mut dx := dx₀ for i in [0:10] do (x,dx) := (x,dx) + Δt * 𝒯 f x dx dx
?
λ (x₀ dx₀: ℝ^{n}) => Id.run do let mut x := x₀ let mut dx := dx₀ for i in [0:10] do dx := dx + Δt * ∂ f x dx x := x + Δt * f x dx
Differentiating Monadic Code
(f∘g) := λ x => let (y, dg) := g x let (z, df) := f y (z, df∘dg)
?
(f∘g) := λ x => do let y ← g x; f y
ℱ f x = (f x, λ dx => ∂ f x dx)
pure
pure
(f∘g) := λ x => do let (y,dg) ← g x let (z,df) ← f y let dfg := λ dx => do let dy ← dg dx df dy pure (z, dfg)
Maybe Kan extension?
Local Differentiability and Reduced Regularity
Non-smooth functions that we want to differentiate:
Potential approaches:
- Smooth approximation
- ∥x∥≈ ∥x∥ϵ=∥x∥2+ϵ2
- x1≈∥x∥ϵ2x
- Local smoothness, new predicate: IsSmoothAtfx
- type class inference would need to automatically prove for example that x=0
- Work in category with maps Lip, Ck or Ck,α.
- Are these categories Cartesian closed?
- Does some form of chain rule hold for Lipschitz functions?
deck
By lecopivo
deck
- 202