Skip to content

Jags-style samplers  #905

@itsdfish

Description

@itsdfish

I am opening this feature request after a discussion on Slack regarding the performance of PG. For continuous parameters in particular, particles tend to get stuck. It's not clear to me to what extent this may happen for discrete parameters. Here is an example:

using Turing,Random,StatsPlots
@model model(y) = begin
    μ ~ Normal(0,10)
    σ ~ Truncated(Cauchy(0,1),0,Inf)
    for j in 1:length(y)
        y[j] ~ Normal(μ,σ)
    end
end
Random.seed!(3431)
y = rand(Normal(0,1),50)
chain = sample(model(y),PG(40,4000))
chain = chain[2001:end,:,:]
println(chain)
plot(chain)

fig4

This required about 2.5 minutes to run on my system. Increasing the number of particles to 80 did not help much.

As a basis for comparison, here is the same model coded in Jags:

ENV["JAGS_HOME"] = "usr/bin/jags" #your path here
using Jags, StatsPlots, Random, Distributions
#cd(@__DIR__)
ProjDir = pwd()
Random.seed!(3431)

y = rand(Normal(0,1),50)

Model = "
model {
      for (i in 1:length(y)) {
            y[i] ~ dnorm(mu,sigma);
      }
      mu  ~ dnorm(0, 1/sqrt(10));
      sigma  ~ dt(0,1,1) T(0, );
  }
"

monitors = Dict(
  "mu" => true,
  "sigma" => true,
  )

jagsmodel = Jagsmodel(
  name="Gaussian",
  model=Model ,
  monitor=monitors,
  ncommands=4, nchains=1,
  #deviance=true, dic=true, popt=true,
  pdir=ProjDir
  )

println("\nJagsmodel that will be used:")
jagsmodel |> display

data = Dict{String, Any}(
  "y" => y,
)

inits = [
  Dict("mu" => 0.0,"sigma" => 1.0,
  ".RNG.name" => "base::Mersenne-Twister")
]

println("Input observed data dictionary:")
data |> display
println("\nInput initial values dictionary:")
inits |> display
println()
#######################################################################################
#                                 Estimate Parameters
#######################################################################################
sim = jags(jagsmodel, data, inits, ProjDir)
sim = sim[5001:end,:,:]
plot(sim)

jags

This required about .267 seconds on my machine, which is nearly a 600 fold speed up.

Here is a second example we found to perform poorly:

using Distributions
using Turing

n=500
p=20
X = rand(Float64, (n,p))
beta=[2.0 .^ (-i) for i in 0:(p-1)]
alpha=0
sigma=0.7
eps=rand(Normal(0, sigma), n)
y = alpha .+ X * beta + eps;

@model model(X, y) = begin
    
    n, p = size(X)
    
    alpha ~ Normal(0,1)
    sigma ~ Truncated(Cauchy(0,1),0,Inf)
    sigma_beta ~ Truncated(Cauchy(0,1),0,Inf)
    pind ~ Beta(2,8)
    
    beta = tzeros(Float64, p)
    betaT = tzeros(Float64, p)
    ind = tzeros(Int, p)
    
    for j in 1:p
        ind[j] ~ Bernoulli(pind)
        betaT[j] ~ Normal(0,sigma_beta)  # random effect
        beta[j] = ind[j] * betaT[j]
    end
    
    mu = tzeros(Float64, n)
    
    for i in 1:n
        mu[i] = alpha + X[i,:]' * beta 
        y[i] ~ Normal(mu[i], sigma)
    end
    
end

steps = 4000
chain = sample(model(X,y),PG(40,steps))

I think this would be a very useful addition. By adding Jags-style samplers, we could have the speed of Jags without the severe limitations of Jags. This would also provide Turing with an ability that Stan struggles to perform.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions