Skip to content

Commit ad2b1ef

Browse files
authored
Try #261:
2 parents 9083299 + b993b97 commit ad2b1ef

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.11.2"
3+
version = "0.11.3"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/compiler.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ end
3636
# failsafe: a literal is never an assumption
3737
isassumption(expr) = :(false)
3838

39+
"""
40+
isliteral(expr)
41+
42+
Return `true` if `expr` is a literal, e.g. `1.0` or `[1.0, ]`, and `false` otherwise.
43+
"""
44+
isliteral(e) = false
45+
isliteral(::Number) = true
46+
isliteral(e::Expr) = !isempty(e.args) && all(isliteral, e.args)
47+
3948
"""
4049
check_tilde_rhs(x)
4150
@@ -240,7 +249,7 @@ variables.
240249
"""
241250
function generate_tilde(left, right)
242251
# If the LHS is a literal, it is always an observation
243-
if !(left isa Symbol || left isa Expr)
252+
if isliteral(left)
244253
return quote
245254
$(DynamicPPL.tilde_observe)(
246255
__context__,
@@ -290,7 +299,7 @@ Generate the expression that replaces `left .~ right` in the model body.
290299
"""
291300
function generate_dot_tilde(left, right)
292301
# If the LHS is a literal, it is always an observation
293-
if !(left isa Symbol || left isa Expr)
302+
if isliteral(left)
294303
return quote
295304
$(DynamicPPL.dot_tilde_observe)(
296305
__context__,

test/compiler.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,4 +423,21 @@ end
423423
x = [Laplace(), Normal(), MvNormal(3, 1.0)]
424424
@test DynamicPPL.check_tilde_rhs(x) === x
425425
end
426+
427+
@testset "array literals" begin
428+
# Verify that we indeed can parse this.
429+
@test @model(function array_literal_model()
430+
# `assume` and literal `observe`
431+
m ~ MvNormal(2, 1.0)
432+
return [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2))
433+
end) isa Function
434+
435+
@model function array_literal_model()
436+
# `assume` and literal `observe`
437+
m ~ MvNormal(2, 1.0)
438+
return [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2))
439+
end
440+
441+
@test array_literal_model()() == [10.0, 10.0]
442+
end
426443
end

0 commit comments

Comments
 (0)