-
Notifications
You must be signed in to change notification settings - Fork 37
Fix prob macros #147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix prob macros #147
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -279,7 +279,7 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`. | |||||||||||||||||||
|
|
||||||||||||||||||||
| The values may or may not be transformed to Euclidean space. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| setval!(vi::VarInfo, val, vn::VarName) = getmetadata(vi, vn).vals[getrange(vi, vn)] = val | ||||||||||||||||||||
| setval!(vi::VarInfo, val, vn::VarName) = getmetadata(vi, vn).vals[getrange(vi, vn)] = [val;] | ||||||||||||||||||||
|
|
||||||||||||||||||||
| """ | ||||||||||||||||||||
| getval(vi::VarInfo, vns::Vector{<:VarName}) | ||||||||||||||||||||
|
|
@@ -1144,3 +1144,49 @@ function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler) | |||||||||||||||||||
| setgid!(vi, spl.selector, vn) | ||||||||||||||||||||
| end | ||||||||||||||||||||
| end | ||||||||||||||||||||
|
|
||||||||||||||||||||
| setval!(vi::AbstractVarInfo, x) = _setval!(vi, values(x), keys(x)) | ||||||||||||||||||||
| function setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) | ||||||||||||||||||||
| return _setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) | ||||||||||||||||||||
| end | ||||||||||||||||||||
|
|
||||||||||||||||||||
| function _setval!(vi::AbstractVarInfo, values, keys) | ||||||||||||||||||||
| for vn in Base.keys(vi) | ||||||||||||||||||||
| _setval_kernel!(vi, vn, values, keys) | ||||||||||||||||||||
| end | ||||||||||||||||||||
| return vi | ||||||||||||||||||||
| end | ||||||||||||||||||||
| _setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, values, keys) | ||||||||||||||||||||
| @generated function _typed_setval!( | ||||||||||||||||||||
| vi::TypedVarInfo, | ||||||||||||||||||||
| metadata::NamedTuple{names}, | ||||||||||||||||||||
| values, | ||||||||||||||||||||
| keys | ||||||||||||||||||||
| ) where {names} | ||||||||||||||||||||
| updates = map(names) do n | ||||||||||||||||||||
| quote | ||||||||||||||||||||
| for vn in metadata.$n.vns | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this always going to be just a single element now, which is the same as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK Lines 490 to 498 in 275ccc8
TypedVarInfo objects from untyped VarInfo objects (usually after running the model once).
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see. So if Thanks man:) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, exactly. |
||||||||||||||||||||
| _setval_kernel!(vi, vn, values, keys) | ||||||||||||||||||||
| end | ||||||||||||||||||||
| end | ||||||||||||||||||||
| end | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return quote | ||||||||||||||||||||
| $(updates...) | ||||||||||||||||||||
| return vi | ||||||||||||||||||||
| end | ||||||||||||||||||||
| end | ||||||||||||||||||||
|
|
||||||||||||||||||||
| function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) | ||||||||||||||||||||
| sym = Symbol(vn) | ||||||||||||||||||||
| regex = Regex("^$sym\$|^$sym\\[") | ||||||||||||||||||||
| indices = findall(x -> match(regex, string(x)) !== nothing, keys) | ||||||||||||||||||||
| if !isempty(indices) | ||||||||||||||||||||
| sorted_indices = sort!(indices; by=i -> string(keys[i]), lt=NaturalSort.natural) | ||||||||||||||||||||
| val = mapreduce(vcat, sorted_indices) do i | ||||||||||||||||||||
| values[i] | ||||||||||||||||||||
| end | ||||||||||||||||||||
| setval!(vi, val, vn) | ||||||||||||||||||||
| settrans!(vi, false, vn) | ||||||||||||||||||||
| end | ||||||||||||||||||||
| end | ||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think simple generated functions like this can be replaced by a map do-block on the named tuple directly. Last I checked the Julia compiler inferred and inlined it just fine with a do-block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be done in many places in DPPL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's what I assumed as well but it's not true in general. At least in March I found the generated function to be much more efficient on a simple example (TuringLang/Turing.jl#1167 (comment)), and hence IMO one really has to benchmark every possible switch from generated function to regular
map(which is a bit unfortunate since I'd like to just use regular functions wherever possible...).I'll benchmark this example to see if we could get rid of the generated function here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably ok to keep as-is here, and perform refactoring to change all places in another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this one of those things where the compiler will "inline" everything up-to some fixed threshold? Same as with recursions. I remember looking into this stuff when working on Bijectors.jl, and there was essentially a fixed depth (I think ~20 or something) at which point the recursion (even though the methods were type-stable) wouldn't be unrolled. I bet it's the same here, where if
namesis sufficiently small thenmapwill be the same asgenerated, but ifnamesis large it won't since this can cause issues, e.g. a Turing model with millions of univariate parameters will probably not be too comfortable for the compiler. With that said, I agree with forcing "inlining" by the use ofgeneratedfunctions since if someone is running a model with millions of parameters, it's likely that you're still just looking at <30 different symbols.