Description
Introduction
Currently, ThreadSafeVarInfo
creates an array of length Threads.nthreads()
to store logp values accumulated in each thread:
DynamicPPL.jl/src/threadsafe.jl
Lines 11 to 13 in cdeb657
It then adds to logps[Threads.threadid()]
:
DynamicPPL.jl/src/threadsafe.jl
Lines 24 to 31 in cdeb657
Although ThreadSafeVarInfo
has been changed a bit by the accumulators PR (#885), the thread ID indexing behaviour described above still remains.
Now, this has worked fine up until Julia 1.11. However, in Julia 1.12, this breaks, because Threads.threadid()
returns a value that is larger than Threads.nthreads()
— as seen in CI of #921 (link to failing run) and more clearly demonstrated here:
Julia 1.12, 1 thread
julia> versioninfo()
Julia Version 1.12.0-beta3
Commit faca79b503a (2025-05-12 06:47 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: macOS (arm64-apple-darwin24.0.0)
CPU: 10 × Apple M1 Pro
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, apple-m1)
GC: Built with stock GC
Threads: 1 default, 1 interactive, 1 GC (on 8 virtual cores)
julia> Threads.nthreads()
1
julia> Threads.@threads for i in 1:Threads.nthreads(); println(Threads.threadid()); end
2
julia> Threads.maxthreadid()
2
Julia 1.12, 4 threads
julia> versioninfo()
Julia Version 1.12.0-beta3
Commit faca79b503a (2025-05-12 06:47 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: macOS (arm64-apple-darwin24.0.0)
CPU: 10 × Apple M1 Pro
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, apple-m1)
GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 8 virtual cores)
julia> Threads.nthreads()
4
julia> Threads.@threads for i in 1:Threads.nthreads(); println(Threads.threadid()); end
2
5
3
4
julia> Threads.maxthreadid()
8
Possible solutions
1. Use maxthreadid()
instead of nthreads()
This would be the quickest, hackiest, fix. It is not ideal, but it is not really any worse than the current situation, and could tide us over for some time while we figure out a proper solution.
(Actually, there is an even more hacky fix: in acclogp
, we can index into the vector with threadid() - 1
instead of threadid()
. I assume we don't want to go there.)
2. Rewrite ThreadSafeVarInfo
to use a lock
Probably the best, but lots of work. In my opinion, I don't think that this amount of work is worth it, unless it allowed us to extend the 'thread safety' to assume-statements (and not just observe-statements).
3. Disallow tilde-statements inside Threads.@threads
Right now, we allow observe-statements to happen inside Threads.@threads
(but not assume-statements). Observe-statements can, of course, be replaced with calls to @addlogprob!
. For example, the following model breaks on Julia 1.12 (with any number of threads):
julia> @model function f(x)
a ~ Normal()
Threads.@threads for i in eachindex(x)
x[i] ~ Normal(a)
end
end
f (generic function with 2 methods)
julia> model = f(Float64.(1:10))
Model{typeof(f), (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}, DefaultContext}(f, (x = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],), NamedTuple(), DefaultContext())
julia> vi = VarInfo(model)
ERROR: TaskFailedException
[...]
The following model, however, is equivalent (and the use of Threads.@spawn
is "officially correct", see https://julialang.org/blog/2023/07/PSA-dont-use-threadid/):
julia> @model function g(x)
a ~ Normal()
logps = map(x) do xi
Threads.@spawn logpdf(Normal(a), xi)
end
@addlogprob! sum(fetch.(logps))
end
g (generic function with 2 methods)
julia> model = g(Float64.(1:10))
Model{typeof(g), (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}, DefaultContext}(g, (x = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],), NamedTuple(), DefaultContext())
julia> vi = VarInfo(model)
VarInfo{@NamedTuple{a::DynamicPPL.Metadata{Dict{VarName{:a, typeof(identity)}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:a, typeof(identity)}}, Vector{Float64}}}, Float64}((a = DynamicPPL.Metadata{Dict{VarName{:a, typeof(identity)}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:a, typeof(identity)}}, Vector{Float64}}(Dict(a => 1), [a], UnitRange{Int64}[1:1], [-0.011060756850626022], Normal{Float64}[Normal{Float64}(μ=0.0, σ=1.0)], [0], Dict{String, BitVector}("del" => [0], "trans" => [0])),), Base.RefValue{Float64}(-203.2173383639174), Base.RefValue{Int64}(0))
And of course, people can use whatever threading library they like (e.g. FLoops.jl) too, as long as there are no tilde-statements in the parallelised code.
Note that if we disallowed multithreaded tilde-statements, this also implies that ThreadSafeVarInfo
could be entirely removed.