Symbolic and Automatic Differentiation in Lean
A stepping stone towards scientific computing in Lean
Tomáš Skřivan
What is Automatic Differentiation?
\(x:X\)
\(y :Y\)
\(f \)
-
Transform program \(f\) to a program \(\partial \, f\) that computes derivative of \(f\)
\(x:X\)
\(dy :Y\)
\(\partial \, f \)
\(dx:X\)
\(x:X\)
\(dx :X\)
\(\partial^\dagger \, f \)
\(dy:Y\)
\(\nabla f = \partial^\dagger \, f \,\,1\) for \(f : X \rightarrow \mathbb{R}\)
\(\partial^\dagger \, f \, x = (\partial\, f \,x)^\dagger \)
Why Scientific Computing in Lean?
Scientific computing is full of mathematics
Lean understands mathematics
+
The goal: To have an interactive computer algebra system operating directly on the source code.
- symbolic computation - no jumping between C++ and Mathematica
- source code transformation - e.g. automatic differentiation
- optimization - use mathematical argument to produce faster code (e.g. Halide)
- specification to implementation
- you do not have to be compiler engineer to do these
Lean is general purpose programming language with powerful metaprogramming
(similar motivation as "Why We Created Julia")
+
PROFIT
?
Personal Motivation: Physics Simulation in Computer Graphics
Main motivation is to reduce the time it takes to create the first working prototype!
Example: Harmonic Oscillator
def H (m k : ℝ) (x p : ℝ) := (1/(2*m)) * ∥p∥² + k/2 * ∥x∥²
approx solver (m k : ℝ) (steps : Nat)
:= odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x p',
-∇ (x':=x), H m k x' p))
by
-- Unfold Hamiltonian and compute gradients
unfold H
symdiff; symdiff
-- Apply RK4 method
rw [odeSolve_fixed_dt runge_kutta4_step]
approx_limit steps; simp; intro steps';
Goals (1)
m k : ℝ
steps : ℕ
⊢ Approx (odeSolve fun t x => (1 / m * x.snd, -(k * x.fst)))
Goals (1)
m k : ℝ
steps : ℕ
⊢ Approx (odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x p',
-∇ (x':=x), H m k x' p)))
Goals (1)
m k : ℝ
steps steps' : ℕ
⊢ Approx (odeSolve_fixed_dt_impl steps' runge_kutta4_step
fun t x => (1 / m * x.snd, -(k * x.fst)))
(x, p) := solver m k substeps t (x, p) Δt
Two Main Operations: Differential and Adjoint
Differential \(\partial : (X \rightarrow Y) \rightarrow (X \rightarrow X \rightarrow Y)\)
Adjoint \(\dagger : (X \rightarrow Y) \rightarrow (Y \rightarrow X )\)
\(x:X\)
\(f \)
\(f \)
\(y:Y\)
\(x:X\)
\(\partial f \)
\(dy:Y\)
\(dx:X\)
\(\partial\)
\(x:X\)
\(f \)
\(f \)
\(y:Y\)
\(y:Y\)
\(f \)
\(f^\dagger \)
\(x:X\)
\(\dagger\)
Simplifier and Typeclass approach
Differential \(\partial\)
Adjoint \(\dagger\)
instance comp_is_smooth (f : Y → Z) (g : X → Y) [IsSmoothT f] [IsSmoothT g]
: IsSmoothT (λ x => f (g x)) := ...
@[simp]
theorem chain_rule (f : Y → Z) (g : X → Y) [IsSmoothT f] [IsSmoothT g]
: ∂ (λ x => f (g x))
=
λ x dx => ∂ f (g x) (∂ g x dx) := ...
instance comp_has_adjoint (f : Y → Z) (g : X → Y) [IsLinT f] [IsLinT g]
: IsLinT (λ x => f (g x)) := ...
@[simp]
theorem adj_of_comp (f : Y → Z) (g : X → Y) [IsLinT f] [IsLinT g]
: (λ x => f (g x))†
=
λ z => g† (f† z) := ...
Lambda calculus and SKI combinators
instance I_is_smooth
: IsSmoothT λ x : X => x := ...
@[simp]
theorem diff_of_I
: ∂ (λ x : X => x)
=
λ x dx => dx := ...
instance K_is_smooth (x : X)
: IsSmoothT λ (y : Y) => x := ...
@[simp]
theorem diff_of_K
: ∂ (λ y : Y => x)
=
λ y dy => (0 : X) := ...
instance S_is_smooth (f : X → Y → Z) (g : X → Y) [IsSmoothNT 2 f] [IsSmoothT g]
: IsSmoothT (λ x => f x (g x)) := ...
@[simp]
theorem diff_of_S (f : X → Y → Z) (g : X → Y) [IsSmoothNT 2 f] [IsSmoothT g]
: ∂ (λ x => f x (g x))
=
λ x dx =>
∂ f x dx (g x)
+
∂ (f x) (g x) (∂ g x dx) := ...
Differential \(\partial\)
Lambda calculus and SKI combinators
Differential \(\partial\)
instance (priority := low) swap_is_smooth (f : α → X → Y) [∀ a, IsSmoothT (f a)]
: IsSmoothT (λ x a => f a x) := ...
@[simp low]
theorem diff_of_swap (f : α → X → Y) [∀ a, IsSmoothT (f a)]
: ∂ (λ x a => f a x)
=
λ x dx a => ∂ (f a) x dx := ...
instance swap'_is_smooth (f : X → α → Y) [IsSmoothT f]
: ∀ a, IsSmoothT (λ x => f x a) := ...
@[simp]
theorem diff_of_swap' (f : X → α → Y) [IsSmoothT f] (a : α)
: ∂ (λ x => f x a)
=
λ x dx => ∂ f x dx a := ...
Lambda calculus and SKI combinators
Differential \(\partial\)
Adjoint \(\dagger\)
Forward Mode AD and Functoriality
\(g \)
\(\partial\)
\(x:X\)
\(\partial (f\circ g) \)
\(dz:Z\)
\(dx:X\)
\(?\)
\(x:X\)
\(f \)
\(g \)
\(f \)
\(f \)
\(z:Z\)
\(\partial f \)
\(x:X\)
\(dx:X\)
\(\partial g \)
\(y:Y\)
\(dy:Y\)
\(dz:Z\)
Forward Mode AD and Functoriality
𝒯 (λ x => f (g x))
=
λ x dx =>
let (y,dy) := 𝒯 g x dx
let (z,dz) := 𝒯 f y dy
(z,dz)
\(x:X\)
\(\mathcal{T} (f\circ g) \)
\(dz:Z\)
\(dx:X\)
\(z:Z\)
\(x:X\)
\(\mathcal{T} g \)
\(dy:Y\)
\(dx:X\)
\(y:Y\)
\(\mathcal{T} f \)
\(dz:Z\)
\(z:Z\)
\(x:X\)
\(f \)
\(g \)
\(f \)
\(f \)
\(z:Z\)
Reverse Mode AD and Functoriality
\(x:X\)
\(f \)
\(g \)
\(?\)
\(f \)
\(f \)
\(z:Z\)
\(x:X\)
\(\partial^\dagger (f\circ g) \)
\(dx:X\)
\(dz:Z\)
\(\partial^\dagger\)
\(x:X\)
\(\partial^\dagger g \)
\(dx:X\)
\(y:Y\)
\(\partial^\dagger f \)
\(dy:Y\)
\(dz:Z\)
Reverse Mode AD and Functoriality
\(x:X\)
\(f \)
\(g \)
\(?\)
\(f \)
\(f \)
\(z:Z\)
\(x:X\)
\(\mathcal{R} (f\circ g) \)
\(z:Z\)
\(\partial^\dagger (f\circ g) \, x : Z \rightarrow X \)
\(\partial^\dagger (f\circ g) \, x : Z \rightarrow X \)
ℛ (λ x => f (g x))
=
λ x =>
let (y,dg') := ℛ g x
let (z,df') := ℛ f y
(z, λ dz => dg' (df' dz))
\(\partial^\dagger\)
\(x:X\)
\(\mathcal{R} g \)
\(y:Y\)
\(\partial^\dagger g \, x : Y \rightarrow X \)
\(\mathcal{R} f \)
\(z:Z\)
\(\partial^\dagger f \, y : Z \rightarrow Y \)
Forward Mode AD and Let Bindings
∂ (λ x =>
let y := g x
let z := f y
x + y + z)
rewrite_by autodiff
fun x dx =>
let y := g x
let dy := ∂ g x dx
let z := f y
let dz := ∂ f y dy
dx + dy + dz
=
- Working prototype of differentiating through let bindings as a custom simplifier step. Currently, the step does not produce a proof.
- Ideally use \(\mathcal{T}\)
∂ (λ x =>
let y := g x
let z := f y
x + y + z)
rewrite_by autodiff
λ x dx =>
let (y,dy) := 𝒯 g x dx
let (z,dz) := 𝒯 f y dy
dx + dy + dz
=
Differentiation Approach Overview
-
Two main operations \(\partial\) and \(\dagger\)
- rules for I, K, S, C, C' and primitive functions
- their correctness needs to be proven using mathlib
-
Derived operations \(\partial^\dagger, \mathcal{T}, \mathcal{R} \)
- their rules can be easily derived from rules of \(\partial\) and \(\dagger\)
- main purpose is for efficiency - faster symbolic computation and more efficient resulting code
-
Special handling of let bindings
- works for \(\partial\) but not verified
- needs to be done for other operations
- But there are problems ...
\(\texttt{IsLinT}\) is Transitive Closure of \(\texttt{IsLin}\)
- We want:
- \(\texttt{IsLin } f\) to imply \(\texttt{IsSmooth } f\)
- \(\partial f \, x \, dx = f \, dx\) for \(f\) linear
- Problem: Attempting to prove linearity on every function is too expensive!
- Solution: Introduce a transitive closure \(\texttt{IsLinT } f\) of \(\texttt{IsLin } f\) via rules for I, K, S, C, C'.
- The class \(\texttt{IsLin } f\) is used only on elementary functions like \( (\cdot + \cdot), (\cdot * \cdot), (- \cdot)\)
Multiple Arguments
example (f : X → Y → Z) [IsSmoothNT 2 f]
: differential (λ (x,y) => f x y)
=
λ (x,y) (dx,dy) =>
∂ (λ x' => f x' y) x dx
+
∂ (λ y' => f x y') y dy
:= by symdiff; done
- Proving \(\texttt{IsSmoothT (λ x => λ y ⟿ f x y)}\) can be difficult
- This can't be done with linearity. We have special variants of I,K,S,C,C' rules for \(\texttt{IsLinNT 2 f}\), \(\texttt{IsLinNT 3 f}\), ...
example (f : X → Y → Z)
[∀ x, IsSmoothT (λ y => f x y)] [IsSmoothT (λ x => λ y ⟿ f x y)]
: [IsSmoothNT 2 f] := ...
Smoothness w.r.t to \((x,y)\) can be reduced to smoothness w.r.t to \(y\) and \(x\) (with values in \(Y ⟿ Z\))
Differentiation w.r.t to \((x,y)\) can be reduced to differentiation w.r.t to \(x\) and \(y\)
Unification Problems
Applying composition rule is difficult!
instance comp_is_smooth (f : Y → Z) (g : X → Y) [IsSmoothT f] [IsSmoothT g]
: IsSmoothT (λ x => f (g x)) := ...
Adding trailing argument trips unification
example (f' : Y → α → Z) (g : X → Y) (a : α) [IsSmoothT (fun y => f' y a)] [IsSmoothT g]
: IsSmoothT (λ x => f' (g x) a) := by try infer_instance -- fails to solve
apply comp_is_smooth (fun y => f' y a) g
unif_hint (f : Y → Z) (f' : Y → α → Z) (g : X → Y) (a : α) where
f =?= λ y => f' y a
|-
IsSmoothT (λ x => f (g x)) =?= IsSmoothT (λ x => f' (g x) a)
Solved by unification hint
It is not over ... and it goes on and on ...
example (f : Y → α → β → Z) (g : X → Y) (a : α) (b : β)
[IsSmoothT (fun y => f y a b)] [IsSmoothT g]
: IsSmoothT (λ x => f (g x) a b) := by infer_instance -- still fails!
Simp Guard
[Meta.Tactic.simp.rewrite] @SciLean.comp.arg_x.adj_simp:100, SciLean.adjoint fun x =>
f (g x) a ==> fun z => SciLean.adjoint (fun x => g x) (SciLean.adjoint (fun x => f x a) z)
[Meta.Tactic.simp.rewrite] @SciLean.comp.arg_x.adj_simp:100, SciLean.adjoint fun x =>
f x a ==> fun z => SciLean.adjoint (fun x => x) (SciLean.adjoint (fun x => f x a) z)
[Meta.Tactic.simp.rewrite] @SciLean.comp.arg_x.adj_simp:100, SciLean.adjoint fun x =>
f x a ==> fun z => SciLean.adjoint (fun x => x) (SciLean.adjoint (fun x => f x a) z)
...
example (...) :
(λ x => f (g x) a)†
=
λ x' => g† ((λ x => f x a)† x')
:= by simp -- infinite loop
unif_hint (...) where
f =?= λ x => f' x a
|-
(λ x => f (g x))† =?= (λ x => f' (g x) a)†
An unification hint can cause an infinite loop
Simp Guard
@[simp, simp_guard g (λ x => x)]
theorem comp.arg_x.adj_simp
(f : Y → Z) [HasAdjointT f]
(g : X → Y) [HasAdjointT g]
: (λ x => f (g x))† = λ z => g† (f† z) := ...
Solution: \( \texttt{simp\_guard g (λ x => x)} \)
- do not apply simp rule if \(\texttt{g}\) is equal to \(\texttt{λ x => x}\)
Problematic Adjoints
Creating constant arrays is wasteful!
Creating almost everywhere zero arrays is wasteful!
Problematic Adjoints and Unification
Matrix transposition
example {n m} (A : Fin n → Fin m → ℝ) :
(λ (x : Fin m → ℝ) => λ i => ∑ j, A i j * x j)†
=
(λ y => ∑ j i', λ i => [[i'=i]] * A j i' * y j) := by symdiff; done
@[simp]
theorem adjoint_sum_eval
(f : ι → κ → X → Y) [∀ i j, HasAdjointT (f i j)]
: (λ (x : κ → X) => λ i => ∑ j, (f i j) (x j))†
=
λ y => λ j => ∑ i, (f i j)† (y i) := ...
Ok, this is somewhat hard. Let's add a new simp rule:
Problematic Adjoints and Unification
Now we have:
Ups
unif_hint
(f? : ι → κ → X → Y)
(f : ι → κ → X → α → Y) (g : ι → κ → α)
where
f? =?= λ i j x => f i j x (g i j)
|-
(λ (x : κ → X) => λ i => ∑ j, (f? i j) (x j))†
=?=
(λ (x : κ → X) => λ i => ∑ j, f i j (x j) (g i j))†
Let's add an unification hint:
Problematic Adjoints and Unification
Now we have:
unif_hint
(f? : ι → κ → X → Y)
(f : ι → κ → W → α → Y) (g : ι → κ → α) (h : ι → κ → X → W)
where
f? =?= λ i j x => f i j (h i j x) (g i j)
|-
(λ (x : κ → X) => λ i => ∑ j, (f? i j) (x j))†
=?=
(λ (x : κ → X) => λ i => ∑ j, f i j (h i j (x j)) (g i j))†
Let's add an unification hint:
Ups
Problematic Adjoints and Unification
Now we have:
I give up
Calculus of Variations
Adjunction on Semi-Hilbert Spaces
Adjunction of \(A : (\mathbb{N}\rightarrow \R) \rightarrow (\mathbb{N}\rightarrow \R)\)
Sensible only if the matrix has only finitely many nonzero elements in every column and row
Adjunction of \(\frac{d}{dx} :(\mathbb{\R}⟿ \mathbb{R}) \rightarrow (\mathbb{\R}⟿ \mathbb{R})\)
True only for \(f\) and \(g\) compactly supported and when integrating over large enough domain.
Adjoint Failure of K, C and eval rules
Adjoint failure of K:
Adjoint failure of eval:
As a consequence we do not have C, C' (can't take an adjoint of the green arrow)
but we still have the product rule
Two Main Goals for Adjunction
example {n k} (x : Fin n → ℝ)
: (λ (w : Fin k → ℝ) => λ (i : Fin n) => ∑ j, w j * x (i + j) )†
=
λ (y : Fin n → ℝ) => λ (j : Fin k) => ∑ i, x i * y (i - j) := ...
Convolution:
Mixing \(f(x)\) and \(f'(x)\)
example (p q : ℝ ⟿ ℝ)
: (λ (f : ℝ ⟿ ℝ) => λ x ⟿ p x * f x + q x * ⅆ f x)†
=
(λ (f : ℝ ⟿ ℝ) => λ x ⟿ p x * f x - ⅆ (x':=x), q x' * f x') := ...
Manifold Like Types
Types like \(\mathbb{R}^+, \texttt{Array X}, \texttt{X} \oplus \texttt{Y}, \{ x : X, \|x\| = 1 \}, N \leftrightsquigarrow M\) are not vector spaces!
We need tangent spaces:
- \(\mathcal{T}_a (\texttt{Array X}) = \{ \texttt{b : Array X, b.size = a.size}\}\)
- \(\mathcal{T}_{x} (\texttt{X} \oplus \texttt{Y}) = \texttt{X}\) \(\mathcal{T}_{y} (\texttt{X} \oplus \texttt{Y}) = \texttt{Y}\)
- \(\mathcal{T}_x \{ x : X, \|x\| = 1 \} = \{ v : X, v \perp x \}\)
- \(\mathcal{T}_f (N ⟿ M) = (x : N) \rightsquigarrow \mathcal{T}_{f(x)} M\)
What is the good mathematical semantics? Diffeological spaces?
How to Handle Imperative Code?
∂ λ (x₀ : ℝ^{n}) => Id.run do
let mut x := x₀
for i in [0:10] do
x := x + Δt * f x
x
λ (x₀ dx₀: ℝ^{n}) => Id.run do
let mut x := x₀
let mut dx := dx₀
for i in [0:10] do
(x,dx) := (x,dx) + Δt * 𝒯 f x dx
dx
?
λ (x₀ dx₀: ℝ^{n}) => Id.run do
let mut x := x₀
let mut dx := dx₀
for i in [0:10] do
dx := dx + Δt * ∂ f x dx
x := x + Δt * f x
dx
Differentiating Monadic Code
(f∘g) := λ x =>
let (y, dg) := g x
let (z, df) := f y
(z, df∘dg)
?
(f∘g) := λ x => do let y ← g x; f y
ℱ f x = (f x, λ dx => ∂ f x dx)
pure
pure
(f∘g) := λ x => do
let (y,dg) ← g x
let (z,df) ← f y
let dfg := λ dx => do
let dy ← dg dx
df dy
pure (z, dfg)
Maybe Kan extension?
Local Differentiability and Reduced Regularity
Non-smooth functions that we want to differentiate:
Potential approaches:
- Smooth approximation
- \( \|x\| \approx \|x\|_\epsilon = \sqrt{ \|x\|^2 + \epsilon^2 }\)
- \( \frac1x \approx \frac{x}{\|x\|^2_\epsilon} \)
- Local smoothness, new predicate: \(\texttt{IsSmoothAt} f \, \, x\)
- type class inference would need to automatically prove for example that \(x \neq 0\)
- Work in category with maps \(\text{Lip}\), \(\text{C}^k\) or \(\text{C}^{k,\alpha}\).
- Are these categories Cartesian closed?
- Does some form of chain rule hold for Lipschitz functions?
deck
By lecopivo
deck
- 143