Skip to content

ThreadSafeVarInfo and threadid #924

Open
@penelopeysm

Description

@penelopeysm

Introduction

Currently, ThreadSafeVarInfo creates an array of length Threads.nthreads() to store logp values accumulated in each thread:

function ThreadSafeVarInfo(vi::AbstractVarInfo)
return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()])
end

It then adds to logps[Threads.threadid()]:

function acclogp!!(vi::ThreadSafeVarInfo, logp)
vi.logps[Threads.threadid()] += logp
return vi
end
function acclogp!!(vi::ThreadSafeVarInfoWithRef, logp)
vi.logps[Threads.threadid()][] += logp
return vi
end

Although ThreadSafeVarInfo has been changed a bit by the accumulators PR (#885), the thread ID indexing behaviour described above still remains.

Now, this has worked fine up until Julia 1.11. However, in Julia 1.12, this breaks, because Threads.threadid() returns a value that is larger than Threads.nthreads() — as seen in CI of #921 (link to failing run) and more clearly demonstrated here:

Julia 1.12, 1 thread

julia> versioninfo()
Julia Version 1.12.0-beta3
Commit faca79b503a (2025-05-12 06:47 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: macOS (arm64-apple-darwin24.0.0)
  CPU: 10 × Apple M1 Pro
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, apple-m1)
  GC: Built with stock GC
Threads: 1 default, 1 interactive, 1 GC (on 8 virtual cores)

julia> Threads.nthreads()
1

julia> Threads.@threads for i in 1:Threads.nthreads(); println(Threads.threadid()); end
2

julia> Threads.maxthreadid()
2

Julia 1.12, 4 threads

julia> versioninfo()
Julia Version 1.12.0-beta3
Commit faca79b503a (2025-05-12 06:47 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: macOS (arm64-apple-darwin24.0.0)
  CPU: 10 × Apple M1 Pro
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, apple-m1)
  GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 8 virtual cores)

julia> Threads.nthreads()
4

julia> Threads.@threads for i in 1:Threads.nthreads(); println(Threads.threadid()); end
2
5
3
4

julia> Threads.maxthreadid()
8

Possible solutions

1. Use maxthreadid() instead of nthreads()

This would be the quickest, hackiest, fix. It is not ideal, but it is not really any worse than the current situation, and could tide us over for some time while we figure out a proper solution.

(Actually, there is an even more hacky fix: in acclogp, we can index into the vector with threadid() - 1 instead of threadid(). I assume we don't want to go there.)

2. Rewrite ThreadSafeVarInfo to use a lock

Probably the best, but lots of work. In my opinion, I don't think that this amount of work is worth it, unless it allowed us to extend the 'thread safety' to assume-statements (and not just observe-statements).

3. Disallow tilde-statements inside Threads.@threads

Right now, we allow observe-statements to happen inside Threads.@threads (but not assume-statements). Observe-statements can, of course, be replaced with calls to @addlogprob!. For example, the following model breaks on Julia 1.12 (with any number of threads):

julia> @model function f(x)
           a ~ Normal()
           Threads.@threads for i in eachindex(x)
               x[i] ~ Normal(a)
           end
       end
f (generic function with 2 methods)

julia> model = f(Float64.(1:10))
Model{typeof(f), (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}, DefaultContext}(f, (x = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],), NamedTuple(), DefaultContext())

julia> vi = VarInfo(model)
ERROR: TaskFailedException
[...]

The following model, however, is equivalent (and the use of Threads.@spawn is "officially correct", see https://julialang.org/blog/2023/07/PSA-dont-use-threadid/):

julia> @model function g(x)
           a ~ Normal()
           logps = map(x) do xi
               Threads.@spawn logpdf(Normal(a), xi)
           end
           @addlogprob! sum(fetch.(logps))
       end
g (generic function with 2 methods)

julia> model = g(Float64.(1:10))
Model{typeof(g), (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}, DefaultContext}(g, (x = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],), NamedTuple(), DefaultContext())

julia> vi = VarInfo(model)
VarInfo{@NamedTuple{a::DynamicPPL.Metadata{Dict{VarName{:a, typeof(identity)}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:a, typeof(identity)}}, Vector{Float64}}}, Float64}((a = DynamicPPL.Metadata{Dict{VarName{:a, typeof(identity)}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:a, typeof(identity)}}, Vector{Float64}}(Dict(a => 1), [a], UnitRange{Int64}[1:1], [-0.011060756850626022], Normal{Float64}[Normal{Float64}=0.0, σ=1.0)], [0], Dict{String, BitVector}("del" => [0], "trans" => [0])),), Base.RefValue{Float64}(-203.2173383639174), Base.RefValue{Int64}(0))

And of course, people can use whatever threading library they like (e.g. FLoops.jl) too, as long as there are no tilde-statements in the parallelised code.

Note that if we disallowed multithreaded tilde-statements, this also implies that ThreadSafeVarInfo could be entirely removed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions