Tomáš Skřivan
Carnegie Mellon University
Why Scientific Computing in Lean?
SciLean
library for scientific computing in Lean 4
Motivation: Physics Simulation in Computer Graphics
work done at
Motivation
Sources of frustration:
This video is from the SIGGRAPH 2022 Technical Paper: ‘Loki: a unified multiphysics simulation framework for production’.
Scientific Computing and Lean
What can Lean offer?
Example: Harmonic Oscillator
def H (m k : ℝ) (x p : ℝ) := (1/(2*m)) * ∥p∥² + k/2 * ∥x∥²
approx solver (m k : ℝ) (steps : Nat)
:= odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x p',
-∇ (x':=x), H m k x' p))
by
-- Unfold Hamiltonian and compute gradients
unfold H
symdiff; symdiff
-- Apply RK4 method
rw [odeSolve_fixed_dt runge_kutta4_step]
approx_limit steps; simp; intro steps';
Goals (1)
m k : ℝ
steps : ℕ
⊢ Approx (odeSolve fun t x => (1 / m * x.snd, -(k * x.fst)))
Goals (1)
m k : ℝ
steps : ℕ
⊢ Approx (odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x p',
-∇ (x':=x), H m k x' p)))
Goals (1)
m k : ℝ
steps steps' : ℕ
⊢ Approx (odeSolve_fixed_dt_impl steps' runge_kutta4_step
fun t x => (1 / m * x.snd, -(k * x.fst)))
(x, p) := solver m k substeps t (x, p) Δt
\(x(t) = \texttt{odeSolve f t₀ x₀ t}\)
\(\Leftrightarrow\)
\(\dot x(t) = f(t,x(t)) x(t_0) = x_0\)
Two Main Operations of AD: Differential and Adjoint
Differential \(\partial : (X \rightarrow Y) \rightarrow (X \rightarrow X \rightarrow Y)\)
Adjoint \(\dagger : (X \rightarrow Y) \rightarrow (Y \rightarrow X )\)
\(x:X\)
\(f \)
\(f \)
\(y:Y\)
\(x:X\)
\(\partial f \)
\(dy:Y\)
\(dx:X\)
\(\partial\)
\(x:X\)
\(f \)
\(f \)
\(y:Y\)
\(y:Y\)
\(f \)
\(f^\dagger \)
\(x:X\)
\(\dagger\)
Simplifier and Typeclass approach
Differential \(\partial\)
Adjoint \(\dagger\)
instance comp_is_smooth (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
: IsSmooth (λ x => f (g x)) := ...
@[simp]
theorem chain_rule (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
: ∂ (λ x => f (g x))
=
λ x dx => ∂ f (g x) (∂ g x dx) := ...
instance comp_has_adjoint (f : Y → Z) (g : X → Y) [IsLin f] [IsLin g]
: IsLin (λ x => f (g x)) := ...
@[simp]
theorem adj_of_comp (f : Y → Z) (g : X → Y) [IsLin f] [IsLin g]
: (λ x => f (g x))†
=
λ z => g† (f† z) := ...
Other operators
Forward Mode AD and Functoriality
\(g \)
\(\partial\)
\(x:X\)
\(\partial (f\circ g) \)
\(dz:Z\)
\(dx:X\)
\(?\)
\(x:X\)
\(f \)
\(g \)
\(f \)
\(f \)
\(z:Z\)
\(\partial f \)
\(x:X\)
\(dx:X\)
\(\partial g \)
\(y:Y\)
\(dy:Y\)
\(dz:Z\)
Forward Mode AD and Functoriality
𝒯 (λ x => f (g x))
=
λ x dx =>
let (y,dy) := 𝒯 g x dx
let (z,dz) := 𝒯 f y dy
(z,dz)
\(x:X\)
\(\mathcal{T} (f\circ g) \)
\(dz:Z\)
\(dx:X\)
\(z:Z\)
\(x:X\)
\(\mathcal{T} g \)
\(dy:Y\)
\(dx:X\)
\(y:Y\)
\(\mathcal{T} f \)
\(dz:Z\)
\(z:Z\)
\(x:X\)
\(f \)
\(g \)
\(f \)
\(f \)
\(z:Z\)
Reverse Mode AD and Functoriality
\(x:X\)
\(f \)
\(g \)
\(?\)
\(f \)
\(f \)
\(z:Z\)
\(x:X\)
\(\partial^\dagger (f\circ g) \)
\(dx:X\)
\(dz:Z\)
\(\partial^\dagger\)
\(x:X\)
\(\partial^\dagger g \)
\(dx:X\)
\(y:Y\)
\(\partial^\dagger f \)
\(dy:Y\)
\(dz:Z\)
Reverse Mode AD and Functoriality
\(x:X\)
\(f \)
\(g \)
\(?\)
\(f \)
\(f \)
\(z:Z\)
\(x:X\)
\(\mathcal{R} (f\circ g) \)
\(z:Z\)
\(\partial^\dagger (f\circ g) \, x : Z \rightarrow X \)
\(\partial^\dagger (f\circ g) \, x : Z \rightarrow X \)
ℛ (λ x => f (g x))
=
λ x =>
let (y,dg') := ℛ g x
let (z,df') := ℛ f y
(z, λ dz => dg' (df' dz))
\(\partial^\dagger\)
\(x:X\)
\(\mathcal{R} g \)
\(y:Y\)
\(\partial^\dagger g \, x : Y \rightarrow X \)
\(\mathcal{R} f \)
\(z:Z\)
\(\partial^\dagger f \, y : Z \rightarrow Y \)
\(\mathcal{R}\texttt{ f x = (f x, λ dy => (∂ f x)† dy)}\)
Example: Ballistic motion
For \(T\in \mathbb{R}, x_T \in \mathbb{R}^2\) find \(v_0\) such that \(x(T) = x_T\)
We will solve necessary condition
\(v_0 = \argmin_{v_0^*} \|x(T) - x_T\|^2 \qquad \text{where } \dot x(0) = v_0^*\)
Problem:
To compute \(\frac{\partial x(T)}{\partial v_0}\) we have to solve adjoint problem
Example: Ballistic motion
For \(T\in \mathbb{R}, x_T \in \mathbb{R}^2\) find \(v_0\) such that \(x(T) = x_T\)
approx aimToTarget (init_v : ℝ×ℝ) (optimizationRate : ℝ) :=
λ (T : ℝ) (target : ℝ×ℝ) =>
let shoot := λ (t : ℝ) (v : ℝ×ℝ) =>
odeSolve (t₀ := 0) (x₀ := ((0:ℝ×ℝ),v)) (t := t)
(f := λ (t : ℝ) (x,v) => balisticMotion x v)
(λ v => (shoot T v).1)⁻¹ target
def ballisticMotion (x v : ℝ×ℝ) := (v, g - (5 + ‖v‖) • v)
SciLean summary