Composable Bayesian Modeling with Soss.jl
Chad Scherrer
July 2021
What is Soss.jl
- Soss is a probabilistic programming language
- A Soss model is a DAG
- A node can be arbitrarily complex
- Each node is stored as an expression
- Code for inference primitives generated on demand
Introduction: Markov Chains
mc = Chain(Normal(0,1)) do x Normal(x,1) end
start
transition
julia> r = rand(mc);
julia> collect(take(r,5))
5-element Vector{Any}:
0.020204148899821542
-0.814158867990756
-0.326553673512969
-0.5415414098608508
0.30487000460135294
Using Chain in a Model
m = @model begin
mc ~ Chain(Normal(0,1)) do x Normal(x,1) end
return mc
end
julia> r = rand(mc);
julia> collect(take(r,5))
5-element Vector{Any}:
0.020204148899821542
-0.814158867990756
-0.326553673512969
-0.5415414098608508
0.30487000460135294
Factoring Out the Initialization
m = @model begin
mc ~ Chain(mc_init()) do x Normal(x,1) end
return mc
end
mc_init = @model begin
x ~ Normal(0,1)
return x
end
Factoring Out the Markov Transition
m = @model begin
mc ~ Chain(mc_init()) do s mc_step(s=s) end
return mc
end
mc_init = @model begin
x ~ Normal(0,1)
return (x=x,)
end
mc_step = @model s begin
x ~ Normal(s.x, 1)
return (x=x,)
end
julia> r = rand(m());
julia> take(r,5) |> collect
5-element Vector{Any}:
(x = -1.8176087228666433,)
(x = -0.9138066071200438,)
(x = -1.8247694123151978,)
(x = -3.1868339797219116,)
(x = -1.8893821736742664,)
Adding Some Parameters
m = @model begin
Δμ ~ Normal()
σ ~ HalfNormal()
p = (;Δμ, σ)
mc ~ Chain(mc_init()) do s mc_step(p=p,s=s) end
return mc
end
mc_init = @model begin
x ~ Normal(0,1)
return (x=x,)
end
mc_step = @model p,s begin
x ~ Normal(s.x + p.Δμ, p.σ)
return (x=x,)
end
julia> r = rand(m());
julia> take(r,5) |> collect
5-element Vector{Any}:
(x = -1.8176087228666433,)
(x = -0.9138066071200438,)
(x = -1.8247694123151978,)
(x = -3.1868339797219116,)
(x = -1.8893821736742664,)
Inference
m = @model begin
Δμ ~ Normal()
σ ~ HalfNormal()
p = (;Δμ, σ)
mc ~ Chain(mc_init()) do s mc_step(p=p,s=s) end
return mc
end
mc_init = @model begin
x ~ Normal(0,1)
return (x=x,)
end
mc_step = @model p,s begin
x ~ Normal(s.x + p.Δμ, p.σ)
return (x=x,)
end
julia> obs = take(rand(m()), 10) |> collect;
julia> post = m() | (mc = obs,);
julia> using SampleChainsDynamicHMC
julia> sample(post, dynamichmc())
4000-element MultiChain with 4 chains and schema (σ = Float64, Δμ = Float64)
(σ = 0.99±0.24, Δμ = 1.46±0.33)
Hidden Markov Chains
m = @model begin
Δμ ~ Normal()
σ ~ HalfNormal()
p = (;Δμ, σ)
mc ~ Chain(mc_init()) do s mc_step(p=p,s=s) end
y ~ For(mc) do s
Poisson(exp(s.x))
end
return y
end
mc_init = @model begin
x ~ Normal(0,1)
return (x=x,)
end
mc_step = @model p,s begin
x ~ Normal(s.x + p.Δμ, p.σ)
return (x=x,)
end
julia> y = predict(m(), (σ=1, Δμ=0.5));
julia> take(y, 10) |> collect
10-element Vector{Int64}:
3
9
31
97
972
1129
1722
2788
1185
998