Soss: Lightweight Probabilistic Programming in Julia

Chad Scherrer

Senior Data Scientist, Metis

Managed Uncertainty

Rational Decisions

Bayesian Analysis

Probabilistic Programming

  • Physical systems
  • Hypothesis testing
  • Modeling as simulation
  • Medicine
  • Finance
  • Insurance

Risk

Custom models

A disconnect between the "user language" and "developer language"

X

3

Python

C

Deep Learning Framework

  • Harder for beginner users
  • Barrier to entry for developers
  • Limits extensibility

?

  • Models begin with "Tilde code" based on math notation


     
  • A Model is a "lossless" data structure

     
  • Combinators transform code before compilation

Model

Tilde code



\begin{aligned} \mu &\sim \text{Normal}(0,5)\\ \sigma &\sim \text{Cauchy}_+(0,3) \\ x_j &\sim \text{Normal}(\mu,\sigma) \end{aligned}

Advantages

  • Transparent
  • Composable
  • Backend-agnostic
  • No intrinsic overhead




P(\mu,\sigma|x)\propto P(\mu,\sigma)P(x|\mu,\sigma)
\begin{aligned} \mu &\sim \text{Normal}(0,5)\\ \sigma &\sim \text{Cauchy}_+(0,3) \\ x_j &\sim \text{Normal}(\mu,\sigma) \end{aligned}
m = @model N begin
    μ ~ Normal(0, 5)
    σ ~ HalfCauchy(3)
    x ~ Normal(μ, σ) |> iid(N)
end

Theory

Soss

julia> sourceLogdensity(m)
:(function ##logdensity#1168(pars)
      @unpack (μ, σ, x, N) = pars
      ℓ = 0.0
      ℓ += logpdf(Normal(0, 5), μ)
      ℓ += logpdf(HalfCauchy(3), σ)
      ℓ += logpdf(Normal(μ, σ) |> iid(N), x)
      return ℓ
  end)
gaussianMixture = @model begin
    N ~ Poisson(100)
    K ~ Poisson(2.5)
    p ~ Dirichlet(K, 1.0)
    μ ~ Normal(0,1.5) |> iid(K)
    σ ~ HalfNormal(1)
    θ = [(m,σ) for m in μ]
    x ~ MixtureModel(Normal, θ, p) |> iid(N)
end
julia> m = gaussianMixture(N=100,K=2)
@model begin
    p ~ Dirichlet(2, 1.0)
    μ ~ Normal(0, 1.5) |> iid(2)
    σ ~ HalfNormal(1)
    θ = [(m,σ) for m in μ]
    x ~ MixtureModel(Normal, θ, p) |> iid(N)
end
julia> rand(m) |> pairs
pairs(::NamedTuple) with 4 entries:
  :p => 0.9652
  :μ => [-1.20, 3.121]
  :σ => 1.320
  :x => [-1.780, -2.983, -1.629, ...]
rand(m_fwd; groundTruth...)
julia> m_fwd = m(:p,:μ,:σ)
@model (p, μ, σ) begin
    θ = [(m, σ) for m = μ]
    x ~ MixtureModel(Normal,θ,[p,1-p]) |> iid(100)
end
m = @model begin
    p ~ Uniform()
    μ ~ Normal(0, 1.5) |> iid(2)
    σ ~ HalfNormal(1)
    θ = [(m, σ) for m in μ]
    x ~ MixtureModel(Normal,θ,[p,1-p]) |> iid(100)
end
post = nuts(m_inv; groundTruth...)
julia> m_inv = m(:x)
@model x begin
    p ~ Uniform()
    μ ~ Normal(0, 1.5) |> iid(2)
    σ ~ HalfNormal(1)
    θ = [(m, σ) for m = μ]
    x ~ MixtureModel(Normal,θ,[p,1-p]) |> iid(100)
end
m = @model begin
    p ~ Uniform()
    μ ~ Normal(0, 1.5) |> iid(2)
    σ ~ HalfNormal(1)
    θ = [(m, σ) for m in μ]
    x ~ MixtureModel(Normal,θ,[p,1-p]) |> iid(100)
end
x
\theta_1
\theta_2
\theta_n
x^\text{rep}_1
x^\text{rep}_2
x^\text{rep}_2
\vdots
\vdots
x
julia> lda
@model (α, η, K, V, N) begin
    M = length(N)
    β ~ Dirichlet(repeat([η],V)) |> iid(K)
    θ ~ Dirichlet(repeat([α],K)) |> iid(M)
    z ~ For(1:M) do m
            Categorical(θ[m]) |> iid(N[m])
        end
    w ~ For(1:M) do m
            For(1:N[m]) do n
                Categorical(β[(z[m])[n]])
            end
        end
end
julia> dependencies(lda)
               [] => :α
               [] => :η
               [] => :K
               [] => :V
               [] => :N
             [:N] => :M
     [:η, :V, :K] => :β
     [:α, :K, :M] => :θ
     [:M, :θ, :N] => :z
 [:M, :N, :β, :z] => :w

Blei, Ng, & Jordan (2003). Latent Dirichlet Allocation. JMLR, 3(4–5), 993–1022.

m = @model y begin
    μ ~ Normal(0, 1)
    σ ~ HalfCauchy(1)
    ε ~ TDist(5)
    y ~ Normal(μ + ε, σ)
end

symlogpdf(m)


\sum_{j=1}^{N} \left(- \frac{μ^{2}}{2} \\ - \log{\left(σ \right)} \\ - \log{\left(σ^{2} + 1 \right)} \\ - 3 \log{\left(\frac{ε^{2}{\left(j \right)}}{5} + 1 \right)} \\ - 2 \log{\left(\pi \right)} - \frac{\log{\left(5 \right)}}{2} - \log{\left(2 \right)} - \log{\left(\operatorname{B}\left(\frac{1}{2}, \frac{5}{2}\right) \right)} - \frac{\left(- μ + y{\left(j \right)} - ε{\left(j \right)}\right)^{2}}{2 σ^{2}}\right)


-N\left[2\log\left(\pi\right)+\frac{\log\left(5\right)}{2}+\log2+\log\left(B\left(\frac{1}{2},\frac{5}{2}\right)\right)\right]
-N\left[\frac{μ^{2}}{2}+\log\left(σ\right)+\log\left(σ^{2}+1\right)\right]
-\frac{1}{2σ^{2}} \sum_{j=1}^{N}\left(-μ+y\left(j\right)-ε\left(j\right)\right)^{2}
-3\sum_{j=1}^{N}\log\left(\frac{ε\left(j\right)^{2}}{5}+1\right)
p = @model begin
    x ~ Normal()
end

q = @model σ begin
    x ~ HalfNormal(σ)
end
julia> sourceRand(q)
:(function ##rand#726(args...; kwargs...)
      @unpack (σ,) = kwargs
      x = rand(HalfNormal(σ))
      (x = x,)
  end)


julia> sourceImportanceLogWeights(p,q)
:(function ##logimportance#725(pars)
      @unpack (x, σ) = pars
      ℓ = 0.0
      ℓ += logpdf(Normal(), x)
      ℓ -= logpdf(HalfNormal(σ), x)
      return ℓ
  end)
p = @model x begin
    μ ~ Normal()
    x ~ Normal(μ,1) |> iid(20)
end

q = @model m,s begin
    μ ~ Normal(m,s)
end
julia> sourceRand(q)
:(function ##rand#726(args...; kwargs...)
      @unpack (σ,) = kwargs
      x = rand(HalfNormal(σ))
      (x = x,)
  end)


julia> sourceImportanceLogWeights(p,q)
:(function ##logimportance#725(pars)
      @unpack (x, σ) = pars
      ℓ = 0.0
      ℓ += logpdf(Normal(), x)
      ℓ -= logpdf(HalfNormal(σ), x)
      return ℓ
  end)

logo by James Fairbanks, on Julia  Discourse site

JuliaCon 2019

July 21-27

Baltimore, MD

Thank You!

2019-05-01 ODSC East

By Chad Scherrer

2019-05-01 ODSC East

  • 1,266