Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 307: MvNormal removal and tuning #369

Merged
merged 11 commits into from
Jul 18, 2024
Merged

Issue 307: MvNormal removal and tuning #369

merged 11 commits into from
Jul 18, 2024

Conversation

seabbs
Copy link
Collaborator

@seabbs seabbs commented Jul 12, 2024

This PR closes #307 by moving to using filldist vs MvNormal.

It also:

  • Makes the random walk avoid reallocation (aiming to target Improve Zygote compatibility #339) and updates it to not add a noise term for the first value (bringing it into line with the AR implementation - (we can talk about other options here but I think this is a sensible step ahead of a new issue discussion)
  • Introduces a vectorised format for the AR process via tooling for accumulate_scan.

accumulate_scan is a candidate to replace scan everywhere with this being done by rewriting EpiData to be a step function with a callable. This would probably need to be connected to a rewrite of the infection modules to make step processes composable. When reviewing please think if this approach has limitations vs the current scan method (probably in the carry flexibility (though I note the expectation here is that more complex methods would need to redefine their inits.

This PR also serves as another test of the benchmarks to try and isolate the failure in #358.

@seabbs seabbs requested a review from SamuelBrand1 July 12, 2024 19:10
@seabbs seabbs changed the title add a working version of accumulate scan and remove mvnormals Issue 307: MvNormal removal and tuning Jul 12, 2024
@codecov-commenter
Copy link

codecov-commenter commented Jul 12, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 93.10%. Comparing base (38e37f9) to head (4d4231e).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #369      +/-   ##
==========================================
+ Coverage   93.07%   93.10%   +0.02%     
==========================================
  Files          50       51       +1     
  Lines         520      522       +2     
==========================================
+ Hits          484      486       +2     
  Misses         36       36              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@seabbs
Copy link
Collaborator Author

seabbs commented Jul 14, 2024

I've added a custom context (PredictContext) and our own version of Turing.predict that uses it + added a new submodel IDD that has a PredictContext based switch.

This is very close to working but I am slightly stuck on a type conversion issue (somewhere it sees $\epsilon$ as a filldist function and it can't replace it with the array (I think this is in the in place VarInfo updating but am quite unclear where it sees the function type (unless its stored in the chain object??). Or maybe there is something else going on?

Note: I spent a rather depressing amount of time wondering through the Turing issues from 2020 here and have started to really wish that submodel was just overloading ~ as then we could have made this update with so much less boiler plate updating.

Test code:

using EpiAware, Turing, Distributions

idd = IDD(Normal(0, 1))
idd_model = generate_latent(idd, 10)

samples = sample(idd_model, Prior(), 1000)

preds_turing = Turing.predict(generate_latent(idd, 20), samples)
preds = EpiAware.EpiAwareUtils.predict(generate_latent(idd, 20), samples)
[ Info: Getting types
[ Info: Predicting
[ Info: DynamicPPL.TypedVarInfo{@NamedTuple{ϵ_t::DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:ϵ_t, Accessors.IndexLens{Tuple{Int64}}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:ϵ_t, Accessors.IndexLens{Tuple{Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}((ϵ_t = DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:ϵ_t, Accessors.IndexLens{Tuple{Int64}}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:ϵ_t, Accessors.IndexLens{Tuple{Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}(Dict{AbstractPPL.VarName{:ϵ_t, Accessors.IndexLens{Tuple{Int64}}}, Int64}(ϵ_t[9] => 9, ϵ_t[4] => 4, ϵ_t[6] => 6, ϵ_t[3] => 3, ϵ_t[17] => 17, ϵ_t[18] => 18, ϵ_t[19] => 19, ϵ_t[5] => 5, ϵ_t[11] => 11, ϵ_t[10] => 10, ϵ_t[13] => 13, ϵ_t[8] => 8, ϵ_t[2] => 2, ϵ_t[20] => 20, ϵ_t[1] => 1, ϵ_t[15] => 15, ϵ_t[16] => 16, ϵ_t[12] => 12, ϵ_t[7] => 7, ϵ_t[14] => 14), AbstractPPL.VarName{:ϵ_t, Accessors.IndexLens{Tuple{Int64}}}[ϵ_t[1], ϵ_t[2], ϵ_t[3], ϵ_t[4], ϵ_t[5], ϵ_t[6], ϵ_t[7], ϵ_t[8], ϵ_t[9], ϵ_t[10], ϵ_t[11], ϵ_t[12], ϵ_t[13], ϵ_t[14], ϵ_t[15], ϵ_t[16], ϵ_t[17], ϵ_t[18], ϵ_t[19], ϵ_t[20]], UnitRange{Int64}[1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 11:11, 12:12, 13:13, 14:14, 15:15, 16:16, 17:17, 18:18, 19:19, 20:20], [-0.7082101677945046, 0.5491115966525051, 2.3842630466986257, -2.1310827823622476, 1.4589684465556265, 1.128274758194063, -0.9437569622288362, 0.0670683933861612, -0.6281275543289423, 0.40844270897271584, 0.4500941947995636, -0.02053489398433327, -0.9654168977655161, 0.8234408488431235, -1.9024840086670902, -0.07285840907253294, -0.08607929613724663, 1.2735247017938127, 0.6520883239494453, 0.796122359493298], Normal{Float64}[Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0)], Set{DynamicPPL.Selector}[Set(), Set(), Set(), Set(), Set(), Set(), Set(), Set(), Set(), Set(), Set(), Set(), Set(), Set(), Set(), Set(), Set(), Set(), Set(), Set()], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], Dict{String, BitVector}("del" => [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "trans" => [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),), Base.RefValue{Float64}(-30.385569079389658), Base.RefValue{Int64}(0))
[ Info: 1
[ Info: 1
[ Info: Reseting value
ERROR: MethodError: no method matching subsumes(::Accessors.IndexLens{Tuple{Int64}}, ::ComposedFunction{Accessors.IndexLens{Tuple{Int64}}, typeof(identity)})

Closest candidates are:
  subsumes(::typeof(identity), ::Union{typeof(identity), Accessors.IndexLens, Accessors.PropertyLens, ComposedFunction})
   @ AbstractPPL ~/.julia/packages/AbstractPPL/kb4q5/src/varname.jl:296
  subsumes(::Union{typeof(identity), Accessors.IndexLens, Accessors.PropertyLens, ComposedFunction}, ::typeof(identity))
   @ AbstractPPL ~/.julia/packages/AbstractPPL/kb4q5/src/varname.jl:297
  subsumes(::Accessors.PropertyLens, ::ComposedFunction)
   @ AbstractPPL ~/.julia/packages/AbstractPPL/kb4q5/src/varname.jl:308
  ...

Stacktrace:
  [1] subsumes(u::AbstractPPL.VarName{:ϵ_t, Accessors.IndexLens{…}}, v::AbstractPPL.VarName{:ϵ_t, ComposedFunction{…}})
    @ AbstractPPL ~/.julia/packages/AbstractPPL/kb4q5/src/varname.jl:287
  [2] (::Base.Fix2{…})(y::AbstractPPL.VarName{…})
    @ Base ./operators.jl:1135
  [3] findnext
    @ ./array.jl:2155 [inlined]
  [4] findfirst
    @ ./array.jl:2206 [inlined]
  [5] _nested_setindex_maybe!
    @ ~/.julia/packages/DynamicPPL/i2EbF/src/varinfo.jl:1398 [inlined]
  [6] nested_setindex_maybe!(vi::DynamicPPL.TypedVarInfo{…}, val::Float64, vn::AbstractPPL.VarName{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/i2EbF/src/varinfo.jl:1384
  [7] setval_and_resample!(vi::DynamicPPL.TypedVarInfo{…}, chains::Chains{…}, sample_idx::Int64, chain_idx::Int64)
    @ DynamicPPL ~/.julia/packages/DynamicPPL/i2EbF/src/varinfo.jl:1952
  [8] (::EpiAware.EpiAwareUtils.var"#32#33"{})(::Tuple{…})
    @ EpiAware.EpiAwareUtils ~/code/Rt-without-renewal/EpiAware/src/EpiAwareUtils/predict.jl:201
  [9] iterate
    @ ./generator.jl:47 [inlined]
 [10] collect(itr::Base.Generator{Base.Iterators.ProductIterator{…}, EpiAware.EpiAwareUtils.var"#32#33"{…}})
    @ Base ./array.jl:834
 [11] map(f::Function, A::Base.Iterators.ProductIterator{Tuple{UnitRange{Int64}, UnitRange{Int64}}})
    @ Base ./abstractarray.jl:3313
 [12] transitions_from_chain(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, chain::Chains{…}; sampler::DynamicPPL.SampleFromPrior, context::PredictContext{…})
    @ EpiAware.EpiAwareUtils ~/code/Rt-without-renewal/EpiAware/src/EpiAwareUtils/predict.jl:196
 [13] predict(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, chain::Chains{…}; include_all::Bool)
    @ EpiAware.EpiAwareUtils ~/code/Rt-without-renewal/EpiAware/src/EpiAwareUtils/predict.jl:104
 [14] predict
    @ ~/code/Rt-without-renewal/EpiAware/src/EpiAwareUtils/predict.jl:96 [inlined]
 [15] #predict#24
    @ ~/code/Rt-without-renewal/EpiAware/src/EpiAwareUtils/predict.jl:94 [inlined]
 [16] predict(model::DynamicPPL.Model{…}, chain::Chains{…})
    @ EpiAware.EpiAwareUtils ~/code/Rt-without-renewal/EpiAware/src/EpiAwareUtils/predict.jl:93
 [17] top-level scope
    @ ~/code/Rt-without-renewal/EpiAware/src/EpiLatentModels/models/test.jl:9
Some type information was truncated. Use `show(err)` to see complete types.

@SamuelBrand1

This comment was marked as outdated.

@SamuelBrand1

This comment was marked as outdated.

@seabbs

This comment was marked as outdated.

@seabbs
Copy link
Collaborator Author

seabbs commented Jul 15, 2024

I've localised the problem to setval_and_resample! and in particular the findfirst call (https://github.com/TuringLang/DynamicPPL.jl/blob/123d7bf129ff75cd783402fe39e08f807c1fcb23/src/varinfo.jl#L1398) but I am not sure what I can actually do about this as its still unclear to me which part (well obviously it is \epsilon_t but more what to do) of the chains object (I think it has to be this) is causing the problem.

@seabbs
Copy link
Collaborator Author

seabbs commented Jul 15, 2024

Any insights @SamuelBrand1 ?

@SamuelBrand1
Copy link
Collaborator

...well obviously it is \epsilon_t but more what to do) of the chains object (I think it has to be this) is causing the problem.

Agreed. I'm looking at this, it seems odd to me that this doesn't work.

@seabbs
Copy link
Collaborator Author

seabbs commented Jul 16, 2024

Plan for this PR is to pull out the predict changes into their own PR and then review/merge this and deal with the predict issues there.

@SamuelBrand1
Copy link
Collaborator

Ongoing review: I like using Base.accumulate because using Base functions increases the chance that custom adjoints already exist in any AD system we might use, as well as being maintained/developed by Julia core devs. I'm in favour of swapping to accumulate_scan over our hand written scan.

We should note that as default accumulate_scan works with stateful iterators where we want to return the last element; it is possible to be just as flexible as scan but it requires writing a custom get_state method.

@seabbs seabbs requested a review from SamuelBrand1 July 17, 2024 17:28
@seabbs
Copy link
Collaborator Author

seabbs commented Jul 17, 2024

I have pulled the changes to predict into their own PR and this is now ready to review. I guess once this is merged we need some issues to remove scan (those are probably tied into improving the infection models).

@seabbs seabbs enabled auto-merge July 17, 2024 21:12
@seabbs
Copy link
Collaborator Author

seabbs commented Jul 17, 2024

It would be nice to wait for benchmarks (and we can see if #381 helps first) but I think we should aim to merge this ASAP and fix in post. I think its very unlikely any changes here reduce performance (we can check this in detail when looking to swap out scan).

Copy link
Collaborator

@SamuelBrand1 SamuelBrand1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I was happy with this side of the PR and remain so.

@seabbs seabbs added this pull request to the merge queue Jul 18, 2024
Merged via the queue into main with commit 6644ace Jul 18, 2024
10 of 11 checks passed
@seabbs seabbs deleted the optimise-mvnormal-scan branch July 18, 2024 09:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Move to all Turing models in EpiAware to be non-vectorised
3 participants