Tomáš Skřivan
Carnegie Mellon University
Hoskinson Center for Formal Mathematics
The Road to Differentiable and Probabilistic Programming in Fundamental Physics
27.06.2023
Overview
Functional programming notation primer
Function application
f(x,y) = f x y
x : X
f : X → Y → Z
X → Y → Z ≅ X×Y → Z
= X → (Y → Z)
≠ (X → Y) → Z
Type ascription
Lambda function
λ x => x + sin x
f(g(x),y) = f (g x) y
What is interactive theorem prover?
Interactive theorem prover:
inductive nat where
| zero : nat
| succ (n : nat) : nat
def add (n m : nat) : nat :=
match m with
| zero => n -- n + 0 = n
| succ m' => succ (add n m') -- n + (m + 1) = (n + m) + 1
infix:1000 "+" => add
theorem add_comm (m n : nat) : m + n = n + m := by <proof>
Example Peano numbers
Lean 4 as programming language for scientific computing
Lean 4 as programming language for scientific computing
def jacobi_method (A : ℝ^(n,n)) (x₀ y : ℝ^n) : ℝ^n :=
let iD := inv_diag A
let U := upper_triangular A
let L := lower_triangular A
let mut x := x₀
for i in [0:max_steps] do
x := iD (y - (L + U) x)
if ‖A x - y‖ < ε then
break
return x
Lean 4 as programming language for scientific computing
macro " ∂ " x:ident "," b:term : term => `(deriv fun $x => $b)
deriv fun x => f x =
∂ x, f x =
integral Ω fun x => norm2 (gradient u x) =
∫ x ∈ Ω, ‖∇ u x‖² =
Lean 4 as programming language for scientific computing
noncomputable
def deriv {X Y : Type} [Vec X] [Vec Y] (f : X → Y) (x dx : X) : Y :=
if h : ∃ dy : Y,
Tendsto (fun h : ℝ => (1/h) • (f (x + h • dx) - f x)) (nhds 0) (nhds dy)
then Classical.choose h
else 0
noncomputable
def adjoint {X Y : Type} [Hilbert X] [Hilbert Y] (f : X → Y) : Y → X :=
if h : ∃ f' : Y → X,
∀ x y, IsLin f
∧
⟪f x, y⟫ = ⟪x, f' y⟫
then Classical.choose h
else 0
noncomputable
def odeSolve {X : Type} [Vec X] (f : ℝ → X → X) (t₀ : ℝ) (x₀ : X) (t : ℝ) : X :=
if h : ∃ x : ℝ → X,
∀ t, ⅆ x t = f t (x t)
∧
x t₀ = x₀
then Classical.choose h t
else 0
Lean 4 as programming language for scientific computing
Example: Harmonic Oscillator
def H (m k : ℝ) (x p : ℝ) := (1/(2*m)) * ∥p∥² + k/2 * ∥x∥²
approx solver (m k : ℝ) (steps : Nat)
:= odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x p',
-∇ (x':=x), H m k x' p))
by
-- Unfold Hamiltonian and compute gradients
unfold H
symdiff
-- Apply RK4 method
rw [odeSolve_fixed_dt runge_kutta4_step]
approx_limit steps; simp; intro steps';
Goals (1)
m k : ℝ
steps : ℕ
⊢ Approx (odeSolve fun t x => (1 / m * x.snd, -(k * x.fst)))
Goals (1)
m k : ℝ
steps : ℕ
⊢ Approx (odeSolve (λ t (x,p) => ( ∇ (p':=p), H m k x p',
-∇ (x':=x), H m k x' p)))
Goals (1)
m k : ℝ
steps steps' : ℕ
⊢ Approx (odeSolve_fixed_dt_impl steps' runge_kutta4_step
fun t x => (1 / m * x.snd, -(k * x.fst)))
(x, p) := solver m k substeps t (x, p) Δt
Two Main Operations of AD: Differential and Adjoint
Differential ∂:(X→Y)→(X→X→Y)
Adjoint †:(X→Y)→(Y→X)
x:X
f
f
y:Y
x:X
∂f
dy:Y
dx:X
∂
x:X
f
f
y:Y
y:Y
f
f†
x:X
†
Other operators
Forward Mode AD and Functoriality
g
∂
x:X
∂(f∘g)
dz:Z
dx:X
?
x:X
f
g
f
f
z:Z
∂f
x:X
dx:X
∂g
y:Y
dy:Y
dz:Z
Forward Mode AD and Functoriality
𝒯 (λ x => f (g x))
=
λ x dx =>
let (y,dy) := 𝒯 g x dx
let (z,dz) := 𝒯 f y dy
(z,dz)
x:X
T(f∘g)
dz:Z
dx:X
z:Z
x:X
Tg
dy:Y
dx:X
y:Y
Tf
dz:Z
z:Z
x:X
f
g
f
f
z:Z
Reverse Mode AD and Functoriality
x:X
f
g
?
f
f
z:Z
x:X
∂†(f∘g)
dx:X
dz:Z
∂†
x:X
∂†g
dx:X
y:Y
∂†f
dy:Y
dz:Z
Reverse Mode AD and Functoriality
x:X
f
g
?
f
f
z:Z
x:X
R(f∘g)
z:Z
∂†(f∘g)x:Z→X
∂†(f∘g)x:Z→X
ℛ (λ x => f (g x))
=
λ x =>
let (y,dg') := ℛ g x
let (z,df') := ℛ f y
(z, λ dz => dg' (df' dz))
∂†
x:X
Rg
y:Y
∂†gx:Y→X
Rf
z:Z
∂†fy:Z→Y
R f x = (f x, λ dy => (∂ f x)† dy)
Approach to AD in Lean 4
Derivative ∂
Adjoint †
instance comp_is_smooth (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
: IsSmooth (λ x => f (g x)) := ...
@[simp]
theorem chain_rule (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
: ∂ x, f (g x)
=
λ x dx => ∂ f (g x) (∂ g x dx) := ...
instance comp_has_adjoint (f : Y → Z) (g : X → Y) [IsLin f] [IsLin g]
: IsLin (λ x => f (g x)) := ...
@[simp]
theorem adj_of_comp (f : Y → Z) (g : X → Y) [IsLin f] [IsLin g]
: (λ x => f (g x))†
=
λ z => g† (f† z) := ...
Approach to AD in Lean 4
Forward mode T
Reverse mode R
noncomputable
def 𝒯 (f : X → Y) (x dx : X) : Y×Y := (f x, ∂ f x dx)
@[simp]
theorem fd_chain_rule (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
: 𝒯 (λ x => f (g x))
=
λ x dx =>
let (y,dy) := 𝒯 g x dx
𝒯 f y dy
:= ...
noncomputable
def ℛ (f : X → Y) (x : X) : Y×(Y→X) := (f x, (∂ f x)†)
@[simp]
theorem rd_chain_rule (f : Y → Z) (g : X → Y) [IsSmooth f] [IsSmooth g]
: ℛ (λ x => f (g x))
=
λ x =>
let (y,dg') := ℛ g x
let (z,df') := ℛ f y
(z, λ dz => dg' (df' dz))
:= ...
Difficulty with traditional automatic differentiation
∂x∂ex
approximation
differentiation
ex
∂x∂∑i=0Nn!xn
∑i=0N−1n!xn
∑i=0Nn!xn
differentiation
approximation
Difficulty with traditional automatic differentiation
∂y∂A−1y y′
approximation
approximation
differentiation
differentiation
A−1y′
fun y =>
let mut x := 0
let mut x' := 0
while (true)
x := D⁻¹ (y - (L + U) x)
x' := D⁻¹ (y' - (L + U) x')
if ‖A x - y‖ < ε then
break
return x'
fun y =>
let mut x' := 0
while (true)
x' := D⁻¹ (y' - (L + U) x')
if ‖A x' - y'‖ < ε then
break
return x'
fun y =>
let mut x := 0
while (true)
x := D⁻¹ (y - (L + U) x)
if ‖A x - y‖ < ε then
break
return x
(∂y∂⋅yˉ)
A=D+L+U
Example: Ballistic motion
For T∈R,xT∈R2 find v0 such that x(T)=xT
Can be computed as minimization problem
v0=argminvˉ0∥x(T)−xT∥2where x˙(0)=vˉ0
Problem:
To compute ∂v0∂x(T) we have to solve adjoint problem
Example: Ballistic motion
For T∈R,xT∈R2 find v0 such that x(T)=xT
approx aimToTarget (T : ℝ) (target : ℝ×ℝ) (init_v : ℝ×ℝ) (optimizationRate : ℝ) :=
let shoot := λ (v : ℝ×ℝ) =>
odeSolve (t₀ := 0) (x₀ := (0,v)) (t := T)
(f := λ (t : ℝ) (x,v) => balisticMotion x v) |>.fst
shoot⁻¹ target
def ballisticMotion (x v : ℝ×ℝ) := (v, g - (5 + ‖v‖) • v)
Summary