Skip to content

Conversation

@torfjelde
Copy link
Member

@torfjelde torfjelde commented Nov 9, 2021

This is a sibling-PR of TuringLang/DynamicPPL.jl#309.

The aim here is to make the bare-minimum changes to ensure that everything works as before. I've not made any significant attempts at making samplers compatible with immutable implementations of AbstractVarInfo, e.g. SimpleVarInfo, but rather just marked these lines with a # TODO: make compatible with immutable vi which we can then address in a separate PR.

Related: #1725.

EDIT: Tests are passing locally, but they won't pass here until the PR in DPPL has been merged and released.

@torfjelde torfjelde mentioned this pull request Dec 4, 2021
This was referenced Dec 10, 2021
bors bot pushed a commit to TuringLang/DynamicPPL.jl that referenced this pull request Dec 15, 2021
On master we have the following behavior for a test-case in Turing.jl:

```julia
julia> @macroexpand @model empty_model() = begin x = 1; end
quote
    function empty_model(__model__::DynamicPPL.Model, __varinfo__::DynamicPPL.AbstractVarInfo, __context__::DynamicPPL.AbstractContext; )
        #= REPL[5]:1 =#
        begin
            #= REPL[5]:1 =#
            #= REPL[5]:1 =#
            return (x = 1, __varinfo__)
        end
    end
    begin
        $(Expr(:meta, :doc))
        function empty_model(; )
            #= REPL[5]:1 =#
            return (DynamicPPL.Model)(:empty_model, empty_model, NamedTuple(), NamedTuple())
        end
    end
end
```

Notice the `return` statement: it converted the statement `x = 1` which returns `1` into an attempt at a `NamedTuple{(:x, :__varinfo__)}`. On Julia 1.6 we don't really notice much of difference, because `first` and `last` will have the same behavior, but on Julia 1.3 the tests would fail in TuringLang/Turing.jl#1726 since "implicit" names in construction of `NamedTuple` isn't supported.

This PR addresses this issue by simply capturing the return-value in separate variable, which is then combined with `__varinfo__` in a `Tuple` at the end. This should both fail and succeed whenever standard Julia code would.
acclogp!!(vi, lp)
end
return r, 0
return r, 0, _vi
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the right thing to do @yebai @devmotion ?

Essentially the issue is that when we're using PG with threads, _vi is a ThreadSafeVarInfo while vi = AdvancedPS.current_trace().f.varinfo is not. And so if we return vi instead of _vi, subsequent code in evaluate_threadsafe!! will fail.

I'm not too familiar with the AdvancedPS stuff so it's a bit unclear to me what I should be returning here, e.g. is the above okay?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or should I overload evaluate!! and all the corresponding methods?

If we're doing this, we might want to consider making the "outer" calls, e.g. (::Model)(...) a bit more general, e.g. dispatch on AbstractProbabilisticProgram instead and then just overloading evaluate_threadsafe!! and evaluate_threadunsafe!! for TracedModel.

Copy link
Member

@yebai yebai Dec 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return r, 0, _vi
return r, 0, vi

Let's perhaps disallow the use of ThreadSafeVarInfo with PG/SMC for now. These samplers are not yet ready to run in threads parallelism due to the interaction of particles (i.e. the resampling step). I suggest that we simply throw an error here explaining that ThreadSafeVarInfo support for PG is to be supported in the near future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, there currently isn't a way to allow/disallow specific combinations like this. And erroring or even just telling people not to use PG/SMC just because they are running with --threads=4 seems wrong since they might not even be using threading inside of the model.

We could add a isthreadsafe(varinfo, context) check in https://github.com/TuringLang/DynamicPPL.jl/blob/57c50f1f46aa5e51ec2125f3ae4883c61e59e9c3/src/model.jl#L391 and then implement this specifically for SamplingContext{<:PG}, etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add a isthreadsafe(varinfo, context) check in https://github.com/TuringLang/DynamicPPL.jl/blob/57c50f1f46aa5e51ec2125f3ae4883c61e59e9c3/src/model.jl#L391 and then implement this specifically for SamplingContext{<:PG}, etc.

Yes, I think this would be a better approach. Then we can always choose the non-threadsafe VarInfo for these samplers instead of erroring if they run Julia with multiple threads (e.g. I always do, so I think it would be very inconvenient to throw an error here).

The only possible problem is that then it would be possible to use @threads with a non-threadsafe VarInfo which can lead to incorrect result silently. I don't think there's a clean way to detect if a model tries to use multiple threads (since it is not as simple as detecting a @threads macro call). I wonder if it is possible to add a locking mechanism to the default non-threadsafe VarInfo that makes it threadsafe while not impacting performance if it is run with a single thread. Surely it would be slower if multiple threads are used but with multiple threads we would use the more performant ThreadSafeVarInfo anyway - if we don't use PG etc. The main motivation would be to avoid silent problems with multithreaded models when PG + a non-threadsafe VarInfo is used - it seems fine if it is less performant since it is not recommended.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this would be a better approach. Then we can always choose the non-threadsafe VarInfo for these samplers instead of erroring if they run Julia with multiple threads (e.g. I always do, so I think it would be very inconvenient to throw an error here).

I'll do that then 👍

The only possible problem is that then it would be possible to use @threads with a non-threadsafe VarInfo which can lead to incorrect result silently.

Yup 😕 We could potentially address this by using atomics (though it would probably require very recent Julia versions to be able to do this properly) which I expect would behave just like a Ref in the single-threaded case, or even just making the default VarInfo threadsafe by using an array instead of Ref for the logp field.

But it's worth mentioning that currently PG, etc. already silently does the wrong thing, doesn't it?

@codecov
Copy link

codecov bot commented Dec 15, 2021

Codecov Report

Merging #1726 (4c1db31) into master (60b7d3f) will increase coverage by 0.12%.
The diff coverage is 84.84%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1726      +/-   ##
==========================================
+ Coverage   81.16%   81.29%   +0.12%     
==========================================
  Files          24       24              
  Lines        1476     1470       -6     
==========================================
- Hits         1198     1195       -3     
+ Misses        278      275       -3     
Impacted Files Coverage Δ
src/core/container.jl 76.19% <0.00%> (ø)
src/inference/Inference.jl 84.55% <ø> (ø)
src/inference/emcee.jl 91.66% <66.66%> (+3.43%) ⬆️
src/inference/hmc.jl 76.36% <71.42%> (+0.03%) ⬆️
src/modes/ModeEstimation.jl 83.73% <84.61%> (ø)
src/inference/mh.jl 85.82% <86.66%> (+0.10%) ⬆️
src/inference/ess.jl 97.91% <90.00%> (-0.05%) ⬇️
src/contrib/inference/dynamichmc.jl 100.00% <100.00%> (ø)
src/core/ad.jl 81.01% <100.00%> (ø)
src/inference/AdvancedSMC.jl 97.43% <100.00%> (+0.02%) ⬆️
... and 3 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 60b7d3f...4c1db31. Read the comment docs.

@coveralls
Copy link

coveralls commented Dec 15, 2021

Pull Request Test Coverage Report for Build 1588726417

  • 84 of 99 (84.85%) changed or added relevant lines in 12 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+0.1%) to 81.293%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/core/container.jl 0 1 0.0%
src/inference/emcee.jl 2 3 66.67%
src/inference/ess.jl 9 10 90.0%
src/inference/mh.jl 13 15 86.67%
src/modes/ModeEstimation.jl 22 26 84.62%
src/inference/hmc.jl 15 21 71.43%
Totals Coverage Status
Change from base Build 1566645174: 0.1%
Covered Lines: 1195
Relevant Lines: 1470

💛 - Coveralls

bors bot pushed a commit to TuringLang/DynamicPPL.jl that referenced this pull request Dec 15, 2021
bors bot pushed a commit to TuringLang/DynamicPPL.jl that referenced this pull request Dec 15, 2021
@torfjelde
Copy link
Member Author

Tests are now passing 👍 This should now be ready for a review.

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! To clarify, this PR only updates the DynamicPPL API, but does not make use of SimpleVarInfo yet, right?

@torfjelde
Copy link
Member Author

torfjelde commented Dec 16, 2021 via email

@torfjelde
Copy link
Member Author

@devmotion you happy with this?

Btw, should we also re-export condition, decondition, and conditioned from Turing now?

@devmotion
Copy link
Member

I'll have a look this afternoon 👍

I guess we don't have any prominent documentation yet and don't use it in any tutorials or examples? I think maybe we could add them and export these function in a single PR?

@torfjelde
Copy link
Member Author

Bueno 👍

And yeah, that's a fair point. Next on my TODO list is writing a tutorial/some docs on the recent features, e.g. @submodel, condition, etc. and so I guess we can export it then. Though the only thing is that those PRs will be going to different projects than Turing itself. But anyhow, it should probably be a separate PR so let's leave it out of this one 👍

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just had some minor questions, but looks good overall 👍 I guess most questions are related to the SimpleVarInfo integration which will be done in a separate PR.

setlogp!(vi, ForwardDiff.value(logp))
# Don't need to capture the resulting `vi` since this is only
# needed if `vi` is mutable.
setlogp!!(vi, ForwardDiff.value(logp))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess ideally we would want a no-op here when setlogp!! returns a new object - if we just call setlogp!! with a SimpleVarInfo it would create a new VarInfo object even though it's never used. Maybe one could check if vi === new_vi and only call setlogp!! if it is true (assuming that in this case setlogp!! would not return a new object)?

But I guess this should be left for the SimpleVarInfo PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this not be optimized away in the case where vi is immutable?
I sort of assumed so. But I guess we might as well give the compiler a helping hand just to be sure.


function AdvancedPS.reset_logprob!(f::TracedModel)
DynamicPPL.resetlogp!(f.varinfo)
DynamicPPL.resetlogp!!(f.varinfo)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also relies on the fact that it mutates f.varinfo.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, but TracedModel isn't compatible with immutable AbstractVarInfo 😕

Comment on lines +113 to +114
resetlogp!!(vi)
empty!!(vi)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does empty!! not call resetlogp!!? Maybe this would be reasonable? In any case, also here it is mandatory that the calls are actually mutating.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the calls below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same response as above: these samplers aren't compatible with immutable VarInfo 😕

Comment on lines 54 to 55
# Transform to unconstrained space.
DynamicPPL.link!.(vis, Ref(spl))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you extract it? It seems there are no performance gains but rather now there are both a broadcasting and a map operation whereas before we just used a single map?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extracted it because at first I intended to make it compatible with immutable AbstractVarInfo, but then I realized this shouldn't be done in this PR 🙃 I'll revert this change I guess 👍

@torfjelde
Copy link
Member Author

Thanks! I'll merge once tests pass.

@torfjelde torfjelde merged commit 5c4b4d5 into master Dec 16, 2021
@delete-merged-branch delete-merged-branch bot deleted the tor/dynamicppl-update branch December 16, 2021 19:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants