Description
TL/DR: It would be very useful if a models that allocate an array of parameters and then sample them could be made compatible with AD engines that don't support mutation. Barring that, such models should at least be made to work fine with these ADs if one conditions that model on that array of parameters.
Suppose we want a simple linear regression model from which we can compute pointwise log-likelihoods. This is how we can write the model including data as input:
using Turing
@model function regress(x, y)
α ~ Normal()
β ~ filldist(Normal(), size(x, 2))
σ ~ truncated(Normal(); lower=0)
μ = muladd(x, β, α)
for i in eachindex(μ, y)
y[i] ~ Normal(μ[i], σ)
end
end
But if we want to instead write the model as a joint distribution of y
internally:
@model function regress2(x)
α ~ Normal()
β ~ filldist(Normal(), size(x, 2))
σ ~ truncated(Normal(); lower=0)
μ = muladd(x, β, α)
y = similar(vec(μ), Base.promote_eltype(μ, σ))
for i in eachindex(μ, y)
y[i] ~ Normal(μ[i], σ)
end
end
Now we can sample from this joint distribution, and we can condition on y
at all. The usually recommended solution of using arraydist
or filldist
here is not possible if one wants pointwise log-likelihoods eventually.
Here's a minimal failing example:
julia> using Turing, Zygote
julia> Turing.setadbackend(:zygote);
julia> @model function regress2(x)
α ~ Normal()
β ~ filldist(Normal(), size(x, 2))
σ ~ truncated(Normal(); lower=0)
μ = muladd(x, β, α)
y = similar(vec(μ), Base.promote_eltype(μ, σ))
for i in eachindex(μ, y)
y[i] ~ Normal(μ[i], σ)
end
end;
julia> x = 0:0.1:π;
julia> y = sin.(x) .+ 2 .+ randn.() .* 0.1;
julia> model = regress2(Matrix(reshape(x, :, 1)));
julia> model_cond = model | (; y=y);
julia> sample(model, NUTS(), 500); # unsurprising failure; model allocates y
ERROR: Mutating arrays is not supported -- called setindex!(::Vector{Float64}, _...)
...
julia> sample(model_cond, NUTS(), 500); # surprising failure; y is allocated but unused since given
ERROR: Mutating arrays is not supported -- called setindex!(::Vector{Float64}, _...)
...