Composable Bayesian Modeling with Soss.jl

Chad Scherrer

July 2021

  • Soss is a probabilistic programming language
     
  • A Soss model is a DAG
     
  • A node can be arbitrarily complex
     
  • Each node is stored as an expression
     
  • Code for inference primitives generated on demand

mc = Chain(Normal(0,1)) do x Normal(x,1) end

\overbrace{\hspace{0.6in}}
\overbrace{\hspace{0.6in}}

start

transition

julia> r = rand(mc);

 

julia> collect(take(r,5))

5-element Vector{Any}:

0.020204148899821542

-0.814158867990756

-0.326553673512969

-0.5415414098608508

0.30487000460135294

m = @model begin

    mc ~ Chain(Normal(0,1)) do x Normal(x,1) end

    return mc

end

julia> r = rand(mc);

 

julia> collect(take(r,5))

5-element Vector{Any}:

0.020204148899821542

-0.814158867990756

-0.326553673512969

-0.5415414098608508

0.30487000460135294

m = @model begin

    mc ~ Chain(mc_init()) do x Normal(x,1) end

    return mc

end

mc_init = @model begin

    x ~ Normal(0,1)

    return x

end

m = @model begin

    mc ~ Chain(mc_init()) do s mc_step(s=s) end

    return mc

end

mc_init = @model begin

    x ~ Normal(0,1)

    return (x=x,)

end

mc_step = @model s begin

    x ~ Normal(s.x, 1)

    return (x=x,)

end

julia> r = rand(m());

 

julia> take(r,5) |> collect

5-element Vector{Any}:

(x = -1.8176087228666433,)

(x = -0.9138066071200438,)

(x = -1.8247694123151978,)

(x = -3.1868339797219116,)

(x = -1.8893821736742664,)

m = @model begin

    Δμ ~ Normal()

    σ ~ HalfNormal()

    p = (;Δμ, σ)

    mc ~ Chain(mc_init()) do s mc_step(p=p,s=s) end

    return mc

end

mc_init = @model begin

    x ~ Normal(0,1)

    return (x=x,)

end

mc_step = @model p,s begin

    x ~ Normal(s.x + p.Δμ, p.σ)

    return (x=x,)

end

julia> r = rand(m());

 

julia> take(r,5) |> collect

5-element Vector{Any}:

(x = -1.8176087228666433,)

(x = -0.9138066071200438,)

(x = -1.8247694123151978,)

(x = -3.1868339797219116,)

(x = -1.8893821736742664,)

m = @model begin

    Δμ ~ Normal()

    σ ~ HalfNormal()

    p = (;Δμ, σ)

    mc ~ Chain(mc_init()) do s mc_step(p=p,s=s) end

    return mc

end

mc_init = @model begin

    x ~ Normal(0,1)

    return (x=x,)

end

mc_step = @model p,s begin

    x ~ Normal(s.x + p.Δμ, p.σ)

    return (x=x,)

end

julia> obs = take(rand(m()), 10) |> collect;

 

julia> post = m() | (mc = obs,);

 

julia> using SampleChainsDynamicHMC

 

julia> sample(post, dynamichmc())

4000-element MultiChain with 4 chains and schema (σ = Float64, Δμ = Float64)

(σ = 0.99±0.24, Δμ = 1.46±0.33)

m = @model begin

    Δμ ~ Normal()

    σ ~ HalfNormal()

    p = (;Δμ, σ)

    mc ~ Chain(mc_init()) do s mc_step(p=p,s=s) end

    y ~ For(mc) do s

        Poisson(exp(s.x))

    end

    return y

end

mc_init = @model begin

    x ~ Normal(0,1)

    return (x=x,)

end

mc_step = @model p,s begin

    x ~ Normal(s.x + p.Δμ, p.σ)

    return (x=x,)

end

julia> y = predict(m(), (σ=1, Δμ=0.5));

 

julia> take(y, 10) |> collect

10-element Vector{Int64}:

3

9

31

97

972

1129

1722

2788

1185

998

Thank You!