Automatic Differentiation in Lean

Tomáš Skřivan

Carnegie Mellon University

9.1.2024

What is Automatic Differentiation?

  • automatic differentiation (AD) synthesizes derivatives of a program
def foo (x : Float) : Float := x^2
def foo' (x : Float) : Float := 2*x
  • application in machine learning, optimization, and scientific computing

AD

  • SciLean - library for scientific computing 
    • https://github.com/lecopivo/SciLean

Why Lean for Automatic Differentiation?

  • in scientific computing we work a lot with approximations and they do not interact well with automatic differentiation
  • traditional AD systems mostly ignore the problem of differentiability
def fast_sin (x : Float) : Float := 
  if x = 0 then
    0
  else 
    sin x
def fast_sin' (x : Float) : Float := 
  if x = 0 then
    0
  else 
    cos x

AD

incorrect

  • nice overview article: [arXiv2305.07546] Understanding Automatic Differentiation Pitfalls, 2023

Why Lean for Automatic Differentiation?

  • in Lean we can work with ideal mathematical objects and treat any code symbolically, thus we can mix techniques from automatic and symbolic differentiation
noncomputable
def odeSolve (f : X → X) (t : ℝ) (x₀ : X) : X :=
  if h : ∃ x : ℝ → X, 
    ∀ t', deriv x t' = f (x t') 
    ∧ 
    x 0 = x₀
  then choose h t
  else 0
theorem odeSolve_deriv
  (...)
  : let x := fun t' => odeSolve f t' x₀
  	deriv x t
    =
    f (x t) := ...
theorem odeSolve_approx 
  (...)
  : odeSolve f t x₀
    =
    limit n → ∞, 
      let mut Δt := t/n
      let mut x := x₀
      for i in [0:n] do
        x := x + Δt * f x
      x
\begin{align*} x'(t) &= f(x(t)) \\ x(0) &= x_0 \end{align*}

Bigger goal: Lean for Scientific Computing

m \ddot x_i = \sum_j G \frac{m_i m_j}{r_{ij}^2}\hat r_{ij}
\begin{align*} v_i^{n+1} &= v_i^n + \Delta t \, \sum_j F(x_i^n,x_j^n) \\ x_i^{n+1} &= x_i^n + \Delta t \, v_i^{n+1} \end{align*}
def update (x v : Array Vec3) :=
  v := v + Δt * force x
  x := x + Δt * v
  (x,v)

+

+

+

Back to Automatic Differentiation

How to implement AD in Lean?

FTrans: General Function Transformation

T (f \circ g) = T f \circ T g
(f \circ g)'(x) = f'(g(x)) \circ g'(x)
  • ftrans - tactic for generic function transformation \(T\)
(f \circ g)^\dagger = g^\dagger \circ f^\dagger
  • differentiation
  • adjoint
  • forward/reverse mode AD, vectorization, isomorphism, ...
example (x : ℝ)
  : fderiv ℝ (fun x : ℝ => x^2) x
    =
    fun dx =>L[ℝ] 2 * dx * x := by ftrans
example 
  : (adjoint (fun x : ℂ =>L[ℂ] I*x))
    =
    fun y =>L[ℂ] -I * y := by ftrans

FTrans: General Function Transformation

theorem fderiv_comp 
	{g : E → F} {f : F → G} 
	(hg : Differentiable 𝕜 g) (hf : Differentiable 𝕜 f) :
	: fderiv 𝕜 (fun x => (f (g x))) 
      = 
      fun x => fun dx =>L[K] fderiv K f (g x) (fderiv K g x dx) := ...
  • three steps to define new function transformation T
    • basic lambda calculus rules
theorem fderiv_id 
  : (fderiv K fun x : X => x) = fun _ => fun dx =>L[K] dx := ...
T fun x => x = ...
T fun x => y = ...
T fun f => f i = ...
T fun x => f (g x) = ...
T fun x => let y := g x; f x y = ...
T fun x i => f i x = ...

FTrans: General Function Transformation

  • three steps to define new function transformation T
    • basic lambda calculus rules
    • register new function transformation

FTrans: General Function Transformation

  • three steps to define new function transformation T
    • basic lambda calculus rules
    • register new function transformation
    • mark rules with @[ftrans]
@[ftrans]
theorem fderiv_add
  (f g : X → Y) (hf : Differentiable K f) (hg : Differentiable K g)
  : (fderiv K fun x => f x + g x)
    =
    fun x => fderiv K f x + fderiv K g x := ...
  • example: SciLean.Core.FunctionTransformations.FDeriv.lean

FProp: Proving General Function Property

P f \rightarrow P g \rightarrow P (f \circ g)
  • fprop - tactic for proving function property \(T\)
example : Continuous fun x : ℝ => x^2 := by fprop
example : Differentiable ℝ fun x : ℝ => x^2 := by fprop
example : Continuous fun x : ℝ =>
        let x1 := x * x
        let x2 := x1 * x1
        let x3 := x2 * x2
        let x4 := x3 * x3
        let x5 := x4 * x4
        x5 := by fprop
  • differentiability, linearity, continuity, ...

FProp: Proving General Function Property

P f \rightarrow P g \rightarrow P (f \circ g)
  • fprop - tactic for proving function property \(T\)
  • differentiability, linearity, continuity, ...
  • similar to to mathlib's continuity, but the focus is on
    • integration with ftrans which caches results of fprop

FProp: Proving General Function Property

P f \rightarrow P g \rightarrow P (f \circ g)
  • fprop - tactic for proving function property \(T\)
  • differentiability, linearity, continuity, ...
  • similar to to mathlib's continuity, but the focus is on
    • integration with ftrans which caches results of fprop
    • speed, continuity ~ 1s, fprop ~ 3ms
example 
  : Continuous 
      fun x : ℝ => 
        let x1 := x + x
        let x2 := x1 + x1
        let x3 := x2 + x2
        let x4 := x3 + x3
        let x5 := x4 + x4
        let x6 := x5 + x5
        x6:= by ...

FProp: Proving General Function Property

P f \rightarrow P g \rightarrow P (f \circ g)
  • fprop - tactic for proving function property \(T\)
  • differentiability, linearity, continuity, ...
  • similar to to mathlib's continuity, but the focus is on
    • integration with ftrans which caches results of fprop
    • speed, continuity ~ 1s, fprop ~ 3ms
    • functions with multiple arguments
example (f : ℝ → ℝ → ℝ) (hf : Continuous (fun (x,y) => f x y)) 
  : Continuous (λ x => f x x) := by fprop

example {n} (f : ℝ → (Fin n → ℝ)) (g : ℝ → ℝ)
  (hf : Continuous f) (hg : Continuous g) 
  : Continuous (λ x => (f (g x))) := by fprop

FProp: Proving General Function Property

  • three steps to define new function property P
    • basic lambda calculus rules
    • register new function property
    • mark rules with @[fprop]
@[fprop]
theorem Differentiable.add
  (f g : X → Y) (hf : Differentiable R f) (hg : Differentiable R g)
  : Differentiable R fun x => f x + g x := ...

Demo

Demo

Conclusion

  • SciLean - library for scientific computing     https://github.com/lecopivo/SciLean
  • forward and reverse mode AD implemented using general function transformation framework
  • ftrans - tactic to perform function transformation
  • fprop - to prove function properties
  • Future work
    • documentation
    • better user experience
    • better integration with mathlib - fderiv depens on norm
    • prove all transformation rules
    • mutable variables, manifolds/dependent functions, variational calculus, ...