Scientific Computing in Lean 4

Tomáš Skřivan

Carnegie Mellon University

Why Scientific Computing in Lean?

  • scientific computing is full of mathematics + Lean understands mathematics
  • Lean 4 is efficient general purpose programming language
  • Lean metaprograming is very powerful and easy to use
  • mathlib - a big library of formalized mathematics
  • something interesting has to come out of this

SciLean

library for scientific computing in Lean 4

Motivation: Physics Simulation in Computer Graphics

  work done at

Motivation

Sources of frustration:

  • the mathematics is not in the code
    • data often form algebraic structures
  • composability and working with libraries
    • libraries often impose data structures on the input output
    • lack of well specified interfaces
  • prototyping
    • mathematical description is often quite easy but implementation is not
  • multi physics problems
    • very difficult to make two solvers to talks to each other
    • usually ends up in extremely complicated monolithic code

This video is from the SIGGRAPH 2022 Technical Paper: ‘Loki: a unified multiphysics simulation framework for production’.

Scientific Computing and Lean

What can Lean offer?

  • the mathematics is not in the code
    • the code is the math and the math is the code
    • interactivity can effectively offer computer algebra system operating directly on the code
  • composability and working with libraries
    • interfaces can be given on mathematical level rather on the data level
    • approximations compose badly
    • interfaces be specified very precisely
  • prototyping
    • interactively transform specification into code
  • multi physics problems
    • combining specifications is likely easier then combining implementations

Example: Harmonic Oscillator

\begin{aligned} \dot x &= \frac{\partial H}{\partial p} \\ \dot p &= - \frac{\partial H}{\partial x} \end{aligned}
def H (m k : ℝ) (x p : ℝ) := (1/(2*m)) * ∥p∥² + k/2 * ∥x∥²

approx solver (m k : ℝ) (steps : Nat)
  := odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x  p',
                             -∇ (x':=x), H m k x' p))
by
  -- Unfold Hamiltonian and compute gradients
  unfold H
  symdiff; symdiff

  -- Apply RK4 method
  rw [odeSolve_fixed_dt runge_kutta4_step]
  approx_limit steps; simp; intro steps';
Goals (1)
m k :  ℝ
steps : ℕ
⊢ Approx (odeSolve fun t x => (1 / m * x.snd, -(k * x.fst)))
Goals (1)
m k :  ℝ
steps : ℕ
⊢ Approx (odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x  p',
                                  -∇ (x':=x), H m k x' p)))
Goals (1)
m k :  ℝ
steps steps' : ℕ
⊢ Approx (odeSolve_fixed_dt_impl steps' runge_kutta4_step 
			fun t x => (1 / m * x.snd, -(k * x.fst)))
(x, p) := solver m k substeps t (x, p) Δt
\texttt{∇ (x':=x), f x'} = \frac{d f}{d x'} \big|_{x'=x}

\(x(t) = \texttt{odeSolve f t₀ x₀ t}\)

\(\Leftrightarrow\)

\(\dot x(t) = f(t,x(t))     x(t_0) = x_0\)

Two Main Operations of AD: Differential and Adjoint

Differential \(\partial : (X \rightarrow Y) \rightarrow (X \rightarrow X \rightarrow Y)\)

Adjoint \(\dagger : (X \rightarrow Y) \rightarrow (Y \rightarrow X )\)

\partial \, f\, x \, dx = \lim_{h\rightarrow 0} \frac{f (x + h \cdot dx) - f(x)}{h}

\(x:X\)

\(f \)

\(f \)

\(y:Y\)

\(x:X\)

\(\partial f \)

\(dy:Y\)

\(dx:X\)

 \(\partial\)

\(x:X\)

\(f \)

\(f \)

\(y:Y\)

\(y:Y\)

\(f \)

\(f^\dagger \)

\(x:X\)

\(\dagger\)

\forall \, x \, y, \left\langle f(x), y \right\rangle = \left\langle x, f^\dagger (y) \right\rangle

Simplifier and Typeclass approach

Differential \(\partial\)

Adjoint \(\dagger\)

instance comp_is_smooth (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
  : IsSmooth (λ x => f (g x)) := ...

@[simp]
theorem chain_rule (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
  : ∂ (λ x => f (g x)) 
    = 
    λ x dx => ∂ f (g x) (∂ g x dx) := ...
instance comp_has_adjoint (f : Y → Z) (g : X → Y) [IsLin f] [IsLin g]
  : IsLin (λ x => f (g x)) := ...

@[simp]
theorem adj_of_comp (f : Y → Z) (g : X → Y) [IsLin f] [IsLin g] 
  : (λ x => f (g x))† 
  	=
    λ z => g† (f† z) := ...

Other operators

  • derivative \(f'\)  for \(f : \mathbb{R} \rightarrow X\)
    • \(f'(x) = \texttt{∂ f x 1}\)   
  • gradient \(\nabla f\)  for \(f : X \rightarrow \mathbb{R}\)
    • \(\nabla f(x) = \texttt{(∂ f x)† 1}\)
  • adjoint differential \(\partial^\dagger\)
    • \(\texttt{∂† f x dy = (∂ f x)† dy}\)
  • tangent map \(\mathcal{T}\) - forward mode AD
    • \(\mathcal{T}\texttt{ f x dx = (f x, ∂ f x dx)}\)
  • reverse differential \(\mathcal{R}\) - reverse mode AD
    •  \(\mathcal{R}\texttt{ f x = (f x, λ dy => (∂ f x)† dy)}\)
  • complexification \(f^c\) 
    • \(f^c : \mathbb{C} \rightarrow \mathbb{C}\)  for  \(f : \mathbb{R} \rightarrow \mathbb{R}\)
  • inverse \(f^{-1}\)

Forward Mode AD and Functoriality

\(g \)

 \(\partial\)

\(x:X\)

\(\partial (f\circ g) \)

\(dz:Z\)

\(dx:X\)

 \(?\)

\(x:X\)

\(f \)

\(g \)

\(f \)

\(f \)

\(z:Z\)

\(\partial f \)

\(x:X\)

\(dx:X\)

\(\partial g \)

\(y:Y\)

\(dy:Y\)

\(dz:Z\)

Forward Mode AD and Functoriality

𝒯 (λ x => f (g x)) 
= 
λ x dx =>
  let (y,dy) := 𝒯 g x dx
  let (z,dz) := 𝒯 f y dy
  (z,dz)

\(x:X\)

\(\mathcal{T} (f\circ g) \)

\(dz:Z\)

\(dx:X\)

\(z:Z\)

\(x:X\)

\(\mathcal{T} g \)

\(dy:Y\)

\(dx:X\)

\(y:Y\)

\(\mathcal{T} f \)

\(dz:Z\)

\(z:Z\)

\(x:X\)

\(f \)

\(g \)

\(f \)

\(f \)

\(z:Z\)

Reverse Mode AD and Functoriality

\(x:X\)

\(f \)

\(g \)

 \(?\)

\(f \)

\(f \)

\(z:Z\)

\(x:X\)

\(\partial^\dagger (f\circ g) \)

\(dx:X\)

\(dz:Z\)

 \(\partial^\dagger\)

\(x:X\)

\(\partial^\dagger g \)

\(dx:X\)

\(y:Y\)

\(\partial^\dagger f \)

\(dy:Y\)

\(dz:Z\)

Reverse Mode AD and Functoriality

\(x:X\)

\(f \)

\(g \)

 \(?\)

\(f \)

\(f \)

\(z:Z\)

\(x:X\)

\(\mathcal{R} (f\circ g) \)

\(z:Z\)

\(\partial^\dagger (f\circ g) \, x : Z \rightarrow X \)

\(\partial^\dagger (f\circ g) \, x : Z \rightarrow X \)

ℛ (λ x => f (g x)) 
    = 
    λ x => 
      let (y,dg') := ℛ g x
      let (z,df') := ℛ f y
      (z, λ dz => dg' (df' dz))

 \(\partial^\dagger\)

\(x:X\)

\(\mathcal{R} g \)

\(y:Y\)

\(\partial^\dagger g \, x : Y \rightarrow X \)

\(\mathcal{R} f \)

\(z:Z\)

\(\partial^\dagger f \, y : Z \rightarrow Y \)

 \(\mathcal{R}\texttt{ f x = (f x, λ dy => (∂ f x)† dy)}\)

Example: Ballistic motion

\begin{aligned} \dot x &= v \\ \dot v &= g - \gamma f(v) \\ x(0) &= 0 \qquad v(0) = v_0 \end{aligned}

For \(T\in \mathbb{R}, x_T \in \mathbb{R}^2\) find \(v_0\) such that \(x(T) = x_T\)

We will solve necessary condition

   \(v_0 = \argmin_{v_0^*} \|x(T) - x_T\|^2 \qquad \text{where } \dot x(0) = v_0^*\)

Problem:

To compute \(\frac{\partial x(T)}{\partial v_0}\) we have to solve adjoint problem

\begin{aligned} \dot x^* &= 0 \\ \dot v^* &= x^* - \gamma f'(v) v^* \\ x^*_T(T) &= x(T) - x_T \qquad v^*(T) = 0 \end{aligned}
\frac{\partial x(T)}{\partial v_0} = 2 v^*(0)

Example: Ballistic motion

\begin{aligned} \dot x &= v \\ \dot v &= g - \gamma f(v) \\ x(0) &= 0 \qquad v(0) = v_0 \end{aligned}

For \(T\in \mathbb{R}, x_T \in \mathbb{R}^2\) find \(v_0\) such that \(x(T) = x_T\)

approx aimToTarget (init_v : ℝ×ℝ) (optimizationRate : ℝ) := 
  λ (T : ℝ) (target : ℝ×ℝ) =>
    let shoot := λ (t : ℝ) (v : ℝ×ℝ) =>
                   odeSolve (t₀ := 0) (x₀ := ((0:ℝ×ℝ),v)) (t := t)
                     (f := λ (t : ℝ) (x,v) => balisticMotion x v)
    (λ v => (shoot T v).1)⁻¹ target
def ballisticMotion (x v : ℝ×ℝ) := (v, g  - (5 + ‖v‖) • v)

SciLean summary

  • library for scientific computing
  • aim is to help users with the mathematics
  • provide guarantees when possible
  • current focus is on automatic differentiation
  • priorities:
    1. user experience
    2. performance
    3. correctness
    4. formal correctness

Talk at Certified and Symbolic-Numeric Computation

By lecopivo

Talk at Certified and Symbolic-Numeric Computation

  • 303