Using Interactive Theorem Prover for Scientific Computing

Tomáš Skřivan

Carnegie Mellon University

Hoskinson Center for Formal Mathematics

The Road to Differentiable and Probabilistic Programming in Fundamental Physics

27.06.2023

Overview

  • What is interactive theorem prover?
  • Lean 4 as programming language for scientific computing
  • Harmonic oscillator example
  • Approach to automatic differentiation
  • Ballistic motion optimization example

Functional programming notation primer

Function application

f(x,y) = f x y
x : X

f : X → Y → Z


    X → Y → Z ≅ X×Y → Z 
    	      = X → (Y → Z)
              ≠ (X → Y) → Z

Type ascription

Lambda function

λ x => x + sin x
f(g(x),y) = f (g x) y

What is interactive theorem prover?

Interactive theorem prover:

  • allows to state and prove mathematical theorems
  • proofs are done via human-computer interaction
inductive nat where
| zero : nat
| succ (n : nat) : nat
def add (n m : nat) : nat :=
match m with
| zero    => n                 -- n + 0 = n
| succ m' => succ (add n m')   -- n + (m + 1) = (n + m) + 1
infix:1000 "+" => add

theorem add_comm (m n : nat) : m + n = n + m := by <proof>

Example Peano numbers

Lean 4 as programming language for scientific computing

  • interactive theorem prover
  • functional programming language
  • supports imperative style and mutable variables
  • custom notation
  • ability to talk about mathematical objects
  • write programs as mathematical specification and then interactively apply approximations and code transformations(like optimization and AD)

Lean 4 as programming language for scientific computing

  • interactive theorem prover
  • functional programming language
  • supports imperative style and mutable variables
  • custom notation
  • ability to talk about mathematical objects
  • write programs as mathematical specification and then interactively apply approximations and code transformations(like optimization and AD)









def jacobi_method (A : ℝ^(n,n)) (x₀ y : ℝ^n) : ℝ^n :=
  let iD := inv_diag A
  let U  := upper_triangular A
  let L  := lower_triangular A
  
  let mut x := x₀
  for i in [0:max_steps] do
    x := iD (y - (L + U) x)
    if ‖A x - y‖ < ε then
      break
  return x
x^{(0)} = x_0 \\ x^{(n+1)} = D^{-1}( y - (L + U) x^{(n)})

Lean 4 as programming language for scientific computing

  • interactive theorem prover
  • functional programming language
  • supports imperative style and mutable variables
  • custom notation
  • ability to talk about mathematical objects
  • write programs as mathematical specification and then interactively apply approximations and code transformations(like optimization and AD)
macro " ∂ " x:ident "," b:term : term => `(deriv fun $x => $b)



              deriv fun x => f x =
              
                        ∂ x, f x = 





integral Ω fun x => norm2 (gradient u x) = 

                       ∫ x ∈ Ω, ‖∇ u x‖² = 
\frac{\partial f(x)}{\partial x}
\int_\Omega \| \nabla u(x) \|^2 \,dx

Lean 4 as programming language for scientific computing

noncomputable 
def deriv {X Y : Type} [Vec X] [Vec Y] (f : X → Y) (x dx : X) : Y := 
  if h : ∃ dy : Y, 
    Tendsto (fun h : ℝ => (1/h) • (f (x + h • dx) - f x)) (nhds 0) (nhds dy)
  then Classical.choose h
  else 0

noncomputable 
def adjoint {X Y : Type} [Hilbert X] [Hilbert Y] (f : X → Y) : Y → X := 
  if h : ∃ f' : Y → X, 
    ∀ x y, IsLin f 
           ∧
    	   ⟪f x, y⟫ = ⟪x, f' y⟫
  then Classical.choose h
  else 0

noncomputable
def odeSolve {X : Type} [Vec X] (f : ℝ → X → X) (t₀ : ℝ) (x₀ : X) (t : ℝ) : X :=
  if h : ∃ x : ℝ → X, 
             ∀ t, ⅆ x t = f t (x t) 
             ∧ 
             x t₀ = x₀
  then Classical.choose h t
  else 0
  • interactive theorem prover
  • functional programming language
  • supports imperative style and mutable variables
  • custom notation
  • ability to talk about mathematical objects
  • write programs as mathematical specification and then interactively apply approximations and code transformations(like optimization and AD)

Lean 4 as programming language for scientific computing

  • interactive theorem prover
  • functional programming language
  • supports imperative style and mutable variables
  • custom notation
  • ability to talk about mathematical objects
  • write programs as mathematical specification and then interactively apply approximations and code transformations(like optimization and AD)

Example: Harmonic Oscillator

\begin{aligned} \dot x &= \frac{\partial H}{\partial p} \\ \dot p &= - \frac{\partial H}{\partial x} \end{aligned}
def H (m k : ℝ) (x p : ℝ) := (1/(2*m)) * ∥p∥² + k/2 * ∥x∥²

approx solver (m k : ℝ) (steps : Nat)
  := odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x  p',
                             -∇ (x':=x), H m k x' p))
by
  -- Unfold Hamiltonian and compute gradients
  unfold H
  symdiff

  -- Apply RK4 method
  rw [odeSolve_fixed_dt runge_kutta4_step]
  approx_limit steps; simp; intro steps';
Goals (1)
m k :  ℝ
steps : ℕ
⊢ Approx (odeSolve fun t x => (1 / m * x.snd, -(k * x.fst)))
Goals (1)
m k :  ℝ
steps : ℕ
⊢ Approx (odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x  p',
                                  -∇ (x':=x), H m k x' p)))
Goals (1)
m k :  ℝ
steps steps' : ℕ
⊢ Approx (odeSolve_fixed_dt_impl steps' runge_kutta4_step 
			fun t x => (1 / m * x.snd, -(k * x.fst)))
(x, p) := solver m k substeps t (x, p) Δt

Two Main Operations of AD: Differential and Adjoint

Differential \(\partial : (X \rightarrow Y) \rightarrow (X \rightarrow X \rightarrow Y)\)

Adjoint \(\dagger : (X \rightarrow Y) \rightarrow (Y \rightarrow X )\)

\partial \, f\, x \, dx = \lim_{h\rightarrow 0} \frac{f (x + h \cdot dx) - f(x)}{h}

\(x:X\)

\(f \)

\(f \)

\(y:Y\)

\(x:X\)

\(\partial f \)

\(dy:Y\)

\(dx:X\)

 \(\partial\)

\(x:X\)

\(f \)

\(f \)

\(y:Y\)

\(y:Y\)

\(f \)

\(f^\dagger \)

\(x:X\)

\(\dagger\)

\forall \, x \, y, \left\langle f(x), y \right\rangle = \left\langle x, f^\dagger (y) \right\rangle

Other operators

  • derivative \(f'\)  for \(f : \mathbb{R} \rightarrow X\)
    • \(f'(x) = \texttt{∂ f x 1}\)   
  • gradient \(\nabla f\)  for \(f : X \rightarrow \mathbb{R}\)
    • \(\nabla f(x) = \texttt{(∂ f x)† 1}\)
  • adjoint differential \(\partial^\dagger\)
    • \(\texttt{∂† f x dy = (∂ f x)† dy}\)
  • tangent map \(\mathcal{T}\) - forward mode AD
    • \(\mathcal{T}\texttt{ f x dx = (f x, ∂ f x dx)}\)
  • reverse differential \(\mathcal{R}\) - reverse mode AD
    •  \(\mathcal{R}\texttt{ f x = (f x, λ dy => (∂ f x)† dy)}\)
  • complexification \(f^c\) 
    • \(f^c : \mathbb{C} \rightarrow \mathbb{C}\)  for  \(f : \mathbb{R} \rightarrow \mathbb{R}\)
  • inverse \(f^{-1}\)

Forward Mode AD and Functoriality

\(g \)

 \(\partial\)

\(x:X\)

\(\partial (f\circ g) \)

\(dz:Z\)

\(dx:X\)

 \(?\)

\(x:X\)

\(f \)

\(g \)

\(f \)

\(f \)

\(z:Z\)

\(\partial f \)

\(x:X\)

\(dx:X\)

\(\partial g \)

\(y:Y\)

\(dy:Y\)

\(dz:Z\)

Forward Mode AD and Functoriality

𝒯 (λ x => f (g x)) 
= 
λ x dx =>
  let (y,dy) := 𝒯 g x dx
  let (z,dz) := 𝒯 f y dy
  (z,dz)

\(x:X\)

\(\mathcal{T} (f\circ g) \)

\(dz:Z\)

\(dx:X\)

\(z:Z\)

\(x:X\)

\(\mathcal{T} g \)

\(dy:Y\)

\(dx:X\)

\(y:Y\)

\(\mathcal{T} f \)

\(dz:Z\)

\(z:Z\)

\(x:X\)

\(f \)

\(g \)

\(f \)

\(f \)

\(z:Z\)

Reverse Mode AD and Functoriality

\(x:X\)

\(f \)

\(g \)

 \(?\)

\(f \)

\(f \)

\(z:Z\)

\(x:X\)

\(\partial^\dagger (f\circ g) \)

\(dx:X\)

\(dz:Z\)

 \(\partial^\dagger\)

\(x:X\)

\(\partial^\dagger g \)

\(dx:X\)

\(y:Y\)

\(\partial^\dagger f \)

\(dy:Y\)

\(dz:Z\)

Reverse Mode AD and Functoriality

\(x:X\)

\(f \)

\(g \)

 \(?\)

\(f \)

\(f \)

\(z:Z\)

\(x:X\)

\(\mathcal{R} (f\circ g) \)

\(z:Z\)

\(\partial^\dagger (f\circ g) \, x : Z \rightarrow X \)

\(\partial^\dagger (f\circ g) \, x : Z \rightarrow X \)

ℛ (λ x => f (g x)) 
    = 
    λ x => 
      let (y,dg') := ℛ g x
      let (z,df') := ℛ f y
      (z, λ dz => dg' (df' dz))

 \(\partial^\dagger\)

\(x:X\)

\(\mathcal{R} g \)

\(y:Y\)

\(\partial^\dagger g \, x : Y \rightarrow X \)

\(\mathcal{R} f \)

\(z:Z\)

\(\partial^\dagger f \, y : Z \rightarrow Y \)

 \(\mathcal{R}\texttt{ f x = (f x, λ dy => (∂ f x)† dy)}\)

Approach to AD in Lean 4

Derivative \(\partial\)

Adjoint \(\dagger\)

instance comp_is_smooth (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
  : IsSmooth (λ x => f (g x)) := ...

@[simp]
theorem chain_rule (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
  : ∂ x, f (g x)
    = 
    λ x dx => ∂ f (g x) (∂ g x dx) := ...
instance comp_has_adjoint (f : Y → Z) (g : X → Y) [IsLin f] [IsLin g]
  : IsLin (λ x => f (g x)) := ...

@[simp]
theorem adj_of_comp (f : Y → Z) (g : X → Y) [IsLin f] [IsLin g] 
  : (λ x => f (g x))† 
  	=
    λ z => g† (f† z) := ...

Approach to AD in Lean 4

Forward mode \( \mathcal{T} \)

Reverse mode \(\mathcal{R}\)

noncomputable
def 𝒯 (f : X → Y) (x dx : X) : Y×Y := (f x, ∂ f x dx)

@[simp]
theorem fd_chain_rule (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
  : 𝒯 (λ x => f (g x)) 
    = 
    λ x dx => 
      let (y,dy) := 𝒯 g x dx
      𝒯 f y dy 
  := ...
noncomputable
def ℛ (f : X → Y) (x : X) : Y×(Y→X) := (f x, (∂ f x)†) 

@[simp]
theorem rd_chain_rule (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
  : ℛ (λ x => f (g x)) 
    = 
    λ x => 
      let (y,dg') := ℛ g x
      let (z,df') := ℛ f y
      (z, λ dz => dg' (df' dz))
  := ...

Difficulty with traditional automatic differentiation

\(\frac{\partial e^x}{\partial x}\)

approximation

differentiation

\(e^x\)

\(\frac{\partial}{\partial x} \sum_{i=0}^N \frac{x^n}{n!}\)

\(\sum_{i=0}^{N-1} \frac{x^n}{n!}\)

\(\sum_{i=0}^{N} \frac{x^n}{n!}\)

differentiation

approximation

Difficulty with traditional automatic differentiation

\(\frac{\partial A^{-1}y}{\partial y} \ y'\)

approximation

approximation

differentiation

differentiation

\( A^{-1} y'\)

fun y =>
  let mut x := 0
  let mut x' := 0
  while (true) 
    x  := D⁻¹ (y  - (L + U) x)
    x' := D⁻¹ (y' - (L + U) x')
    if ‖A x - y‖ < ε then
      break
  return x'
fun y =>
  let mut x' := 0
  while (true) 
    x' := D⁻¹ (y' - (L + U) x')
    if ‖A x' - y'‖ < ε then
      break
  return x'
fun y =>
  let mut x := 0
  while (true) 
    x := D⁻¹ (y - (L + U) x)
    if ‖A x - y‖ < ε then
      break
  return x

\(\left(\frac{\partial \cdot}{\partial y} \bar y \right) \)

\(A = D + L + U\)

Example: Ballistic motion

\begin{aligned} \dot x &= v \\ \dot v &= g - \gamma f(v) \\ x(0) &= 0 \qquad v(0) = v_0 \end{aligned}

For \(T\in \mathbb{R}, x_T \in \mathbb{R}^2\) find \(v_0\) such that \(x(T) = x_T\)

Can be computed as minimization problem

   \(v_0 = \argmin_{\bar v_0} \|x(T) - x_T\|^2 \qquad \text{where } \dot x(0) = \bar v_0\)

Problem:

To compute \(\frac{\partial x(T)}{\partial v_0}\) we have to solve adjoint problem

\begin{aligned} \dot x^* &= 0 \\ \dot v^* &= x^* - \gamma f'(v) v^* \\ x^*_T(T) &= x(T) - x_T \qquad v^*(T) = 0 \end{aligned}
\frac{\partial x(T)}{\partial v_0} = 2 v^*(0)

Example: Ballistic motion

\begin{aligned} \dot x &= v \\ \dot v &= g - (5+\|v\|) v \\ x(0) &= 0 \qquad v(0) = v_0 \end{aligned}

For \(T\in \mathbb{R}, x_T \in \mathbb{R}^2\) find \(v_0\) such that \(x(T) = x_T\)

approx aimToTarget (T : ℝ) (target : ℝ×ℝ) (init_v : ℝ×ℝ) (optimizationRate : ℝ) := 
    let shoot := λ (v : ℝ×ℝ) =>
                   odeSolve (t₀ := 0) (x₀ := (0,v)) (t := T)
                     (f := λ (t : ℝ) (x,v) => balisticMotion x v) |>.fst
    shoot⁻¹ target
def ballisticMotion (x v : ℝ×ℝ) := (v, g  - (5 + ‖v‖) • v)

Summary

  • SciLean - experimental library for scientific computing in Lean 4
    https://github.com/lecopivo/SciLean
  • decoupling of what (specification) from how (implementation)
    • better code readability/documentation
    • better composability and faster development
    • reduction of bugs
  • interactive "computer algebra system" working on source code
  • aim is to help users with the mathematics
  • provide formal guarantees when possible
  • current focus is on automatic and symbolic differentiation
    • imperative code - for loops, mutable variables etc.
    • variational calculus and differential geometry
    • formally verified
  • probabilistic programming at some point in the future

Using Interactive Theorem Prover for Scientific Computing

By lecopivo

Using Interactive Theorem Prover for Scientific Computing

  • 134