Skip to content

Commit 03e0b6f

Browse files
authored
Try #231:
2 parents d2678d5 + 7c1176b commit 03e0b6f

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

src/varinfo.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,8 +722,11 @@ function link!(vi::UntypedVarInfo, spl::Sampler)
722722
end
723723
end
724724
function link!(vi::TypedVarInfo, spl::AbstractSampler)
725+
return link!(vi, spl, Val(getspace(spl)))
726+
end
727+
function link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val)
725728
vns = _getvns(vi, spl)
726-
return _link!(vi.metadata, vi, vns, Val(getspace(spl)))
729+
return _link!(vi.metadata, vi, vns, spaceval)
727730
end
728731
@generated function _link!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space}
729732
expr = Expr(:block)
@@ -770,8 +773,11 @@ function invlink!(vi::UntypedVarInfo, spl::AbstractSampler)
770773
end
771774
end
772775
function invlink!(vi::TypedVarInfo, spl::AbstractSampler)
776+
return invlink!(vi, spl, Val(getspace(spl)))
777+
end
778+
function invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val)
773779
vns = _getvns(vi, spl)
774-
return _invlink!(vi.metadata, vi, vns, Val(getspace(spl)))
780+
return _invlink!(vi.metadata, vi, vns, spaceval)
775781
end
776782
@generated function _invlink!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space}
777783
expr = Expr(:block)

test/turing/varinfo.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,16 @@
6767
@test all(x -> !istrans(vi, x), meta.m.vns)
6868
@test meta.s.vals == v_s
6969
@test meta.m.vals == v_m
70+
71+
# Transforming only a subset of the variables
72+
link!(vi, spl, Val((:m, )))
73+
@test all(x -> !istrans(vi, x), meta.s.vns)
74+
@test all(x -> istrans(vi, x), meta.m.vns)
75+
invlink!(vi, spl, Val((:m, )))
76+
@test all(x -> !istrans(vi, x), meta.s.vns)
77+
@test all(x -> !istrans(vi, x), meta.m.vns)
78+
@test meta.s.vals == v_s
79+
@test meta.m.vals == v_m
7080
end
7181
@testset "orders" begin
7282
csym = gensym() # unique per model
@@ -329,4 +339,4 @@
329339
@test vi.metadata.w.gids[1] == Set([hmc.selector])
330340
@test vi.metadata.u.gids[1] == Set([hmc.selector]) =#
331341
end
332-
end
342+
end

0 commit comments

Comments
 (0)