@@ -23,12 +23,13 @@ function _pobserve(expr::Expr)
2323 $ (process_tilde_statements (block))
2424 end
2525 end
26- total_likelihoods = sum (fetch .(likelihood_tasks))
26+ retvals_and_likelihoods = fetch .(likelihood_tasks)
27+ total_likelihoods = sum (last, retvals_and_likelihoods)
2728 # println("Total likelihoods: ", total_likelihoods)
2829 $ (esc (:(__varinfo__))) = $ (DynamicPPL. accloglikelihood!!)(
2930 $ (esc (:(__varinfo__))), total_likelihoods
3031 )
31- nothing
32+ map (first, retvals_and_likelihoods)
3233 end
3334 return return_expr
3435end
@@ -50,16 +51,34 @@ function process_tilde_statements(expr::Expr)
5051 @gensym loglike
5152 beginning_statement =
5253 :($ loglike = zero ($ (DynamicPPL. getloglikelihood)($ (esc (:(__varinfo__))))))
53- transformed_statements = map (statements) do stmt
54- # skip non-tilde statements
55- # TODO : dot-tilde
56- @capture (stmt, lhs_ ~ rhs_) || return :($ (esc (stmt)))
57- # if the above matched, we transform the tilde statement
58- # TODO : We should probably perform some checks to make sure that this
59- # indeed was meant to be an observe statement.
60- :($ loglike += $ (Distributions. logpdf)($ (esc (rhs)), $ (esc (lhs))))
54+ n_statements = length (statements)
55+ transformed_statements:: Vector{Vector{Expr}} = map (enumerate (statements)) do (i, stmt)
56+ is_last = i == n_statements
57+ if @capture (stmt, lhs_ ~ rhs_)
58+ # TODO : We should probably perform some checks to make sure that this
59+ # indeed was meant to be an observe statement.
60+ @gensym left
61+ e = [
62+ :($ left = $ (esc (lhs))),
63+ :($ loglike += $ (Distributions. logpdf)($ (esc (rhs)), $ left)),
64+ ]
65+ is_last && push! (e, :(($ left, $ loglike)))
66+ e
67+ elseif @capture (stmt, lhs_ .~ rhs_)
68+ @gensym val
69+ e = [
70+ # TODO : dot-tilde
71+ :($ val = $ (esc (stmt))),
72+ ]
73+ is_last && push! (e, :(($ val, $ loglike)))
74+ e
75+ else
76+ @gensym val
77+ e = [:($ val = $ (esc (stmt)))]
78+ is_last && push! (e, :(($ val, $ loglike)))
79+ e
80+ end
6181 end
62- ending_statement = loglike
63- new_statements = [beginning_statement, transformed_statements... , ending_statement]
82+ new_statements = [beginning_statement, reduce (vcat, transformed_statements)... ]
6483 return Expr (:block , new_statements... )
6584end
0 commit comments