Using Interactive Theorem Prover for Scientific Computing
Tomáš Skřivan
Carnegie Mellon University
Hoskinson Center for Formal Mathematics
The Road to Differentiable and Probabilistic Programming in Fundamental Physics
27.06.2023
Overview
- What is interactive theorem prover?
- Lean 4 as programming language for scientific computing
- Harmonic oscillator example
- Approach to automatic differentiation
- Ballistic motion optimization example
Functional programming notation primer
Function application
f(x,y) = f x y
x : X
f : X → Y → Z
X → Y → Z ≅ X×Y → Z
= X → (Y → Z)
≠ (X → Y) → Z
Type ascription
Lambda function
λ x => x + sin x
f(g(x),y) = f (g x) y
What is interactive theorem prover?
Interactive theorem prover:
- allows to state and prove mathematical theorems
- proofs are done via human-computer interaction
inductive nat where
| zero : nat
| succ (n : nat) : nat
def add (n m : nat) : nat :=
match m with
| zero => n -- n + 0 = n
| succ m' => succ (add n m') -- n + (m + 1) = (n + m) + 1
infix:1000 "+" => add
theorem add_comm (m n : nat) : m + n = n + m := by <proof>
Example Peano numbers
Lean 4 as programming language for scientific computing
- interactive theorem prover
- functional programming language
- supports imperative style and mutable variables
- custom notation
- ability to talk about mathematical objects
- write programs as mathematical specification and then interactively apply approximations and code transformations(like optimization and AD)
Lean 4 as programming language for scientific computing
- interactive theorem prover
- functional programming language
- supports imperative style and mutable variables
- custom notation
- ability to talk about mathematical objects
- write programs as mathematical specification and then interactively apply approximations and code transformations(like optimization and AD)
def jacobi_method (A : ℝ^(n,n)) (x₀ y : ℝ^n) : ℝ^n :=
let iD := inv_diag A
let U := upper_triangular A
let L := lower_triangular A
let mut x := x₀
for i in [0:max_steps] do
x := iD (y - (L + U) x)
if ‖A x - y‖ < ε then
break
return x
Lean 4 as programming language for scientific computing
- interactive theorem prover
- functional programming language
- supports imperative style and mutable variables
- custom notation
- ability to talk about mathematical objects
- write programs as mathematical specification and then interactively apply approximations and code transformations(like optimization and AD)
macro " ∂ " x:ident "," b:term : term => `(deriv fun $x => $b)
deriv fun x => f x =
∂ x, f x =
integral Ω fun x => norm2 (gradient u x) =
∫ x ∈ Ω, ‖∇ u x‖² =
Lean 4 as programming language for scientific computing
noncomputable
def deriv {X Y : Type} [Vec X] [Vec Y] (f : X → Y) (x dx : X) : Y :=
if h : ∃ dy : Y,
Tendsto (fun h : ℝ => (1/h) • (f (x + h • dx) - f x)) (nhds 0) (nhds dy)
then Classical.choose h
else 0
noncomputable
def adjoint {X Y : Type} [Hilbert X] [Hilbert Y] (f : X → Y) : Y → X :=
if h : ∃ f' : Y → X,
∀ x y, IsLin f
∧
⟪f x, y⟫ = ⟪x, f' y⟫
then Classical.choose h
else 0
noncomputable
def odeSolve {X : Type} [Vec X] (f : ℝ → X → X) (t₀ : ℝ) (x₀ : X) (t : ℝ) : X :=
if h : ∃ x : ℝ → X,
∀ t, ⅆ x t = f t (x t)
∧
x t₀ = x₀
then Classical.choose h t
else 0
- interactive theorem prover
- functional programming language
- supports imperative style and mutable variables
- custom notation
- ability to talk about mathematical objects
- write programs as mathematical specification and then interactively apply approximations and code transformations(like optimization and AD)
Lean 4 as programming language for scientific computing
- interactive theorem prover
- functional programming language
- supports imperative style and mutable variables
- custom notation
- ability to talk about mathematical objects
- write programs as mathematical specification and then interactively apply approximations and code transformations(like optimization and AD)
Example: Harmonic Oscillator
def H (m k : ℝ) (x p : ℝ) := (1/(2*m)) * ∥p∥² + k/2 * ∥x∥²
approx solver (m k : ℝ) (steps : Nat)
:= odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x p',
-∇ (x':=x), H m k x' p))
by
-- Unfold Hamiltonian and compute gradients
unfold H
symdiff
-- Apply RK4 method
rw [odeSolve_fixed_dt runge_kutta4_step]
approx_limit steps; simp; intro steps';
Goals (1)
m k : ℝ
steps : ℕ
⊢ Approx (odeSolve fun t x => (1 / m * x.snd, -(k * x.fst)))
Goals (1)
m k : ℝ
steps : ℕ
⊢ Approx (odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x p',
-∇ (x':=x), H m k x' p)))
Goals (1)
m k : ℝ
steps steps' : ℕ
⊢ Approx (odeSolve_fixed_dt_impl steps' runge_kutta4_step
fun t x => (1 / m * x.snd, -(k * x.fst)))
(x, p) := solver m k substeps t (x, p) Δt
Two Main Operations of AD: Differential and Adjoint
Differential \(\partial : (X \rightarrow Y) \rightarrow (X \rightarrow X \rightarrow Y)\)
Adjoint \(\dagger : (X \rightarrow Y) \rightarrow (Y \rightarrow X )\)
\(x:X\)
\(f \)
\(f \)
\(y:Y\)
\(x:X\)
\(\partial f \)
\(dy:Y\)
\(dx:X\)
\(\partial\)
\(x:X\)
\(f \)
\(f \)
\(y:Y\)
\(y:Y\)
\(f \)
\(f^\dagger \)
\(x:X\)
\(\dagger\)
Other operators
- derivative \(f'\) for \(f : \mathbb{R} \rightarrow X\)
- \(f'(x) = \texttt{∂ f x 1}\)
- gradient \(\nabla f\) for \(f : X \rightarrow \mathbb{R}\)
- \(\nabla f(x) = \texttt{(∂ f x)† 1}\)
- adjoint differential \(\partial^\dagger\)
- \(\texttt{∂† f x dy = (∂ f x)† dy}\)
- tangent map \(\mathcal{T}\) - forward mode AD
- \(\mathcal{T}\texttt{ f x dx = (f x, ∂ f x dx)}\)
- reverse differential \(\mathcal{R}\) - reverse mode AD
- \(\mathcal{R}\texttt{ f x = (f x, λ dy => (∂ f x)† dy)}\)
- complexification \(f^c\)
- \(f^c : \mathbb{C} \rightarrow \mathbb{C}\) for \(f : \mathbb{R} \rightarrow \mathbb{R}\)
- inverse \(f^{-1}\)
Forward Mode AD and Functoriality
\(g \)
\(\partial\)
\(x:X\)
\(\partial (f\circ g) \)
\(dz:Z\)
\(dx:X\)
\(?\)
\(x:X\)
\(f \)
\(g \)
\(f \)
\(f \)
\(z:Z\)
\(\partial f \)
\(x:X\)
\(dx:X\)
\(\partial g \)
\(y:Y\)
\(dy:Y\)
\(dz:Z\)
Forward Mode AD and Functoriality
𝒯 (λ x => f (g x))
=
λ x dx =>
let (y,dy) := 𝒯 g x dx
let (z,dz) := 𝒯 f y dy
(z,dz)
\(x:X\)
\(\mathcal{T} (f\circ g) \)
\(dz:Z\)
\(dx:X\)
\(z:Z\)
\(x:X\)
\(\mathcal{T} g \)
\(dy:Y\)
\(dx:X\)
\(y:Y\)
\(\mathcal{T} f \)
\(dz:Z\)
\(z:Z\)
\(x:X\)
\(f \)
\(g \)
\(f \)
\(f \)
\(z:Z\)
Reverse Mode AD and Functoriality
\(x:X\)
\(f \)
\(g \)
\(?\)
\(f \)
\(f \)
\(z:Z\)
\(x:X\)
\(\partial^\dagger (f\circ g) \)
\(dx:X\)
\(dz:Z\)
\(\partial^\dagger\)
\(x:X\)
\(\partial^\dagger g \)
\(dx:X\)
\(y:Y\)
\(\partial^\dagger f \)
\(dy:Y\)
\(dz:Z\)
Reverse Mode AD and Functoriality
\(x:X\)
\(f \)
\(g \)
\(?\)
\(f \)
\(f \)
\(z:Z\)
\(x:X\)
\(\mathcal{R} (f\circ g) \)
\(z:Z\)
\(\partial^\dagger (f\circ g) \, x : Z \rightarrow X \)
\(\partial^\dagger (f\circ g) \, x : Z \rightarrow X \)
ℛ (λ x => f (g x))
=
λ x =>
let (y,dg') := ℛ g x
let (z,df') := ℛ f y
(z, λ dz => dg' (df' dz))
\(\partial^\dagger\)
\(x:X\)
\(\mathcal{R} g \)
\(y:Y\)
\(\partial^\dagger g \, x : Y \rightarrow X \)
\(\mathcal{R} f \)
\(z:Z\)
\(\partial^\dagger f \, y : Z \rightarrow Y \)
\(\mathcal{R}\texttt{ f x = (f x, λ dy => (∂ f x)† dy)}\)
Approach to AD in Lean 4
Derivative \(\partial\)
Adjoint \(\dagger\)
instance comp_is_smooth (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
: IsSmooth (λ x => f (g x)) := ...
@[simp]
theorem chain_rule (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
: ∂ x, f (g x)
=
λ x dx => ∂ f (g x) (∂ g x dx) := ...
instance comp_has_adjoint (f : Y → Z) (g : X → Y) [IsLin f] [IsLin g]
: IsLin (λ x => f (g x)) := ...
@[simp]
theorem adj_of_comp (f : Y → Z) (g : X → Y) [IsLin f] [IsLin g]
: (λ x => f (g x))†
=
λ z => g† (f† z) := ...
Approach to AD in Lean 4
Forward mode \( \mathcal{T} \)
Reverse mode \(\mathcal{R}\)
noncomputable
def 𝒯 (f : X → Y) (x dx : X) : Y×Y := (f x, ∂ f x dx)
@[simp]
theorem fd_chain_rule (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
: 𝒯 (λ x => f (g x))
=
λ x dx =>
let (y,dy) := 𝒯 g x dx
𝒯 f y dy
:= ...
noncomputable
def ℛ (f : X → Y) (x : X) : Y×(Y→X) := (f x, (∂ f x)†)
@[simp]
theorem rd_chain_rule (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
: ℛ (λ x => f (g x))
=
λ x =>
let (y,dg') := ℛ g x
let (z,df') := ℛ f y
(z, λ dz => dg' (df' dz))
:= ...
Difficulty with traditional automatic differentiation
\(\frac{\partial e^x}{\partial x}\)
approximation
differentiation
\(e^x\)
\(\frac{\partial}{\partial x} \sum_{i=0}^N \frac{x^n}{n!}\)
\(\sum_{i=0}^{N-1} \frac{x^n}{n!}\)
\(\sum_{i=0}^{N} \frac{x^n}{n!}\)
differentiation
approximation
Difficulty with traditional automatic differentiation
\(\frac{\partial A^{-1}y}{\partial y} \ y'\)
approximation
approximation
differentiation
differentiation
\( A^{-1} y'\)
fun y =>
let mut x := 0
let mut x' := 0
while (true)
x := D⁻¹ (y - (L + U) x)
x' := D⁻¹ (y' - (L + U) x')
if ‖A x - y‖ < ε then
break
return x'
fun y =>
let mut x' := 0
while (true)
x' := D⁻¹ (y' - (L + U) x')
if ‖A x' - y'‖ < ε then
break
return x'
fun y =>
let mut x := 0
while (true)
x := D⁻¹ (y - (L + U) x)
if ‖A x - y‖ < ε then
break
return x
\(\left(\frac{\partial \cdot}{\partial y} \bar y \right) \)
\(A = D + L + U\)
Example: Ballistic motion
For \(T\in \mathbb{R}, x_T \in \mathbb{R}^2\) find \(v_0\) such that \(x(T) = x_T\)
Can be computed as minimization problem
\(v_0 = \argmin_{\bar v_0} \|x(T) - x_T\|^2 \qquad \text{where } \dot x(0) = \bar v_0\)
Problem:
To compute \(\frac{\partial x(T)}{\partial v_0}\) we have to solve adjoint problem
Example: Ballistic motion
For \(T\in \mathbb{R}, x_T \in \mathbb{R}^2\) find \(v_0\) such that \(x(T) = x_T\)
approx aimToTarget (T : ℝ) (target : ℝ×ℝ) (init_v : ℝ×ℝ) (optimizationRate : ℝ) :=
let shoot := λ (v : ℝ×ℝ) =>
odeSolve (t₀ := 0) (x₀ := (0,v)) (t := T)
(f := λ (t : ℝ) (x,v) => balisticMotion x v) |>.fst
shoot⁻¹ target
def ballisticMotion (x v : ℝ×ℝ) := (v, g - (5 + ‖v‖) • v)
Summary
- SciLean - experimental library for scientific computing in Lean 4
https://github.com/lecopivo/SciLean - decoupling of what (specification) from how (implementation)
- better code readability/documentation
- better composability and faster development
- reduction of bugs
- interactive "computer algebra system" working on source code
- aim is to help users with the mathematics
- provide formal guarantees when possible
- current focus is on automatic and symbolic differentiation
- imperative code - for loops, mutable variables etc.
- variational calculus and differential geometry
- formally verified
- probabilistic programming at some point in the future
Using Interactive Theorem Prover for Scientific Computing
By lecopivo
Using Interactive Theorem Prover for Scientific Computing
- 134