@@ -355,10 +355,12 @@ end
355355
356356function generate_tilde_literal (left, right)
357357 # If the LHS is a literal, it is always an observation
358+ @gensym value
358359 return quote
359- $ (DynamicPPL. tilde_observe!)(
360+ $ value, __varinfo__ = $ (DynamicPPL. tilde_observe! !)(
360361 __context__, $ (DynamicPPL. check_tilde_rhs)($ right), $ left, __varinfo__
361362 )
363+ $ value
362364 end
363365end
364366
@@ -373,7 +375,7 @@ function generate_tilde(left, right)
373375
374376 # Otherwise it is determined by the model or its value,
375377 # if the LHS represents an observation
376- @gensym vn isassumption
378+ @gensym vn isassumption value
377379
378380 # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
379381 # that in DynamicPPL we the entire function body. Instead we should be
@@ -389,32 +391,38 @@ function generate_tilde(left, right)
389391 $ left = $ (DynamicPPL. getvalue_nested)(__context__, $ vn)
390392 end
391393
392- $ (DynamicPPL. tilde_observe!)(
394+ $ value, __varinfo__ = $ (DynamicPPL. tilde_observe! !)(
393395 __context__,
394396 $ (DynamicPPL. check_tilde_rhs)($ right),
395397 $ (maybe_view (left)),
396398 $ vn,
397399 __varinfo__,
398400 )
401+ $ value
399402 end
400403 end
401404end
402405
403406function generate_tilde_assume (left, right, vn)
404- expr = :(
405- $ left = $ (DynamicPPL. tilde_assume!)(
407+ # HACK: Because the Setfield.jl macro does not support assignment
408+ # with multiple arguments on the LHS, we need to capture the return-values
409+ # and then update the LHS variables one by one.
410+ @gensym value
411+ expr = :($ left = $ value)
412+ if left isa Expr
413+ expr = AbstractPPL. drop_escape (
414+ Setfield. setmacro (BangBang. prefermutation, expr; overwrite= true )
415+ )
416+ end
417+
418+ return quote
419+ $ value, __varinfo__ = $ (DynamicPPL. tilde_assume!!)(
406420 __context__,
407421 $ (DynamicPPL. unwrap_right_vn)($ (DynamicPPL. check_tilde_rhs)($ right), $ vn). .. ,
408422 __varinfo__,
409423 )
410- )
411-
412- return if left isa Expr
413- AbstractPPL. drop_escape (
414- Setfield. setmacro (BangBang. prefermutation, expr; overwrite= true )
415- )
416- else
417- return expr
424+ $ expr
425+ $ value
418426 end
419427end
420428
@@ -428,7 +436,7 @@ function generate_dot_tilde(left, right)
428436
429437 # Otherwise it is determined by the model or its value,
430438 # if the LHS represents an observation
431- @gensym vn isassumption
439+ @gensym vn isassumption value
432440 return quote
433441 $ vn = $ (AbstractPPL. drop_escape (varname (left)))
434442 $ isassumption = $ (DynamicPPL. isassumption (left))
@@ -440,13 +448,14 @@ function generate_dot_tilde(left, right)
440448 $ left .= $ (DynamicPPL. getvalue_nested)(__context__, $ vn)
441449 end
442450
443- $ (DynamicPPL. dot_tilde_observe!)(
451+ $ value, __varinfo__ = $ (DynamicPPL. dot_tilde_observe! !)(
444452 __context__,
445453 $ (DynamicPPL. check_tilde_rhs)($ right),
446454 $ (maybe_view (left)),
447455 $ vn,
448456 __varinfo__,
449457 )
458+ $ value
450459 end
451460 end
452461end
@@ -455,15 +464,82 @@ function generate_dot_tilde_assume(left, right, vn)
455464 # We don't need to use `Setfield.@set` here since
456465 # `.=` is always going to be inplace + needs `left` to
457466 # be something that supports `.=`.
458- return :(
459- $ left .= $ (DynamicPPL. dot_tilde_assume!)(
467+ @gensym value
468+ return quote
469+ $ value, __varinfo__ = $ (DynamicPPL. dot_tilde_assume!!)(
460470 __context__,
461471 $ (DynamicPPL. unwrap_right_left_vns)(
462472 $ (DynamicPPL. check_tilde_rhs)($ right), $ (maybe_view (left)), $ vn
463473 ). .. ,
464474 __varinfo__,
465475 )
466- )
476+ $ left .= $ value
477+ $ value
478+ end
479+ end
480+
481+ # Note that we cannot use `MacroTools.isdef` because
482+ # of https://github.com/FluxML/MacroTools.jl/issues/154.
483+ """
484+ isfuncdef(expr)
485+
486+ Return `true` if `expr` is any form of function definition, and `false` otherwise.
487+ """
488+ function isfuncdef (e:: Expr )
489+ return if Meta. isexpr (e, :function )
490+ # Classic `function f(...)`
491+ true
492+ elseif Meta. isexpr (e, :-> )
493+ # Anonymous functions/lambdas, e.g. `do` blocks or `->` defs.
494+ true
495+ elseif Meta. isexpr (e, :(= )) && Meta. isexpr (e. args[1 ], :call )
496+ # Short function defs, e.g. `f(args...) = ...`.
497+ true
498+ else
499+ false
500+ end
501+ end
502+
503+ """
504+ replace_returns(expr)
505+
506+ Return `Expr` with all `return ...` statements replaced with
507+ `return ..., DynamicPPL.return_values(__varinfo__)`.
508+
509+ Note that this method will _not_ replace `return` statements within function
510+ definitions. This is checked using [`isfuncdef`](@ref).
511+ """
512+ replace_returns (e) = e
513+ function replace_returns (e:: Expr )
514+ if isfuncdef (e)
515+ return e
516+ end
517+
518+ if Meta. isexpr (e, :return )
519+ # NOTE: `return` always has an argument. In the case of
520+ # an empty `return`, the lowered expression will be `return nothing`.
521+ # Hence we don't need any special handling for empty returns.
522+ retval_expr = if length (e. args) > 1
523+ Expr (:tuple , e. args... )
524+ else
525+ e. args[1 ]
526+ end
527+
528+ return :(return ($ retval_expr, __varinfo__))
529+ end
530+
531+ return Expr (e. head, map (replace_returns, e. args)... )
532+ end
533+
534+ # If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`.
535+ make_returns_explicit! (body) = Expr (:return , body)
536+ function make_returns_explicit! (body:: Expr )
537+ # If the last statement is a return-statement, we don't do anything.
538+ # Otherwise we replace the last statement with a `return` statement.
539+ if ! Meta. isexpr (body. args[end ], :return )
540+ body. args[end ] = Expr (:return , body. args[end ])
541+ end
542+ return body
467543end
468544
469545const FloatOrArrayType = Type{<: Union{AbstractFloat,AbstractArray} }
@@ -496,10 +572,14 @@ function build_output(modelinfo, linenumbernode)
496572 # Replace the user-provided function body with the version created by DynamicPPL.
497573 # We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
498574 # that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
499- # to the call site
575+ # to the call site.
576+ # NOTE: We need to replace statements of the form `return ...` with
577+ # `return (..., __varinfo__)` to ensure that the second
578+ # element in the returned value is always the most up-to-date `__varinfo__`.
579+ # See the docstrings of `replace_returns` for more info.
500580 evaluatordef[:body ] = MacroTools. @q begin
501581 $ (linenumbernode)
502- $ (modelinfo[:body ])
582+ $ (replace_returns ( make_returns_explicit! ( modelinfo[:body ])) )
503583 end
504584
505585 # # Build the model function.
0 commit comments