Automatic Differentiation in Lean
Tomáš Skřivan
Carnegie Mellon University
9.1.2024
What is Automatic Differentiation?
- automatic differentiation (AD) synthesizes derivatives of a program
def foo (x : Float) : Float := x^2
def foo' (x : Float) : Float := 2*x
- application in machine learning, optimization, and scientific computing
AD
- SciLean - library for scientific computing
- https://github.com/lecopivo/SciLean
Why Lean for Automatic Differentiation?
- in scientific computing we work a lot with approximations and they do not interact well with automatic differentiation
- traditional AD systems mostly ignore the problem of differentiability
def fast_sin (x : Float) : Float :=
if x = 0 then
0
else
sin x
def fast_sin' (x : Float) : Float :=
if x = 0 then
0
else
cos x
AD
incorrect
- nice overview article: [arXiv2305.07546] Understanding Automatic Differentiation Pitfalls, 2023
Why Lean for Automatic Differentiation?
- in Lean we can work with ideal mathematical objects and treat any code symbolically, thus we can mix techniques from automatic and symbolic differentiation
noncomputable
def odeSolve (f : X → X) (t : ℝ) (x₀ : X) : X :=
if h : ∃ x : ℝ → X,
∀ t', deriv x t' = f (x t')
∧
x 0 = x₀
then choose h t
else 0
theorem odeSolve_deriv
(...)
: let x := fun t' => odeSolve f t' x₀
deriv x t
=
f (x t) := ...
theorem odeSolve_approx
(...)
: odeSolve f t x₀
=
limit n → ∞,
let mut Δt := t/n
let mut x := x₀
for i in [0:n] do
x := x + Δt * f x
x
\begin{align*}
x'(t) &= f(x(t)) \\
x(0) &= x_0
\end{align*}
Bigger goal: Lean for Scientific Computing
m \ddot x_i = \sum_j G \frac{m_i m_j}{r_{ij}^2}\hat r_{ij}
\begin{align*}
v_i^{n+1} &= v_i^n + \Delta t \, \sum_j F(x_i^n,x_j^n) \\
x_i^{n+1} &= x_i^n + \Delta t \, v_i^{n+1}
\end{align*}
def update (x v : Array Vec3) :=
v := v + Δt * force x
x := x + Δt * v
(x,v)
+
+
+
Back to Automatic Differentiation
How to implement AD in Lean?
FTrans: General Function Transformation
T (f \circ g) = T f \circ T g
(f \circ g)'(x) = f'(g(x)) \circ g'(x)
- ftrans - tactic for generic function transformation \(T\)
(f \circ g)^\dagger = g^\dagger \circ f^\dagger
- differentiation
- adjoint
- forward/reverse mode AD, vectorization, isomorphism, ...
example (x : ℝ)
: fderiv ℝ (fun x : ℝ => x^2) x
=
fun dx =>L[ℝ] 2 * dx * x := by ftrans
example
: (adjoint (fun x : ℂ =>L[ℂ] I*x))
=
fun y =>L[ℂ] -I * y := by ftrans
FTrans: General Function Transformation
theorem fderiv_comp
{g : E → F} {f : F → G}
(hg : Differentiable 𝕜 g) (hf : Differentiable 𝕜 f) :
: fderiv 𝕜 (fun x => (f (g x)))
=
fun x => fun dx =>L[K] fderiv K f (g x) (fderiv K g x dx) := ...
- three steps to define new function transformation T
- basic lambda calculus rules
theorem fderiv_id
: (fderiv K fun x : X => x) = fun _ => fun dx =>L[K] dx := ...
T fun x => x = ...
T fun x => y = ...
T fun f => f i = ...
T fun x => f (g x) = ...
T fun x => let y := g x; f x y = ...
T fun x i => f i x = ...
FTrans: General Function Transformation
- three steps to define new function transformation T
- basic lambda calculus rules
- register new function transformation
FTrans: General Function Transformation
- three steps to define new function transformation T
- basic lambda calculus rules
- register new function transformation
- mark rules with @[ftrans]
@[ftrans]
theorem fderiv_add
(f g : X → Y) (hf : Differentiable K f) (hg : Differentiable K g)
: (fderiv K fun x => f x + g x)
=
fun x => fderiv K f x + fderiv K g x := ...
- example: SciLean.Core.FunctionTransformations.FDeriv.lean
FProp: Proving General Function Property
P f \rightarrow P g \rightarrow P (f \circ g)
- fprop - tactic for proving function property \(T\)
example : Continuous fun x : ℝ => x^2 := by fprop
example : Differentiable ℝ fun x : ℝ => x^2 := by fprop
example : Continuous fun x : ℝ =>
let x1 := x * x
let x2 := x1 * x1
let x3 := x2 * x2
let x4 := x3 * x3
let x5 := x4 * x4
x5 := by fprop
- differentiability, linearity, continuity, ...
FProp: Proving General Function Property
P f \rightarrow P g \rightarrow P (f \circ g)
- fprop - tactic for proving function property \(T\)
- differentiability, linearity, continuity, ...
- similar to to mathlib's continuity, but the focus is on
- integration with ftrans which caches results of fprop
FProp: Proving General Function Property
P f \rightarrow P g \rightarrow P (f \circ g)
- fprop - tactic for proving function property \(T\)
- differentiability, linearity, continuity, ...
- similar to to mathlib's continuity, but the focus is on
- integration with ftrans which caches results of fprop
- speed, continuity ~ 1s, fprop ~ 3ms
example
: Continuous
fun x : ℝ =>
let x1 := x + x
let x2 := x1 + x1
let x3 := x2 + x2
let x4 := x3 + x3
let x5 := x4 + x4
let x6 := x5 + x5
x6:= by ...
FProp: Proving General Function Property
P f \rightarrow P g \rightarrow P (f \circ g)
- fprop - tactic for proving function property \(T\)
- differentiability, linearity, continuity, ...
- similar to to mathlib's continuity, but the focus is on
- integration with ftrans which caches results of fprop
- speed, continuity ~ 1s, fprop ~ 3ms
- functions with multiple arguments
example (f : ℝ → ℝ → ℝ) (hf : Continuous (fun (x,y) => f x y))
: Continuous (λ x => f x x) := by fprop
example {n} (f : ℝ → (Fin n → ℝ)) (g : ℝ → ℝ)
(hf : Continuous f) (hg : Continuous g)
: Continuous (λ x => (f (g x))) := by fprop
FProp: Proving General Function Property
- three steps to define new function property P
- basic lambda calculus rules
- register new function property
- mark rules with @[fprop]
@[fprop]
theorem Differentiable.add
(f g : X → Y) (hf : Differentiable R f) (hg : Differentiable R g)
: Differentiable R fun x => f x + g x := ...
Demo
Demo
Conclusion
- SciLean - library for scientific computing https://github.com/lecopivo/SciLean
- forward and reverse mode AD implemented using general function transformation framework
- ftrans - tactic to perform function transformation
- fprop - to prove function properties
- Future work
- documentation
- better user experience
- better integration with mathlib - fderiv depens on norm
- prove all transformation rules
- mutable variables, manifolds/dependent functions, variational calculus, ...
Automatic differentiation in Lean
By lecopivo
Automatic differentiation in Lean
- 122