Composable Bayesian Modeling with Soss.jl

July 2021

• Soss is a probabilistic programming language

• A Soss model is a DAG

• A node can be arbitrarily complex

• Each node is stored as an expression

• Code for inference primitives generated on demand

mc = Chain(Normal(0,1)) do x Normal(x,1) end

\overbrace{\hspace{0.6in}}
\overbrace{\hspace{0.6in}}

start

transition

julia> r = rand(mc);

julia> collect(take(r,5))

5-element Vector{Any}:

0.020204148899821542

-0.814158867990756

-0.326553673512969

-0.5415414098608508

0.30487000460135294

m = @model begin

mc ~ Chain(Normal(0,1)) do x Normal(x,1) end

return mc

end

julia> r = rand(mc);

julia> collect(take(r,5))

5-element Vector{Any}:

0.020204148899821542

-0.814158867990756

-0.326553673512969

-0.5415414098608508

0.30487000460135294

m = @model begin

mc ~ Chain(mc_init()) do x Normal(x,1) end

return mc

end

mc_init = @model begin

x ~ Normal(0,1)

return x

end

m = @model begin

mc ~ Chain(mc_init()) do s mc_step(s=s) end

return mc

end

mc_init = @model begin

x ~ Normal(0,1)

return (x=x,)

end

mc_step = @model s begin

x ~ Normal(s.x, 1)

return (x=x,)

end

julia> r = rand(m());

julia> take(r,5) |> collect

5-element Vector{Any}:

(x = -1.8176087228666433,)

(x = -0.9138066071200438,)

(x = -1.8247694123151978,)

(x = -3.1868339797219116,)

(x = -1.8893821736742664,)

m = @model begin

Δμ ~ Normal()

σ ~ HalfNormal()

p = (;Δμ, σ)

mc ~ Chain(mc_init()) do s mc_step(p=p,s=s) end

return mc

end

mc_init = @model begin

x ~ Normal(0,1)

return (x=x,)

end

mc_step = @model p,s begin

x ~ Normal(s.x + p.Δμ, p.σ)

return (x=x,)

end

julia> r = rand(m());

julia> take(r,5) |> collect

5-element Vector{Any}:

(x = -1.8176087228666433,)

(x = -0.9138066071200438,)

(x = -1.8247694123151978,)

(x = -3.1868339797219116,)

(x = -1.8893821736742664,)

m = @model begin

Δμ ~ Normal()

σ ~ HalfNormal()

p = (;Δμ, σ)

mc ~ Chain(mc_init()) do s mc_step(p=p,s=s) end

return mc

end

mc_init = @model begin

x ~ Normal(0,1)

return (x=x,)

end

mc_step = @model p,s begin

x ~ Normal(s.x + p.Δμ, p.σ)

return (x=x,)

end

julia> obs = take(rand(m()), 10) |> collect;

julia> post = m() | (mc = obs,);

julia> using SampleChainsDynamicHMC

julia> sample(post, dynamichmc())

4000-element MultiChain with 4 chains and schema (σ = Float64, Δμ = Float64)

(σ = 0.99±0.24, Δμ = 1.46±0.33)

m = @model begin

Δμ ~ Normal()

σ ~ HalfNormal()

p = (;Δμ, σ)

mc ~ Chain(mc_init()) do s mc_step(p=p,s=s) end

y ~ For(mc) do s

Poisson(exp(s.x))

end

return y

end

mc_init = @model begin

x ~ Normal(0,1)

return (x=x,)

end

mc_step = @model p,s begin

x ~ Normal(s.x + p.Δμ, p.σ)

return (x=x,)

end

julia> y = predict(m(), (σ=1, Δμ=0.5));

julia> take(y, 10) |> collect

10-element Vector{Int64}:

3

9

31

97

972

1129

1722

2788

1185

998