|  | 
| 565 | 565 |         end | 
| 566 | 566 |     end | 
| 567 | 567 | 
 | 
|  | 568 | +    @testset "PG with variable number of observations" begin | 
|  | 569 | +        # When sampling from a model with Particle Gibbs, it is mandatory for | 
|  | 570 | +        # the number of observations to be the same in all particles, since the | 
|  | 571 | +        # observations trigger particle resampling. | 
|  | 572 | +        # | 
|  | 573 | +        # Up until Turing v0.39, `x ~ dist` statements where `x` was the | 
|  | 574 | +        # responsibility of a different (non-PG) Gibbs subsampler used to not | 
|  | 575 | +        # count as an observation. Instead, the log-probability `logpdf(dist, x)` | 
|  | 576 | +        # would be manually added to the VarInfo's `logp` field and included in the | 
|  | 577 | +        # weighting for the _following_ observation. | 
|  | 578 | +        # | 
|  | 579 | +        # In Turing v0.40, this is now changed: `x ~ dist` uses tilde_observe!! | 
|  | 580 | +        # which thus triggers resampling. Thus, for example, the following model | 
|  | 581 | +        # does not work any more: | 
|  | 582 | +        # | 
|  | 583 | +        #   @model function f() | 
|  | 584 | +        #       a ~ Poisson(2.0) | 
|  | 585 | +        #       x = Vector{Float64}(undef, a) | 
|  | 586 | +        #       for i in eachindex(x) | 
|  | 587 | +        #           x[i] ~ Normal() | 
|  | 588 | +        #       end | 
|  | 589 | +        #   end | 
|  | 590 | +        #   sample(f(), Gibbs(:a => PG(10), :x => MH()), 1000) | 
|  | 591 | +        #  | 
|  | 592 | +        # because the number of observations in each particle depends on the value | 
|  | 593 | +        # of `a`. | 
|  | 594 | +        # | 
|  | 595 | +        # This testset checks that ways of working around such a situation. | 
|  | 596 | + | 
|  | 597 | +        function test_dynamic_bernoulli(chain) | 
|  | 598 | +            means = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 2.0) | 
|  | 599 | +            stds = Dict(:b => 0.5, "x[1]" => 1.0, "x[2]" => 1.0) | 
|  | 600 | +            for vn in keys(means) | 
|  | 601 | +                @test isapprox(mean(skipmissing(chain[:, vn, 1])), means[vn]; atol=0.1) | 
|  | 602 | +                @test isapprox(std(skipmissing(chain[:, vn, 1])), stds[vn]; atol=0.1) | 
|  | 603 | +            end | 
|  | 604 | +        end | 
|  | 605 | + | 
|  | 606 | +        @testset "Coalescing multiple observations into one" begin | 
|  | 607 | +            # Instead of observing x[1] and x[2] separately, we lump them into a | 
|  | 608 | +            # single distribution. | 
|  | 609 | +            @model function dynamic_bernoulli() | 
|  | 610 | +                b ~ Bernoulli() | 
|  | 611 | +                if b | 
|  | 612 | +                    dists = [Normal(1.0)] | 
|  | 613 | +                else | 
|  | 614 | +                    dists = [Normal(1.0), Normal(2.0)] | 
|  | 615 | +                end | 
|  | 616 | +                return x ~ product_distribution(dists) | 
|  | 617 | +            end | 
|  | 618 | +            model = dynamic_bernoulli() | 
|  | 619 | +            # This currently fails because if the global varinfo has `x` with length 2, | 
|  | 620 | +            # and the particle sampler has `b = true`, it attempts to calculate the | 
|  | 621 | +            # log-likelihood of a length-2 vector with respect to a length-1 | 
|  | 622 | +            # distribution. | 
|  | 623 | +            @test_throws DimensionMismatch chain = sample( | 
|  | 624 | +                StableRNG(468), | 
|  | 625 | +                model, | 
|  | 626 | +                Gibbs(:b => PG(10), :x => ESS()), | 
|  | 627 | +                2000; | 
|  | 628 | +                discard_initial=100, | 
|  | 629 | +            ) | 
|  | 630 | +            # test_dynamic_bernoulli(chain) | 
|  | 631 | +        end | 
|  | 632 | + | 
|  | 633 | +        @testset "Inserting @addlogprob!" begin | 
|  | 634 | +            # On top of observing x[i], we also add in extra 'observations' | 
|  | 635 | +            @model function dynamic_bernoulli_2() | 
|  | 636 | +                b ~ Bernoulli() | 
|  | 637 | +                x_length = b ? 1 : 2 | 
|  | 638 | +                x = Vector{Float64}(undef, x_length) | 
|  | 639 | +                for i in 1:x_length | 
|  | 640 | +                    x[i] ~ Normal(i, 1.0) | 
|  | 641 | +                end | 
|  | 642 | +                if length(x) == 1 | 
|  | 643 | +                    # This value is the expectation value of logpdf(Normal(), x) where x ~ Normal(). | 
|  | 644 | +                    # See discussion in | 
|  | 645 | +                    # https://github.com/TuringLang/Turing.jl/pull/2629#discussion_r2237323817 | 
|  | 646 | +                    @addlogprob!(-1.418849) | 
|  | 647 | +                end | 
|  | 648 | +            end | 
|  | 649 | +            model = dynamic_bernoulli_2() | 
|  | 650 | +            chain = sample( | 
|  | 651 | +                StableRNG(468), | 
|  | 652 | +                model, | 
|  | 653 | +                Gibbs(:b => PG(10), :x => ESS()), | 
|  | 654 | +                2000; | 
|  | 655 | +                discard_initial=100, | 
|  | 656 | +            ) | 
|  | 657 | +            test_dynamic_bernoulli(chain) | 
|  | 658 | +        end | 
|  | 659 | +    end | 
|  | 660 | + | 
| 568 | 661 |     @testset "Demo model" begin | 
| 569 | 662 |         @testset verbose = true "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS | 
| 570 | 663 |             vns = DynamicPPL.TestUtils.varnames(model) | 
|  | 
0 commit comments