Skip to content

Commit e241dca

Browse files
sunxd3yebai
andauthored
AbstractPPL@0.6 Compatibility Fix (#494)
* Bump `AbstractPPL` version. * Remove version 0.5. * Concretize `Colon`s * Relax `AbstractPPL` version * Correct test `Project.toml` * Formatting * Fix formatter mistake. * Apply formatting suggestions * Correct wrong dimension of `s` in test * use `@test_throws` instead of `@test` * use `Setfield.set` in `set!!` function, experiment * Revert "use `Setfield.set` in `set!!` function, experiment" This reverts commit d7cf4ca. * hacky `possible` with `ConcretizedSlice` * add some doc for `need * Update Project.toml --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
1 parent 0d5c463 commit e241dca

File tree

6 files changed

+49
-10
lines changed

6 files changed

+49
-10
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.23.6"
3+
version = "0.23.7"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2222

2323
[compat]
2424
AbstractMCMC = "2, 3.0, 4"
25-
AbstractPPL = "0.5.3"
25+
AbstractPPL = "0.6"
2626
BangBang = "0.3"
2727
Bijectors = "0.13"
2828
ChainRulesCore = "0.9.7, 0.10, 1"

src/compiler.jl

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,34 @@
11
const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)
22

3+
"""
4+
need_concretize(expr)
5+
6+
Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or
7+
requires a dynamic lens.
8+
9+
# Examples
10+
11+
```jldoctest; setup=:(using Setfield)
12+
julia> DynamicPPL.need_concretize(:(x[1, :]))
13+
true
14+
15+
julia> DynamicPPL.need_concretize(:(x[1, end]))
16+
true
17+
18+
julia> DynamicPPL.need_concretize(:(x[1, 1]))
19+
false
20+
"""
21+
function need_concretize(expr)
22+
return Setfield.need_dynamic_lens(expr) || begin
23+
flag = false
24+
MacroTools.postwalk(expr) do ex
25+
# Concretise colon by default
26+
ex == :(:) && (flag = true) && return ex
27+
end
28+
flag
29+
end
30+
end
31+
332
"""
433
isassumption(expr[, vn])
534
@@ -16,10 +45,13 @@ When `expr` is not an expression or symbol (i.e., a literal), this expands to `f
1645
1746
If `vn` is specified, it will be assumed to refer to a expression which
1847
evaluates to a `VarName`, and this will be used in the subsequent checks.
19-
If `vn` is not specified, `AbstractPPL.drop_escape(varname(expr))` will be
48+
If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be
2049
used in its place.
2150
"""
22-
function isassumption(expr::Union{Expr,Symbol}, vn=AbstractPPL.drop_escape(varname(expr)))
51+
function isassumption(
52+
expr::Union{Expr,Symbol},
53+
vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))),
54+
)
2355
return quote
2456
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
2557
# Considered an assumption by `__context__` which means either:
@@ -194,7 +226,7 @@ function unwrap_right_left_vns(
194226
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
195227
# and we therefore add the `Colon()` below.
196228
vns = map(axes(left, 2)) do i
197-
return vn Setfield.IndexLens((Colon(), i))
229+
return AbstractPPL.concretize(vn Setfield.IndexLens((Colon(), i)), left)
198230
end
199231
return unwrap_right_left_vns(right, left, vns)
200232
end
@@ -372,7 +404,7 @@ function generate_tilde(left, right)
372404
return quote
373405
$dist = $right
374406
$vn = $(DynamicPPL.resolve_varnames)(
375-
$(AbstractPPL.drop_escape(varname(left))), $dist
407+
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist
376408
)
377409
$isassumption = $(DynamicPPL.isassumption(left, vn))
378410
if $(DynamicPPL.isfixed(left, vn))
@@ -433,7 +465,7 @@ function generate_dot_tilde(left, right)
433465
@gensym vn isassumption value
434466
return quote
435467
$vn = $(DynamicPPL.resolve_varnames)(
436-
$(AbstractPPL.drop_escape(varname(left))), $right
468+
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $right
437469
)
438470
$isassumption = $(DynamicPPL.isassumption(left, vn))
439471
if $(DynamicPPL.isfixed(left, vn))

src/test_utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,8 @@ function logprior_true_with_logabsdet_jacobian(
543543
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
544544
end
545545
function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)})
546-
return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m)]
546+
s = zeros(1, 2) # used for varname concretization only
547+
return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(m)]
547548
end
548549

549550
@model function demo_assume_matrix_dot_observe_matrix(

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,12 @@ function BangBang.possible(
500500
return BangBang.implements(setindex!, C) &&
501501
promote_type(eltype(C), eltype(T)) <: eltype(C)
502502
end
503+
function BangBang.possible(
504+
::typeof(BangBang._setindex!), ::C, ::T, ::AbstractPPL.ConcretizedSlice, ::Integer
505+
) where {C<:AbstractMatrix,T<:AbstractVector}
506+
return BangBang.implements(setindex!, C) &&
507+
promote_type(eltype(C), eltype(T)) <: eltype(C)
508+
end
503509

504510
# HACK(torfjelde): This makes it so it works on iterators, etc. by default.
505511
# TODO(torfjelde): Do better.

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2222

2323
[compat]
2424
AbstractMCMC = "2.1, 3.0, 4"
25-
AbstractPPL = "0.5"
25+
AbstractPPL = "0.6"
2626
Bijectors = "0.13"
2727
Distributions = "0.25"
2828
DistributionsAD = "0.6.3"

test/varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
@test inspace(@varname(z[1][:]), space)
5656
@test inspace(@varname(z[1][2:3:10]), space)
5757
@test inspace(@varname(M[[2, 3], 1]), space)
58-
@test inspace(@varname(M[:, 1:4]), space)
58+
@test_throws ErrorException inspace(@varname(M[:, 1:4]), space)
5959
@test inspace(@varname(M[1, [2, 4, 6]]), space)
6060
@test !inspace(@varname(z[2]), space)
6161
@test !inspace(@varname(z), space)

0 commit comments

Comments
 (0)