Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.17.1"
version = "0.17.2"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
16 changes: 13 additions & 3 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,16 @@ number of `sampler`.
"""
(model::Model)(args...) = first(evaluate!!(model, args...))

"""
use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)

Return `true` if evaluation of a model using `context` and `varinfo` should
wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise.
"""
function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
return Threads.nthreads() > 1
end

"""
evaluate!!(model::Model[, rng, varinfo, sampler, context])

Expand All @@ -388,10 +398,10 @@ The method resets the log joint probability of `varinfo` and increases the evalu
number of `sampler`.
"""
function evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext)
if Threads.nthreads() == 1
return evaluate_threadunsafe!!(model, varinfo, context)
return if use_threadsafe_eval(context, varinfo)
evaluate_threadsafe!!(model, varinfo, context)
else
return evaluate_threadsafe!!(model, varinfo, context)
evaluate_threadunsafe!!(model, varinfo, context)
end
end

Expand Down