Skip to content

Commit 19f9782

Browse files
committed
Fix for unwrap_left_right_vns (#297)
I just noticed a bug I introduced in a recent PR when looking at #295 . This PR fixes it. I'll add tests, a sec. @yebai
1 parent 5472d9d commit 19f9782

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.13.0"
3+
version = "0.13.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/compiler.jl

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,35 @@ left-hand side of a `.~` expression such as `x .~ Normal()`.
9090
9191
This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the
9292
variables.
93+
94+
# Example
95+
```jldoctest; setup=:(using Distributions)
96+
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(1, 1.0), randn(1, 2), @varname(x)); vns
97+
2-element Vector{VarName{:x, Tuple{Tuple{Colon, Int64}}}}:
98+
x[:,1]
99+
x[:,2]
100+
101+
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); vns
102+
1×2 Matrix{VarName{:x, Tuple{Tuple{Colon}, Tuple{Int64, Int64}}}}:
103+
x[:][1,1] x[:][1,2]
104+
105+
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); vns
106+
3-element Vector{VarName{:x, Tuple{Tuple{Int64}, Tuple{Int64}}}}:
107+
x[1][1]
108+
x[1][2]
109+
x[1][3]
110+
111+
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2, 3), @varname(x)); vns
112+
1×2×3 Array{VarName{:x, Tuple{Tuple{Int64, Int64, Int64}}}, 3}:
113+
[:, :, 1] =
114+
x[1,1,1] x[1,2,1]
115+
116+
[:, :, 2] =
117+
x[1,1,2] x[1,2,2]
118+
119+
[:, :, 3] =
120+
x[1,1,3] x[1,2,3]
121+
```
93122
"""
94123
unwrap_right_left_vns(right, left, vns) = right, left, vns
95124
function unwrap_right_left_vns(right::NamedDist, left, vns)
@@ -103,7 +132,7 @@ function unwrap_right_left_vns(
103132
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
104133
# and we therefore add the `Colon()` below.
105134
vns = map(axes(left, 2)) do i
106-
return VarName(vn, (vn.indexing..., Colon(), Tuple(i)))
135+
return VarName(vn, (vn.indexing..., (Colon(), i)))
107136
end
108137
return unwrap_right_left_vns(right, left, vns)
109138
end

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2020

2121
[compat]
2222
AbstractMCMC = "2.1, 3.0"
23-
AbstractPPL = "0.1.4, 0.2"
23+
AbstractPPL = "0.2"
2424
Bijectors = "0.9.5"
2525
Distributions = "< 0.25.11"
2626
DistributionsAD = "0.6.3"

0 commit comments

Comments
 (0)