-
Notifications
You must be signed in to change notification settings - Fork 37
Make accumulation of log probabilities thread-safe #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I think the constraint is less strict for HMC and MH sampling since there is no real distinction between data and parameters. Maybe we can assert on sampler type in the implementation to avoid misuses. |
|
The main problems are due to the implementation of On another note, maybe one should rather use I'm not sure though if Channels are well supported by all Julia versions >= 1.0? |
Codecov Report
@@ Coverage Diff @@
## master #89 +/- ##
==========================================
+ Coverage 77.30% 77.49% +0.18%
==========================================
Files 13 13
Lines 846 853 +7
==========================================
+ Hits 654 661 +7
Misses 192 192
Continue to review full report at Codecov.
|
|
I think we can consider having a |
|
Hmm, I'm not sure, that seems like quite a bit of an overhead? I'm already not very excited about the fact that one would have an array of logps with this PR. |
Yes which is why we should probably only have it when necessary. Normal HMC runs don't need it. This PR's approach is good enough. We only need it when "filling" the VarInfo initially. This means we will need different treatments for different samplers which is not too appealing to me. I am just thinking out loud here :) We can perhaps keep a flag for whether |
|
Let's keep brainstorming though even after this PR goes in. |
|
So just to elaborate on my proposal above. I am thinking we keep one "read" VarInfo and one write VarInfo for each thread. If all the variables are in the read VarInfo, then no |
|
I'm leaning towards using Channels right now, but I haven't had time to think about an implementation yet. In principle that should work without replicating anything, one would just have an asynchronous process that takes care of |
I don't think this would work. We need to know the range of indices in the |
|
As I said, it's not completely clear yet how that would work and maybe one would need another structure or push the samples to a dictionary first. It's just wild guesses and speculations so far 😄 |
|
No worries 😄 |
|
For the purposes of this PR, I really prefer that |
|
Hmm I'm a bit surprised about that, I would have thought that not having to resize the logp of every possible VarInfo implementation (and hence also not having to implement a resize functionality, how to deal with vectorized logps etc) would be simpler and make it easier to plug in other implementations. At least at a first glance it seemed quite unintuitive to change the logp size of the VarInfo at every model call (possibly) and, e.g., write the sum of existing values to the first entry of this vector, sum over the whole array every time you call |
|
We probably want to keep |
So what I am trying to do there is to accumulate all the individual
I don't agree with this. |
Note that julia> function f(x)
Threads.@threads for i in 1:100
push!(x, i)
end
x
end
f (generic function with 1 method)
julia> f(Int[])
62-element Array{Int64,1}:
51
52
53
54
55
56
57
58
59
60
61
62
63
64
0
0
0
...To support your use-case (and any other possible alternative of julia> function f(x)
channel = Channel{Int}()
@sync begin
@async begin
while true
n = take!(channel)
n < 0 && break
push!(x, n)
end
end
@async begin
Threads.@threads for i in 1:100
put!(channel, i)
end
put!(channel, -1)
end
end
return x
end
f (generic function with 1 method)
julia> f(Int[])
100-element Array{Int64,1}:
76
77
78
26
1
2
3
51
79
80
52
53
27
81
54
4
...
julia> length(unique(ans))
100 |
|
I am aware that |
|
Btw here is the entire implementation of |
|
This PR is helpful for the COVID model. Shall we merge this PR as-is for now, and warn the user that this feature should be used with caution (e.g. only for cases where model dimensionality is frozen)? We can probably revisit this design in a separate PR. |
|
I am fine with the PR but let's take |
|
Or I can do it in another PR. |
|
Thanks @devmotion! |
| sampler::$(DynamicPPL.AbstractSampler), | ||
| context::$(DynamicPPL.AbstractContext), | ||
| ) | ||
| logps = $(DynamicPPL.initlogps)(varinfo) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this will create a type instability. I just noticed it now. Perhaps another reason to have logps inside VarInfo and resize it accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's please fix this in another PR before releasing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's due to our use of StaticArrays here and JuliaLang/julia#34902. I guess we should just use arrays, compared to all other allocations and model evaluation that should not matter at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree 👍
|
Sorry I've been quite busy today, just saw your messages now.
Please don't do this. While I agree that the adhoc solution in this PR is not the best and final design, I am very much convinced that the threading part and the current PPL part of VarInfo should be kept separate as much as possible (IMO the current implementation is already quite difficult to understand, and one should really try to not complicate it further if not needed). Hence I've already played around with different alternatives, also, e.g., based on Channels (but unfortunately it seems that they are problematic with task copying, so maybe that would require some other changes as well). Instead of adding
The advantages would be:
|
|
I like the idea of |
|
I already have an implementation. |
|
The only problem is the ridiculous amount of methods one has to define. |
|
No, they are not sufficient. Of course, I started with them but one has to add more. |
|
True, you have to change some signatures from VarInfo to AbstractVarInfo. |
|
Please take a look at the |
|
That's all not sufficient, Turing has to be updated as well (see the WIP PR). |
|
Sure, I will review it tomorrow. |
This PR includes the proper implementation of the idea discussed in #79 (comment).
It's not clear from the discussion that we've had so far that we actually want to implement thread-safe accumulation of log probabilities in this way (BTW notice that this will only allow to use multiple threads when looping over observations since many more parts than just accumulating log probabilities are not threadsafe when sampling variables). The advantages of the approach in this PR is that it does not modify the implementation of
VarInfoand hence does not break any code that operates onVarInfo. However, accessing the current log probability inside of the model would be a bit more cumbersome, it could be accessed bysum(_logps)(orsum(_logps) + getlogp(_varinfo)if the log probability of_varinfowas not zero before running the model) or_logps[Threads.threadid()]for each thread. The alternative approach discussed in #79 would be using a vector for_varinfo.logpinstead. The advantage would be that in the model the current log probability could be accessed bygetlogp(_varinfo)if it is defined asgetlogp(_varinfo) = sum(_varinfo.logp). However, in that case every call ofgetlogpwould recompute the sum again and the size of_varinfo.logpwould depend on the number of threads, which would be problematic when saving and loadingVarInfoobjects in different Julia sessions.BTW with this PR it is still possible to use lines such as
in the model, without having to deal with
_logps.