Skip to content

Conversation

@devmotion
Copy link
Member

This PR includes the proper implementation of the idea discussed in #79 (comment).

It's not clear from the discussion that we've had so far that we actually want to implement thread-safe accumulation of log probabilities in this way (BTW notice that this will only allow to use multiple threads when looping over observations since many more parts than just accumulating log probabilities are not threadsafe when sampling variables). The advantages of the approach in this PR is that it does not modify the implementation of VarInfo and hence does not break any code that operates on VarInfo. However, accessing the current log probability inside of the model would be a bit more cumbersome, it could be accessed by sum(_logps) (or sum(_logps) + getlogp(_varinfo) if the log probability of _varinfo was not zero before running the model) or _logps[Threads.threadid()] for each thread. The alternative approach discussed in #79 would be using a vector for _varinfo.logp instead. The advantage would be that in the model the current log probability could be accessed by getlogp(_varinfo) if it is defined as getlogp(_varinfo) = sum(_varinfo.logp). However, in that case every call of getlogp would recompute the sum again and the size of _varinfo.logp would depend on the number of threads, which would be problematic when saving and loading VarInfo objects in different Julia sessions.

BTW with this PR it is still possible to use lines such as

_varinfo.logp = -Inf
return

in the model, without having to deal with _logps.

@yebai
Copy link
Member

yebai commented Apr 29, 2020

(BTW notice that this will only allow to use multiple threads when looping over observations since many more parts than just accumulating log probabilities are not threadsafe when sampling variables)

I think the constraint is less strict for HMC and MH sampling since there is no real distinction between data and parameters. Maybe we can assert on sampler type in the implementation to avoid misuses.

@devmotion
Copy link
Member Author

devmotion commented Apr 29, 2020

The main problems are due to the implementation of VarInfo and the fact that adding samples to a VarInfo object is not threadsafe (at least a first simple example (basically just providing no x in the new test) crashed). One problem seems to be, e.g., that BitArray is not threadsafe (JuliaLang/julia#33750).

On another note, maybe one should rather use Channels instead of duplicating all containers, in particular when thinking about extending the approach in this PR to other parts of VarInfo. Then one could maybe push all samples (maybe in a struct with distribution and log probability) to one single channel that updates one single VarInfo object asynchronously. Maybe that would also make it easier to experiment with alternatives to VarInfo at some point.

I'm not sure though if Channels are well supported by all Julia versions >= 1.0?

@codecov
Copy link

codecov bot commented Apr 29, 2020

Codecov Report

Merging #89 into master will increase coverage by 0.18%.
The diff coverage is 92.85%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #89      +/-   ##
==========================================
+ Coverage   77.30%   77.49%   +0.18%     
==========================================
  Files          13       13              
  Lines         846      853       +7     
==========================================
+ Hits          654      661       +7     
  Misses        192      192              
Impacted Files Coverage Δ
src/DynamicPPL.jl 100.00% <ø> (ø)
src/context_implementations.jl 56.02% <83.33%> (ø)
src/compiler.jl 88.33% <100.00%> (+0.50%) ⬆️
src/utils.jl 55.10% <100.00%> (+1.91%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 98e46ea...f4ab1ee. Read the comment docs.

@mohamed82008
Copy link
Contributor

mohamed82008 commented Apr 29, 2020

I think we can consider having a VarInfo for each thread combining them at the end. This is a generalization of the approach in this PR and it should work nicely in all cases I believe, unless I am wrong of course :) So even variable sampling can be thread-safe. For HMC, this is not needed.

@devmotion
Copy link
Member Author

Hmm, I'm not sure, that seems like quite a bit of an overhead? I'm already not very excited about the fact that one would have an array of logps with this PR.

@mohamed82008
Copy link
Contributor

Hmm, I'm not sure, that seems like quite a bit of an overhead?

Yes which is why we should probably only have it when necessary. Normal HMC runs don't need it. This PR's approach is good enough. We only need it when "filling" the VarInfo initially. This means we will need different treatments for different samplers which is not too appealing to me. I am just thinking out loud here :) We can perhaps keep a flag for whether push! was called or not and only merge when needed. We will still need all the VarInfos to be around but the merging overhead can be skipped in most cases. In cases where push! is called, pushing and merging need to happen in any implementation so I think the overhead is acceptable.

@mohamed82008
Copy link
Contributor

Let's keep brainstorming though even after this PR goes in.

@mohamed82008
Copy link
Contributor

So just to elaborate on my proposal above. I am thinking we keep one "read" VarInfo and one write VarInfo for each thread. If all the variables are in the read VarInfo, then no push! is called, and no time overhead is there. If push! needs to be called, then each thread pushes to its write VarInfo and they all get merged into the read VarInfo at the end.

@devmotion
Copy link
Member Author

I'm leaning towards using Channels right now, but I haven't had time to think about an implementation yet. In principle that should work without replicating anything, one would just have an asynchronous process that takes care of push!ing to the VarInfo (or some other structure) and updating the accumulated log probability, and all the threads would just push to the channel.

@mohamed82008
Copy link
Contributor

an asynchronous process that takes care of push!ing to the VarInfo (or some other structure)

I don't think this would work. We need to know the range of indices in the vals vector corresponding to a VarName. If we let Julia decide which variable goes in first, this will mess with the ranges field of VarInfo.

@devmotion
Copy link
Member Author

As I said, it's not completely clear yet how that would work and maybe one would need another structure or push the samples to a dictionary first. It's just wild guesses and speculations so far 😄

@mohamed82008
Copy link
Contributor

No worries 😄

@mohamed82008
Copy link
Contributor

For the purposes of this PR, I really prefer that _logps stays inside VarInfo resized to the proper size at the beginning. This will make this approach more extensible in the future by changing VarInfo and not the model macro. More specifically, I have https://github.com/mohamed82008/TuringSparseDiff.jl in mind when writing this comment. This PR will complicate the LazyVarInfo implementation in the mt/sparsefd branch of DPPL even for a single thread.

@devmotion
Copy link
Member Author

Hmm I'm a bit surprised about that, I would have thought that not having to resize the logp of every possible VarInfo implementation (and hence also not having to implement a resize functionality, how to deal with vectorized logps etc) would be simpler and make it easier to plug in other implementations. At least at a first glance it seemed quite unintuitive to change the logp size of the VarInfo at every model call (possibly) and, e.g., write the sum of existing values to the first entry of this vector, sum over the whole array every time you call getlogp (if you don't have another cache or resizing step after the model call), where to write to if you just call acclogp! etc. With the approach (even though it surely is not optimal) no VarInfo implementation has to deal with how the logp values are accumulated, it only has to implement acclogp! to update its logp value in the end. Can you explain what problems this causes for the lazy version you mentioned?

@yebai
Copy link
Member

yebai commented Apr 29, 2020

We probably want to keep VarInfo independent of threading if possible. Some advanced interface for expert users to write threading-enabled code would be useful, but supporting threading in a generic way might be overkill and would substantially increase complexity of DynamicPPL. Besides, even if we can make VarInfo threading safe, other Julia libraries, e.g. ReverseDiff might be threading unsafe anyway.

@mohamed82008
Copy link
Contributor

Can you explain what problems this causes for the lazy version you mentioned?

So what I am trying to do there is to accumulate all the individual logp values in a vector to expose more structure in the jacobian of the logp vector function. SparsityDetection then detects the sparsity structure in the jacobian. This can enable differentiating all the logp values wrt many parameters in a single forward pass in ForwardDiff because the parameters don't interact at all (except when summing the logp values at the end). Calling acclogp! in every ~ is therefore needed for this to work because I don't add, I push.

We probably want to keep VarInfo independent of threading if possible. Some advanced interface for expert users to write threading-enabled code would be useful, but supporting threading in a generic way might be overkill and would substantially increase complexity of DynamicPPL. Besides, even if we can make VarInfo threading safe, other Julia libraries, e.g. ReverseDiff might be threading unsafe anyway.

I don't agree with this. VarInfo is the main data structure we use in DPPL. Making it thread-safe will have a direct impact on sampling performance when using ForwardDiff. This was observed in the Covid model but Tor had to work-around the lack of thread safety in VarInfo. Even ReverseDiff can be made thread-safe in principle, it just wasn't done before. And personally I would rather fiddle with VarInfo and dispatch (normal Julia code) than with the model macro (awkward Julia code) so I think we should keep the macro as simple as possible and do the rest with dispatch and normal functions.

@devmotion
Copy link
Member Author

Calling acclogp! in every ~ is therefore needed for this to work because I don't add, I push.

Note that push!ing is not thread-safe (and probably never will, see, e.g., https://discourse.julialang.org/t/can-dicts-be-threadsafe/27172/6?u=devmotion). A simple example:

julia> function f(x)
           Threads.@threads for i in 1:100
               push!(x, i)
           end
           x
       end
f (generic function with 1 method)
                                                               
julia> f(Int[])
62-element Array{Int64,1}:
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
  0
  0
  0
...

To support your use-case (and any other possible alternative of VarInfo) we should maybe really just switch to a Channel-based implementation. It seems that would be the easiest way to ensure that everything is thread-safe since only one process would call acclogp! (and maybe push! new samples in an extension of this approach). The implementation of VarInfo (or its alternatives) would never have to care about what acclogp! is actually doing (and hence the implementation of VarInfo does not specialize on threading at all), but it would be guaranteed that the model evaluation happens in a threadsafe way. In the example above that would be, e.g.:

julia> function f(x)
           channel = Channel{Int}()
           @sync begin
               @async begin
                   while true
                       n = take!(channel)
                       n < 0 && break
                       
                       push!(x, n)
                   end
               end
               
               @async begin
                   Threads.@threads for i in 1:100
                       put!(channel, i)
                   end
                   put!(channel, -1)
               end
           end
           
           return x
       end             
                                                               
f (generic function with 1 method)

julia> f(Int[])
100-element Array{Int64,1}:
  76
  77
  78
  26
   1
   2
   3
  51
  79
  80
  52
  53
  27
  81
  54
   4
...

julia> length(unique(ans))
100

@mohamed82008
Copy link
Contributor

mohamed82008 commented Apr 30, 2020

I am aware that push! is not thread safe. And I am not trying to make LazyVarInfo thread-safe either. Even if I use a vector for each thread inside LazyVarInfo and let each thread only push to its corresponding vector, SparseDiff won't work unless the tasks are statically assigned to the threads in the multi-threaded parts of the model. This is because I need to guarantee that the order of logp values will not change between runs for the jacobian sparsity pattern to remain valid. So I am not trying to make TuringSparseDiff multi-threaded. But with this PR, even a single-threaded implementation becomes impossible because the logp values get added to _logps then only at the end that acclogp! is called which is the only part I can control from outside the model macro using dispatch.

@mohamed82008
Copy link
Contributor

Btw here is the entire implementation of LazyVarInfo https://github.com/TuringLang/DynamicPPL.jl/blob/mt/sparsefd/src/lazyvarinfo.jl.

@yebai
Copy link
Member

yebai commented May 1, 2020

This PR is helpful for the COVID model. Shall we merge this PR as-is for now, and warn the user that this feature should be used with caution (e.g. only for cases where model dimensionality is frozen)? We can probably revisit this design in a separate PR.

@mohamed82008
Copy link
Contributor

I am fine with the PR but let's take _logps inside VarInfo. It's not that hard and it makes life easier in the future.

@mohamed82008
Copy link
Contributor

Or I can do it in another PR.

@mohamed82008 mohamed82008 merged commit 180458e into master May 1, 2020
@delete-merged-branch delete-merged-branch bot deleted the threads branch May 1, 2020 12:58
@mohamed82008
Copy link
Contributor

Thanks @devmotion!

sampler::$(DynamicPPL.AbstractSampler),
context::$(DynamicPPL.AbstractContext),
)
logps = $(DynamicPPL.initlogps)(varinfo)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this will create a type instability. I just noticed it now. Perhaps another reason to have logps inside VarInfo and resize it accordingly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's please fix this in another PR before releasing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's due to our use of StaticArrays here and JuliaLang/julia#34902. I guess we should just use arrays, compared to all other allocations and model evaluation that should not matter at all.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree 👍

@devmotion
Copy link
Member Author

Sorry I've been quite busy today, just saw your messages now.

I am fine with the PR but let's take _logps inside VarInfo. It's not that hard and it makes life easier in the future.

Please don't do this. While I agree that the adhoc solution in this PR is not the best and final design, I am very much convinced that the threading part and the current PPL part of VarInfo should be kept separate as much as possible (IMO the current implementation is already quite difficult to understand, and one should really try to not complicate it further if not needed). Hence I've already played around with different alternatives, also, e.g., based on Channels (but unfortunately it seems that they are problematic with task copying, so maybe that would require some other changes as well).

Instead of adding logps to VarInfo I suggest the following:

  • Make a ThreadSafeVarInfo that just wraps an AbstractVarInfo object and contains an logps array for now (one could switch to Channels, include samples, etc later if needed)
  • Forward all required method for AbstractVarInfo to the underlying AbstractVarInfo object, but implement
    getlogp(v::ThreadSafeVarInfo) = sum(v.logps) + getlogp(v.varinfo)
    function acclogp!(v::ThreadSafeVarInfo, val)
        v.logps[Threads.threadid()] += val
    end
  • Revert the changes in this PR and just change the implementation of
    function (model::Model)(vi, spl, context)
        # some stuff in the current implementation that I don't remember...
    
        Threads.threadid() == 1 && return model.f(model, vi, spl, context)
    
        wrapper = ThreadSafeVarInfo(vi)
        result = model.f(model, wrapper, spl, context)
        acclogp!(vi, sum(wrapper.logps))
        return result
    end

The advantages would be:

  • The threading implementation is clearly separated from all other logic, it only enters in the definition of (model::Model)(...)
  • We don't create any strange arguments or internal variables just for threading
  • It is compatible with LazyVarInfo (and any other AbstractVarInfo): if multi-threading is not supported, you just redefine (model::Model)(vi::MyVarInfo, ...) without dealing with any other internals (BTW, if one could use Channel, then LazyVarInfo wouldn't have to specialize here - as long as the user does not use multi-threading in the model definition, our implementation would guarantee that the logps appear deterministically in the same order in the Channel and hence the logps array in LazyVarInfo would just constructed in the same way every time).
  • At any time during model evaluation (outside of Threads.@threads blocks, of course), getlogp(_varinfo) would actually return the correct accumulated value of logp at the current point.

@mohamed82008
Copy link
Contributor

I like the idea of ThreadSafeVarInfo, I will go with that thanks! Channels would be interesting, but we need to benchmark it to make sure it is not slower than the approach we have now.

@devmotion
Copy link
Member Author

I already have an implementation.

@devmotion
Copy link
Member Author

The only problem is the ridiculous amount of methods one has to define.

@mohamed82008
Copy link
Contributor

@devmotion
Copy link
Member Author

No, they are not sufficient. Of course, I started with them but one has to add more.

@mohamed82008
Copy link
Contributor

True, you have to change some signatures from VarInfo to AbstractVarInfo.

@mohamed82008
Copy link
Contributor

mohamed82008 commented May 1, 2020

Please take a look at the mt/sparsefd branch.

@devmotion devmotion mentioned this pull request May 1, 2020
@devmotion
Copy link
Member Author

That's all not sufficient, Turing has to be updated as well (see the WIP PR).

@mohamed82008
Copy link
Contributor

Sure, I will review it tomorrow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants