Using Interactive Theorem Prover for Scientific Computing

Tomáš Skřivan

Carnegie Mellon University

Hoskinson Center for Formal Mathematics

The Road to Differentiable and Probabilistic Programming in Fundamental Physics

27.06.2023

Overview

  • What is interactive theorem prover?
  • Lean 4 as programming language for scientific computing
  • Harmonic oscillator example
  • Approach to automatic differentiation
  • Ballistic motion optimization example

Functional programming notation primer

Function application

f(x,y) = f x y
x : X

f : X → Y → Z


    X → Y → Z ≅ X×Y → Z 
    	      = X → (Y → Z)
              ≠ (X → Y) → Z

Type ascription

Lambda function

λ x => x + sin x
f(g(x),y) = f (g x) y

What is interactive theorem prover?

Interactive theorem prover:

  • allows to state and prove mathematical theorems
  • proofs are done via human-computer interaction
inductive nat where
| zero : nat
| succ (n : nat) : nat
def add (n m : nat) : nat :=
match m with
| zero    => n                 -- n + 0 = n
| succ m' => succ (add n m')   -- n + (m + 1) = (n + m) + 1
infix:1000 "+" => add

theorem add_comm (m n : nat) : m + n = n + m := by <proof>

Example Peano numbers

Lean 4 as programming language for scientific computing

  • interactive theorem prover
  • functional programming language
  • supports imperative style and mutable variables
  • custom notation
  • ability to talk about mathematical objects
  • write programs as mathematical specification and then interactively apply approximations and code transformations(like optimization and AD)

Lean 4 as programming language for scientific computing

  • interactive theorem prover
  • functional programming language
  • supports imperative style and mutable variables
  • custom notation
  • ability to talk about mathematical objects
  • write programs as mathematical specification and then interactively apply approximations and code transformations(like optimization and AD)









def jacobi_method (A : ℝ^(n,n)) (x₀ y : ℝ^n) : ℝ^n :=
  let iD := inv_diag A
  let U  := upper_triangular A
  let L  := lower_triangular A
  
  let mut x := x₀
  for i in [0:max_steps] do
    x := iD (y - (L + U) x)
    if ‖A x - y‖ < ε then
      break
  return x
x(0)=x0x(n+1)=D1(y(L+U)x(n))x^{(0)} = x_0 \\ x^{(n+1)} = D^{-1}( y - (L + U) x^{(n)})
x^{(0)} = x_0 \\ x^{(n+1)} = D^{-1}( y - (L + U) x^{(n)})

Lean 4 as programming language for scientific computing

  • interactive theorem prover
  • functional programming language
  • supports imperative style and mutable variables
  • custom notation
  • ability to talk about mathematical objects
  • write programs as mathematical specification and then interactively apply approximations and code transformations(like optimization and AD)
macro " ∂ " x:ident "," b:term : term => `(deriv fun $x => $b)



              deriv fun x => f x =
              
                        ∂ x, f x = 





integral Ω fun x => norm2 (gradient u x) = 

                       ∫ x ∈ Ω, ‖∇ u x‖² = 
f(x)x\frac{\partial f(x)}{\partial x}
\frac{\partial f(x)}{\partial x}
Ωu(x)2dx\int_\Omega \| \nabla u(x) \|^2 \,dx
\int_\Omega \| \nabla u(x) \|^2 \,dx

Lean 4 as programming language for scientific computing

noncomputable 
def deriv {X Y : Type} [Vec X] [Vec Y] (f : X → Y) (x dx : X) : Y := 
  if h : ∃ dy : Y, 
    Tendsto (fun h : ℝ => (1/h) • (f (x + h • dx) - f x)) (nhds 0) (nhds dy)
  then Classical.choose h
  else 0

noncomputable 
def adjoint {X Y : Type} [Hilbert X] [Hilbert Y] (f : X → Y) : Y → X := 
  if h : ∃ f' : Y → X, 
    ∀ x y, IsLin f 
           ∧
    	   ⟪f x, y⟫ = ⟪x, f' y⟫
  then Classical.choose h
  else 0

noncomputable
def odeSolve {X : Type} [Vec X] (f : ℝ → X → X) (t₀ : ℝ) (x₀ : X) (t : ℝ) : X :=
  if h : ∃ x : ℝ → X, 
             ∀ t, ⅆ x t = f t (x t) 
             ∧ 
             x t₀ = x₀
  then Classical.choose h t
  else 0
  • interactive theorem prover
  • functional programming language
  • supports imperative style and mutable variables
  • custom notation
  • ability to talk about mathematical objects
  • write programs as mathematical specification and then interactively apply approximations and code transformations(like optimization and AD)

Lean 4 as programming language for scientific computing

  • interactive theorem prover
  • functional programming language
  • supports imperative style and mutable variables
  • custom notation
  • ability to talk about mathematical objects
  • write programs as mathematical specification and then interactively apply approximations and code transformations(like optimization and AD)

Example: Harmonic Oscillator

x˙=Hpp˙=Hx\begin{aligned} \dot x &= \frac{\partial H}{\partial p} \\ \dot p &= - \frac{\partial H}{\partial x} \end{aligned}
\begin{aligned} \dot x &= \frac{\partial H}{\partial p} \\ \dot p &= - \frac{\partial H}{\partial x} \end{aligned}
def H (m k : ℝ) (x p : ℝ) := (1/(2*m)) * ∥p∥² + k/2 * ∥x∥²

approx solver (m k : ℝ) (steps : Nat)
  := odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x  p',
                             -∇ (x':=x), H m k x' p))
by
  -- Unfold Hamiltonian and compute gradients
  unfold H
  symdiff

  -- Apply RK4 method
  rw [odeSolve_fixed_dt runge_kutta4_step]
  approx_limit steps; simp; intro steps';
Goals (1)
m k :  ℝ
steps : ℕ
⊢ Approx (odeSolve fun t x => (1 / m * x.snd, -(k * x.fst)))
Goals (1)
m k :  ℝ
steps : ℕ
⊢ Approx (odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x  p',
                                  -∇ (x':=x), H m k x' p)))
Goals (1)
m k :  ℝ
steps steps' : ℕ
⊢ Approx (odeSolve_fixed_dt_impl steps' runge_kutta4_step 
			fun t x => (1 / m * x.snd, -(k * x.fst)))
(x, p) := solver m k substeps t (x, p) Δt

Two Main Operations of AD: Differential and Adjoint

Differential :(XY)(XXY)\partial : (X \rightarrow Y) \rightarrow (X \rightarrow X \rightarrow Y)

Adjoint :(XY)(YX)\dagger : (X \rightarrow Y) \rightarrow (Y \rightarrow X )

fxdx=limh0f(x+hdx)f(x)h\partial \, f\, x \, dx = \lim_{h\rightarrow 0} \frac{f (x + h \cdot dx) - f(x)}{h}
\partial \, f\, x \, dx = \lim_{h\rightarrow 0} \frac{f (x + h \cdot dx) - f(x)}{h}

x:Xx:X

ff

ff

y:Yy:Y

x:Xx:X

f\partial f

dy:Ydy:Y

dx:Xdx:X

 \partial

x:Xx:X

ff

ff

y:Yy:Y

y:Yy:Y

ff

ff^\dagger

x:Xx:X

\dagger

xy,f(x),y=x,f(y)\forall \, x \, y, \left\langle f(x), y \right\rangle = \left\langle x, f^\dagger (y) \right\rangle
\forall \, x \, y, \left\langle f(x), y \right\rangle = \left\langle x, f^\dagger (y) \right\rangle

Other operators

  • derivative ff'  for f:RXf : \mathbb{R} \rightarrow X
    • f(x)= f x 1f'(x) = \texttt{∂ f x 1}   
  • gradient f\nabla f  for f:XRf : X \rightarrow \mathbb{R}
    • f(x)=(∂ f x)† 1\nabla f(x) = \texttt{(∂ f x)† 1}
  • adjoint differential \partial^\dagger
    • ∂† f x dy = (∂ f x)† dy\texttt{∂† f x dy = (∂ f x)† dy}
  • tangent map T\mathcal{T} - forward mode AD
    • T f x dx = (f x,  f x dx)\mathcal{T}\texttt{ f x dx = (f x, ∂ f x dx)}
  • reverse differential R\mathcal{R} - reverse mode AD
    •  R f x = (f x, λ dy => (∂ f x)† dy)\mathcal{R}\texttt{ f x = (f x, λ dy => (∂ f x)† dy)}
  • complexification fcf^c 
    • fc:CCf^c : \mathbb{C} \rightarrow \mathbb{C}  for  f:RRf : \mathbb{R} \rightarrow \mathbb{R}
  • inverse f1f^{-1}

Forward Mode AD and Functoriality

gg

 \partial

x:Xx:X

(fg)\partial (f\circ g)

dz:Zdz:Z

dx:Xdx:X

 ??

x:Xx:X

ff

gg

ff

ff

z:Zz:Z

f\partial f

x:Xx:X

dx:Xdx:X

g\partial g

y:Yy:Y

dy:Ydy:Y

dz:Zdz:Z

Forward Mode AD and Functoriality

𝒯 (λ x => f (g x)) 
= 
λ x dx =>
  let (y,dy) := 𝒯 g x dx
  let (z,dz) := 𝒯 f y dy
  (z,dz)

x:Xx:X

T(fg)\mathcal{T} (f\circ g)

dz:Zdz:Z

dx:Xdx:X

z:Zz:Z

x:Xx:X

Tg\mathcal{T} g

dy:Ydy:Y

dx:Xdx:X

y:Yy:Y

Tf\mathcal{T} f

dz:Zdz:Z

z:Zz:Z

x:Xx:X

ff

gg

ff

ff

z:Zz:Z

Reverse Mode AD and Functoriality

x:Xx:X

ff

gg

 ??

ff

ff

z:Zz:Z

x:Xx:X

(fg)\partial^\dagger (f\circ g)

dx:Xdx:X

dz:Zdz:Z

 \partial^\dagger

x:Xx:X

g\partial^\dagger g

dx:Xdx:X

y:Yy:Y

f\partial^\dagger f

dy:Ydy:Y

dz:Zdz:Z

Reverse Mode AD and Functoriality

x:Xx:X

ff

gg

 ??

ff

ff

z:Zz:Z

x:Xx:X

R(fg)\mathcal{R} (f\circ g)

z:Zz:Z

(fg)x:ZX\partial^\dagger (f\circ g) \, x : Z \rightarrow X

(fg)x:ZX\partial^\dagger (f\circ g) \, x : Z \rightarrow X

ℛ (λ x => f (g x)) 
    = 
    λ x => 
      let (y,dg') := ℛ g x
      let (z,df') := ℛ f y
      (z, λ dz => dg' (df' dz))

 \partial^\dagger

x:Xx:X

Rg\mathcal{R} g

y:Yy:Y

gx:YX\partial^\dagger g \, x : Y \rightarrow X

Rf\mathcal{R} f

z:Zz:Z

fy:ZY\partial^\dagger f \, y : Z \rightarrow Y

 R f x = (f x, λ dy => (∂ f x)† dy)\mathcal{R}\texttt{ f x = (f x, λ dy => (∂ f x)† dy)}

Approach to AD in Lean 4

Derivative \partial

Adjoint \dagger

instance comp_is_smooth (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
  : IsSmooth (λ x => f (g x)) := ...

@[simp]
theorem chain_rule (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
  : ∂ x, f (g x)
    = 
    λ x dx => ∂ f (g x) (∂ g x dx) := ...
instance comp_has_adjoint (f : Y → Z) (g : X → Y) [IsLin f] [IsLin g]
  : IsLin (λ x => f (g x)) := ...

@[simp]
theorem adj_of_comp (f : Y → Z) (g : X → Y) [IsLin f] [IsLin g] 
  : (λ x => f (g x))† 
  	=
    λ z => g† (f† z) := ...

Approach to AD in Lean 4

Forward mode T \mathcal{T}

Reverse mode R\mathcal{R}

noncomputable
def 𝒯 (f : X → Y) (x dx : X) : Y×Y := (f x, ∂ f x dx)

@[simp]
theorem fd_chain_rule (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
  : 𝒯 (λ x => f (g x)) 
    = 
    λ x dx => 
      let (y,dy) := 𝒯 g x dx
      𝒯 f y dy 
  := ...
noncomputable
def ℛ (f : X → Y) (x : X) : Y×(Y→X) := (f x, (∂ f x)†) 

@[simp]
theorem rd_chain_rule (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
  : ℛ (λ x => f (g x)) 
    = 
    λ x => 
      let (y,dg') := ℛ g x
      let (z,df') := ℛ f y
      (z, λ dz => dg' (df' dz))
  := ...

Difficulty with traditional automatic differentiation

exx\frac{\partial e^x}{\partial x}

approximation

differentiation

exe^x

xi=0Nxnn!\frac{\partial}{\partial x} \sum_{i=0}^N \frac{x^n}{n!}

i=0N1xnn!\sum_{i=0}^{N-1} \frac{x^n}{n!}

i=0Nxnn!\sum_{i=0}^{N} \frac{x^n}{n!}

differentiation

approximation

Difficulty with traditional automatic differentiation

A1yy y\frac{\partial A^{-1}y}{\partial y} \ y'

approximation

approximation

differentiation

differentiation

A1y A^{-1} y'

fun y =>
  let mut x := 0
  let mut x' := 0
  while (true) 
    x  := D⁻¹ (y  - (L + U) x)
    x' := D⁻¹ (y' - (L + U) x')
    if ‖A x - y‖ < ε then
      break
  return x'
fun y =>
  let mut x' := 0
  while (true) 
    x' := D⁻¹ (y' - (L + U) x')
    if ‖A x' - y'‖ < ε then
      break
  return x'
fun y =>
  let mut x := 0
  while (true) 
    x := D⁻¹ (y - (L + U) x)
    if ‖A x - y‖ < ε then
      break
  return x

(yyˉ)\left(\frac{\partial \cdot}{\partial y} \bar y \right)

A=D+L+UA = D + L + U

Example: Ballistic motion

x˙=vv˙=gγf(v)x(0)=0v(0)=v0\begin{aligned} \dot x &= v \\ \dot v &= g - \gamma f(v) \\ x(0) &= 0 \qquad v(0) = v_0 \end{aligned}
\begin{aligned} \dot x &= v \\ \dot v &= g - \gamma f(v) \\ x(0) &= 0 \qquad v(0) = v_0 \end{aligned}

For TR,xTR2T\in \mathbb{R}, x_T \in \mathbb{R}^2 find v0v_0 such that x(T)=xTx(T) = x_T

Can be computed as minimization problem

   v0=arg minvˉ0x(T)xT2where x˙(0)=vˉ0v_0 = \argmin_{\bar v_0} \|x(T) - x_T\|^2 \qquad \text{where } \dot x(0) = \bar v_0

Problem:

To compute x(T)v0\frac{\partial x(T)}{\partial v_0} we have to solve adjoint problem

x˙=0v˙=xγf(v)vxT(T)=x(T)xTv(T)=0\begin{aligned} \dot x^* &= 0 \\ \dot v^* &= x^* - \gamma f'(v) v^* \\ x^*_T(T) &= x(T) - x_T \qquad v^*(T) = 0 \end{aligned}
\begin{aligned} \dot x^* &= 0 \\ \dot v^* &= x^* - \gamma f'(v) v^* \\ x^*_T(T) &= x(T) - x_T \qquad v^*(T) = 0 \end{aligned}
x(T)v0=2v(0)\frac{\partial x(T)}{\partial v_0} = 2 v^*(0)
\frac{\partial x(T)}{\partial v_0} = 2 v^*(0)

Example: Ballistic motion

x˙=vv˙=g(5+v)vx(0)=0v(0)=v0\begin{aligned} \dot x &= v \\ \dot v &= g - (5+\|v\|) v \\ x(0) &= 0 \qquad v(0) = v_0 \end{aligned}
\begin{aligned} \dot x &= v \\ \dot v &= g - (5+\|v\|) v \\ x(0) &= 0 \qquad v(0) = v_0 \end{aligned}

For TR,xTR2T\in \mathbb{R}, x_T \in \mathbb{R}^2 find v0v_0 such that x(T)=xTx(T) = x_T

approx aimToTarget (T : ℝ) (target : ℝ×ℝ) (init_v : ℝ×ℝ) (optimizationRate : ℝ) := 
    let shoot := λ (v : ℝ×ℝ) =>
                   odeSolve (t₀ := 0) (x₀ := (0,v)) (t := T)
                     (f := λ (t : ℝ) (x,v) => balisticMotion x v) |>.fst
    shoot⁻¹ target
def ballisticMotion (x v : ℝ×ℝ) := (v, g  - (5 + ‖v‖) • v)

Summary

  • SciLean - experimental library for scientific computing in Lean 4
    https://github.com/lecopivo/SciLean
  • decoupling of what (specification) from how (implementation)
    • better code readability/documentation
    • better composability and faster development
    • reduction of bugs
  • interactive "computer algebra system" working on source code
  • aim is to help users with the mathematics
  • provide formal guarantees when possible
  • current focus is on automatic and symbolic differentiation
    • imperative code - for loops, mutable variables etc.
    • variational calculus and differential geometry
    • formally verified
  • probabilistic programming at some point in the future
Made with Slides.com