Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling issue when adding variational chains #290

Closed
sefffal opened this issue Oct 21, 2024 · 40 comments · Fixed by #292
Closed

Sampling issue when adding variational chains #290

sefffal opened this issue Oct 21, 2024 · 40 comments · Fixed by #292

Comments

@sefffal
Copy link

sefffal commented Oct 21, 2024

Hello all,
I have been seeing occasional issues when adding a variational reference. It seems to quite be sensitive to the model and the number of chains used. I have finally found a MWE that consistently reproduces this effect.

When looking at the corner plots below, note the second column plots the log posterior density. The second example includes samples that seem to be far outside the typical set.

I will include the model below.

Without Variational Reference

First, a sampling run with 10 chains and 0 variational chains:
image
image

─────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0       3.29     0.0133      6e+04       -111   5.57e-14      0.701          1          1 
        4          0       4.27     0.0184   8.76e+04       -164   1.68e-52      0.612          1          1 
        8          0          5     0.0347   1.48e+05       -103     0.0434      0.545          1          1 
       16          0       5.53     0.0665   2.62e+05       -106     0.0758      0.497          1          1 
       32          0       5.39      0.135   4.53e+05       -108      0.305       0.51          1          1 
       64          6       2.19      0.336   3.16e+07       -109      0.584      0.801          1          1 
      128         16       2.34      0.619   6.34e+07       -109      0.649      0.787          1          1 
      256         33       2.34       1.59   1.26e+08       -109      0.727      0.787          1          1 
      512         77       2.28       2.59   2.51e+08       -109      0.753      0.792          1          1 
 1.02e+03        120       2.38       5.15   5.04e+08       -109      0.754      0.784          1          1 
─────────────────────────────────────────────────────────────────────────────────────────────────────────────

With Variational Reference

Now, a sampling run with 10 chains and 6 variational chains:
image

image
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        Λ_var      time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0       2.48       2.74     0.0178    9.4e+04       -128   1.61e-31      0.693          1          1 
        4          0       4.47       2.75     0.0289   1.27e+05       -181   2.87e-66      0.575          1          1 
        8          0       5.04       3.49     0.0552   2.07e+05       -109   0.000299      0.499          1          1 
       16          0          6        3.5      0.109   3.68e+05       -113   4.64e-05      0.441          1          1 
       32          0       5.59       4.27      0.223   6.83e+05       -111     0.0311       0.42          1          1 
       64          4       5.67       3.84      0.447   1.36e+07       -108     0.0224      0.441          1          1 
      128          9       5.59       4.17      0.903   2.74e+07       -111      0.062      0.426          1          1 
      256         19        5.7       4.13        1.9   5.52e+07       -108     0.0407      0.422          1          1 
      512         37       5.76       4.03       3.76   1.11e+08       -109     0.0803      0.424          1          1 
 1.02e+03         67        5.8       4.02       7.84   2.23e+08       -109      0.144      0.422          1          1 
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

HMC, for reference

image image

Model Code

To reproduce this example, use the latest #main commit from Octofitter and OctofitterRadialVelocity, e.g. ] add Octofitter#main OctofitterRadialVelocity#main and then run the following:

using Octofitter
using OctofitterRadialVelocity
using Distributions
using PlanetOrbits


epochs = 58849 .+ (20:20:660)
planet_sim_mass = 0.001 # solar masses here


orb_template = orbit(
    a = 1.0,
    e = 0.7,
    # i= pi/4, # You can remove I think
    # Ω = 0.1, # You can remove I think
    ω = 1π/4, # radians
    M = 1.0, # Total mass, not stellar mass FYI
    plx=100.0,
    tp =58829 # Epoch of periastron passage. 
)

rvlike = StarAbsoluteRVLikelihood(
    Table(
        epoch=epochs,
        rv=radvel.(orb_template, epochs, planet_sim_mass),
        σ_rv=fill(5.0, size(epochs)),
    ),
    instrument_names=["simulated"]
)


first_epoch_for_tp_ref = first(epochs)
@planet b RadialVelocityOrbit begin
    e ~ Uniform(0,0.999999)
    a ~ truncated(Normal(1, 1),lower=0)
    mass ~ truncated(Normal(1, 1), lower=0)

    # Remove these, we don't need 'em
    # i ~ Sine()
    # Ω ~ UniformCircular()
    ω ~ UniformCircular()
    θ ~ UniformCircular()

    τ ~ UniformCircular(1.0)
    P = (b.a^3/system.M)
    tp =  b.τ*b.P*365.25 + $first_epoch_for_tp_ref # reference epoch for τ. Choose to be near data
end 

@system SimualtedSystem begin
    M ~ truncated(Normal(1, 0.04),lower=0) # (Baines & Armstrong 2011).
    plx = 100.0
    jitter ~ truncated(Normal(0,10),lower=0)
    rv0 ~ Normal(0, 100)
end rvlike b

model = Octofitter.LogDensityModel(SimualtedSystem)

using Random
rng = Xoshiro(0)


results, pt = octofit_pigeons(model, n_rounds=10, explorer=SliceSampler(), n_chains=12, n_chains_variational=0)

Plotting:

using PairPlots, CairoMakie
octocorner(model, results, small=false, includecols=(:iter,:logpost,:b_ωx,:b_ωy,:b_τx,:b_τy), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))

Additional notes: the columns ω, τ, and P are computed from other variables and not sampled directly from the model.

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

One complication is that the variables ω,τ, and θ are computed from their counterparts :θx,:θy and so on using atan(θy,θx).

In the likelihood, I set the length of sqrt(θx^2 + θy^2) ~ Normal(1, 0.1) in order to prevent numerical issues if θx and θy are approximately 0.

So of course in the reference models, that sqrt(θx^2 + θy^2) ~ Normal(1, 0.1) is not applied.

@alexandrebouchard
Copy link
Member

Thanks William! Will discuss in more details in our meeting today, but quick question: is the second type of plot showing goodness of fit comparing data to prediction?

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

Prior only model, sampled with HMC:

image

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

Variational Chains Only

Here is a corner plot of a sampling run with 6 variational chains and no fixed reference.

image

As you can see, the problem goes away! Very interesting.

@miguelbiron
Copy link
Collaborator

Thanks for running that @sefffal ! Definitely helps us narrow down the issue further.

@miguelbiron
Copy link
Collaborator

miguelbiron commented Oct 21, 2024

Btw how did you solve the Chains/sample_array/get_sample issue you mentioned with plain Variational PT?

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

I just pushed a fix to the Octofitter / Pigeons integration: sefffal/Octofitter.jl@49a77c1

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

Here is a further-reduced example that removes those angular "UniformCircular" parameters.

I note also that the weighting between the number of regular chains and variational chains seems to be important.

using Octofitter
using OctofitterRadialVelocity
using CairoMakie
using PairPlots
using Distributions
using PlanetOrbits


epochs = 58849 .+ (20:20:660)
planet_sim_mass = 0.001 # solar masses here


orb_template = orbit(
    a = 1.0,
    e = 0.7,
    # i= pi/4, # You can remove I think
    # Ω = 0.1, # You can remove I think
    ω = 1π/4, # radians
    M = 1.0, # Total mass, not stellar mass FYI
    plx=100.0,
    tp =58829 # Epoch of periastron passage. 
)
# Makie.lines(orb_template)


rvlike = StarAbsoluteRVLikelihood(
    Table(
        epoch=epochs,
        rv=radvel.(orb_template, epochs, planet_sim_mass),
        σ_rv=fill(5.0, size(epochs)),
    ),
    instrument_names=["simulated"]
)

first_epoch_for_tp_ref = first(epochs)
@planet b RadialVelocityOrbit begin
    e ~ Uniform(0,0.999999)
    a ~ truncated(Normal(1, 1),lower=0)
    mass ~ truncated(Normal(1, 1), lower=0)
    ω ~ Uniform(0,2pi)
    τ ~ Uniform(0.0, 1.0)
    tp =  b.τ*√(b.a^3/system.M)*365.25 + $first_epoch_for_tp_ref 
end 

@system SimualtedSystem begin
    M ~ truncated(Normal(1, 0.04),lower=0) # (Baines & Armstrong 2011).
    plx = 100.0
    jitter ~ truncated(Normal(0,10),lower=0)
    rv0 ~ Normal(0, 100)
end rvlike b

model = Octofitter.LogDensityModel(SimualtedSystem)

results, pt = octofit_pigeons(model, n_rounds=10, explorer=SliceSampler(), n_chains=10, n_chains_variational=6)

octocorner(model, results, small=false, includecols=(:iter,:logpost), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))
image

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

Re:

Variational Chains Only

Here is a corner plot of a sampling run with 6 variational chains and no fixed reference.
...
As you can see, the problem goes away! Very interesting.

Actually,
I am not sure this example is right. I get example the same values with n_chains=8, n_chains_variational=0 as I do with n_chains=0, n_chains_variational=8. How can I be sure this was using the variational reference at all?

@miguelbiron
Copy link
Collaborator

miguelbiron commented Oct 21, 2024

You notice the usage because Lambda changes at 5th round. For example,

Traditional PT

inp = Pigeons.Inputs(;
    target = model,
    record = [traces; round_trip; record_default(); index_process],
    multithreaded=true,
    show_report=true,
    n_rounds=10, 
    explorer=SliceSampler(), 
    n_chains=10, 
)
results, pt = octofit_pigeons(inp)

─────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0       3.07     0.0478   5.42e+05       -928          0      0.659          1          1 
        4          0       3.87     0.0428   5.46e+05       -122   5.95e-17       0.57          1          1 
        8          0       4.79     0.0493    8.3e+04       -120   3.94e-12      0.468          1          1 
       16          0       4.13     0.0969   1.53e+05       -107       0.25      0.541          1          1 
       32          0       4.79      0.184   2.58e+05       -105      0.277      0.468          1          1 
       64          0       5.15      0.358   4.72e+05       -106      0.314      0.428          1          1 
      128          1       4.76      0.746    8.8e+05       -108      0.171      0.471          1          1 
      256          5       4.93       1.48    1.7e+06       -108      0.369      0.452          1          1 
      512          9       5.05       3.01   3.25e+06       -108      0.362      0.439          1          1 
 1.02e+03         14       5.05       6.03   6.43e+06       -108      0.285      0.439          1          1 
─────────────────────────────────────────────────────────────────────────────────────────────────────────────

Variational (non-stabilized) PT

inp = Pigeons.Inputs(;
    target = model,
    record = [traces; round_trip; record_default(); index_process],
    multithreaded=true,
    show_report=true,
    n_rounds=10, 
    explorer=SliceSampler(), 
    n_chains=10, 
    variational = GaussianReference(first_tuning_round = 5)
)
results, pt = octofit_pigeons(inp)

─────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0       3.07      0.333    1.7e+07       -105   1.71e-08      0.659          1          1 
        4          0       3.87     0.0486   5.48e+05       -113   9.57e-09       0.57          1          1 
        8          0       4.79     0.0525   8.86e+04       -152   1.56e-40      0.468          1          1 
       16          0       4.12     0.0944   1.63e+05       -105      0.361      0.542          1          1 
       32          0       4.72      0.182   2.76e+05       -108      0.185      0.476          1          1 
       64          6       1.82      0.674   2.53e+07       -108      0.623      0.797          1          1 
      128         17       1.82      0.871    3.5e+07       -108      0.698      0.798          1          1 
      256         32       1.75       1.63   6.97e+07       -108      0.757      0.805          1          1 
      512         72       1.77       3.43    1.4e+08       -108      0.687      0.803          1          1 
 1.02e+03        121       2.17       6.81   2.77e+08       -108       0.71      0.759          1          1 
─────────────────────────────────────────────────────────────────────────────────────────────────────────────

Stabilized PT: note that Lambda is roughly the same as for standard PT, whereas Lambda_var is similar to the final value of (non-stabilized) Variational PT.

inp = Pigeons.Inputs(;
    target = model,
    record = [traces; round_trip; record_default(); index_process],
    multithreaded=true,
    show_report=true,
    n_rounds=10, 
    explorer=SliceSampler(), 
    n_chains=10, 
    n_chains_variational=10,
    variational = GaussianReference(first_tuning_round = 5)
)
results, pt = octofit_pigeons(inp)


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        Λ_var      time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0       3.54       3.07      0.101   1.12e+06       -105  7.61e-163      0.652          1          1 
        4          0       3.89       3.79     0.0443   1.12e+05       -113   9.57e-09      0.596          1          1 
        8          0       3.71        5.1      0.102   1.89e+05       -149   1.21e-39      0.536          1          1 
       16          0       4.07       4.57      0.172   3.25e+05       -108      0.055      0.545          1          1 
       32          0       4.51        4.3      0.335   5.55e+05       -107       0.26      0.536          1          1 
       64         12       4.85       1.72      0.872   1.97e+07       -107      0.288      0.654          1          1 
      128         18       4.94        1.9        1.5   3.57e+07       -108      0.368       0.64          1          1 
      256         37       5.02       1.82       3.08    7.1e+07       -108      0.144       0.64          1          1 
      512         71       4.98       2.14       6.13    1.4e+08       -108      0.361      0.625          1          1 
 1.02e+03        164          5       1.89       12.5   2.82e+08       -108      0.402      0.637          1          1 
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

Ah, okay thanks @miguelbiron. Then I am a bit confused about the interaction between n_chains_variational and n_chains.

Is the following correct?
Regular PT: n_chains=8, n_chains_variational=0, variational=nothing
Stabilized Variational PT: n_chains=8, n_chains_variational=8, variational=GaussianReference....
Non-stabilized variational PT: n_chains=8, n_chains_variational=0, variational=GaussianReference....

@miguelbiron
Copy link
Collaborator

Correct!

@alexandrebouchard
Copy link
Member

alexandrebouchard commented Oct 21, 2024

I agree it's pretty confusing we should improve the naming at some point :(

@miguelbiron
Copy link
Collaborator

Ok so I'm not getting any issues whatsoever. First, as you can see in my post, the log-normalization constant is the same across the 3 alternatives. Second, the following plots are very similar

PT
pt

VariationalPT
vpt

StabilizedPT
stab_vpt

@miguelbiron
Copy link
Collaborator

Btw, I'm using Julia 1.10.5 because 1.11.1 is not working right now.

julia> versioninfo()
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
Threads: 4 default, 0 interactive, 2 GC (on 8 virtual cores)
Environment:
  JULIA_PKG_USE_CLI_GIT = true
  JULIA_DEBUG = 
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 4

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

Oh, @miguelbiron my experiments were with Julia 1.11. I know there is an Enzyme failure with 1.11 but are there other issues too?

@miguelbiron
Copy link
Collaborator

miguelbiron commented Oct 21, 2024

I'm not sure but 1.11 has been a bumpy ride so I decided to wait until it's on a firmer ground. I'm gonna rerun with 1.11 to check if it makes any difference

@miguelbiron
Copy link
Collaborator

miguelbiron commented Oct 21, 2024

Hmm I can't run this line on 1.11.1

julia> Octofitter._kepsolve_use_threads[] = false
ERROR: UndefVarError: `_kepsolve_use_threads` not defined
Stacktrace:
 [1] getproperty(x::Module, f::Symbol)
   @ Base ./Base.jl:31
 [2] top-level scope
   @ ~/projects/Pigeons.jl/test/temp.jl:10

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

I was able to reproduce the problem on Julia 1.10.

@miguelbiron are you using the latest commit from Octofitter?

julia> versioninfo()
Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 12 × Apple M2 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
Threads: 8 default, 0 interactive, 4 GC (on 8 virtual cores)

@miguelbiron
Copy link
Collaborator

Yes, #main like you suggested. Gonna go back to 1.10 then.

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

I went through this again.

Blue is stabilized variational PT, gold is regular PT, and green is non-stabilized variational PT.

Is it possible the problem is caused by how I'm using get_sample(pt, 0)? In some situations, could I be getting samples from the wrong chain? IMO that could explain why the numbers of each kind of chain used seems to be interpolating between a ~prior like distribution and the target.

image

@miguelbiron
Copy link
Collaborator

That makes sense... Why are you doing get_sample(pt, 0) again?

Btw, I rerun and cannot reproduce for the life of me. This is what I'm running (from Pigeons)

Details

include("activate_test_env.jl")

using Pkg
Pkg.add([PackageSpec(name="Octofitter", rev="main"), PackageSpec(name="OctofitterRadialVelocity", rev="main")])

using Octofitter
using OctofitterRadialVelocity
using CairoMakie
using PairPlots
using Distributions
using PlanetOrbits

Octofitter._kepsolve_use_threads[] = false


epochs = 58849 .+ (20:20:660)
planet_sim_mass = 0.001 # solar masses here


orb_template = orbit(
    a = 1.0,
    e = 0.7,
    # i= pi/4, # You can remove I think
    # Ω = 0.1, # You can remove I think
    ω = 1π/4, # radians
    M = 1.0, # Total mass, not stellar mass FYI
    plx=100.0,
    tp =58829 # Epoch of periastron passage. 
)
# Makie.lines(orb_template)


rvlike = StarAbsoluteRVLikelihood(
    Table(
        epoch=epochs,
        rv=radvel.(orb_template, epochs, planet_sim_mass),
        σ_rv=fill(5.0, size(epochs)),
    ),
    instrument_names=["simulated"]
)

first_epoch_for_tp_ref = first(epochs)
@planet b RadialVelocityOrbit begin
    e ~ Uniform(0,0.999999)
    a ~ truncated(Normal(1, 1),lower=0)
    mass ~ truncated(Normal(1, 1), lower=0)
    ω ~ Uniform(0,2pi)
    τ ~ Uniform(0.0, 1.0)
    tp =  b.τ*√(b.a^3/system.M)*365.25 + $first_epoch_for_tp_ref 
end 

@system SimualtedSystem begin
    M ~ truncated(Normal(1, 0.04),lower=0) # (Baines & Armstrong 2011).
    plx = 100.0
    jitter ~ truncated(Normal(0,10),lower=0)
    rv0 ~ Normal(0, 100)
end rvlike b

model = Octofitter.LogDensityModel(SimualtedSystem)

inp_pt = Pigeons.Inputs(;
    target = model,
    record = [traces; round_trip; record_default(); index_process],
    multithreaded=true,
    show_report=true,
    n_rounds=10, 
    explorer=SliceSampler(), 
    n_chains=10, 
    # n_chains_variational=10,
    # variational = GaussianReference(first_tuning_round = 5)
)
inp_vpt = Pigeons.Inputs(;
    target = model,
    record = [traces; round_trip; record_default(); index_process],
    multithreaded=true,
    show_report=true,
    n_rounds=10, 
    explorer=SliceSampler(), 
    n_chains=10, 
    # n_chains_variational=10,
    variational = GaussianReference(first_tuning_round = 5)
)
inp_stab_vpt = Pigeons.Inputs(;
    target = model,
    record = [traces; round_trip; record_default(); index_process],
    multithreaded=true,
    show_report=true,
    n_rounds=10, 
    explorer=SliceSampler(), 
    n_chains=10, 
    n_chains_variational=10,
    variational = GaussianReference(first_tuning_round = 5)
)
results_pt, pt = octofit_pigeons(inp_pt)
results_vpt, pt = octofit_pigeons(inp_vpt)
results_stab_vpt, pt = octofit_pigeons(inp_stab_vpt)

p = octocorner(model, results_pt, small=false, includecols=(:iter,:logpost), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))
save("pt.png", p)
p = octocorner(model, results_vpt, small=false, includecols=(:iter,:logpost), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))
save("vpt.png", p)
p = octocorner(model, results_stab_vpt, small=false, includecols=(:iter,:logpost), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))
save("stab_vpt.png", p)

versioninfo()
Pkg.status()

Here's the relevant output (excluding plots since alreay the logZ values are correct)

(...)
      results_pt, pt = octofit_pigeons(inp_pt)
       results_vpt, pt = octofit_pigeons(inp_vpt)
       results_stab_vpt, pt = octofit_pigeons(inp_stab_vpt)
       
       p = octocorner(model, results_pt, small=false, includecols=(:iter,:logpost), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))
       save("pt.png", p)
       p = octocorner(model, results_vpt, small=false, includecols=(:iter,:logpost), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))
       save("vpt.png", p)
       p = octocorner(model, results_stab_vpt, small=false, includecols=(:iter,:logpost), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))
       save("stab_vpt.png", p)
[ Info: Preparing model
┌ Info: Determined number of free variables
└   D = 8
ℓπcallback(θ, args...): 0.000014 seconds (9 allocations: 320 bytes)
┌ Info: Tuning autodiff
│   chunk_size = 8
└   t = 7.09e-7
┌ Info: Selected auto-diff chunk size
└   ideal_chunk_size = 8
∇ℓπcallback(θ): 0.000009 seconds
[ Info: Determining initial positions and metric using pathfinder
┌ Info: Found a sample of initial positions
└   initial_logpost_range = (-108.8666220931488, -97.86911327709223)
─────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0          3      0.048   5.42e+05       -443  7.64e-304      0.667          1          1 
        4          0       3.83     0.0469   5.49e+05       -102     0.0871      0.574          1          1 
        8          0       4.53     0.0466   8.86e+04       -107   4.03e-06      0.497          1          1 
       16          0       4.47     0.0983   1.59e+05       -109      0.138      0.503          1          1 
       32          0       4.54      0.187   2.56e+05       -107      0.294      0.495          1          1 
       64          0        4.7      0.365   4.84e+05       -107       0.36      0.478          1          1 
      128          2       4.97       0.77   8.79e+05       -108      0.301      0.448          1          1 
      256          2       4.87       1.56    1.7e+06       -107      0.274      0.459          1          1 
      512          4       4.97          3   3.24e+06       -107      0.394      0.448          1          1 
 1.02e+03         16       4.97       6.11   6.43e+06       -108      0.402      0.447          1          1 
─────────────────────────────────────────────────────────────────────────────────────────────────────────────
[ Info: Preparing model
┌ Info: Determined number of free variables
└   D = 8
ℓπcallback(θ, args...): 0.000001 seconds (9 allocations: 320 bytes)
┌ Info: Tuning autodiff
│   chunk_size = 8
└   t = 8.39e-7
┌ Info: Selected auto-diff chunk size
└   ideal_chunk_size = 8
∇ℓπcallback(θ): 0.000001 seconds
─────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0          3      0.333    1.7e+07       -104   1.87e-09      0.667          1          1 
        4          0       4.21     0.0442   5.47e+05       -113   2.57e-10      0.532          1          1 
        8          0       4.03     0.0464   9.75e+04       -156   5.41e-41      0.552          1          1 
       16          0       5.09      0.113   1.53e+05       -117      0.126      0.434          1          1 
       32          0       4.34      0.201   2.68e+05       -108      0.117      0.517          1          1 
       64          9       1.84      0.622    2.5e+07       -108      0.689      0.795          1          1 
      128         20       1.64      0.907   3.49e+07       -108      0.753      0.817          1          1 
      256         37       1.66       1.72      7e+07       -108      0.757      0.816          1          1 
      512         69       1.83       3.49    1.4e+08       -108      0.717      0.796          1          1 
 1.02e+03        128       1.95       7.02   2.79e+08       -108      0.701      0.783          1          1 
─────────────────────────────────────────────────────────────────────────────────────────────────────────────
[ Info: Preparing model
┌ Info: Determined number of free variables
└   D = 8
ℓπcallback(θ, args...): 0.000004 seconds (9 allocations: 320 bytes)
┌ Info: Tuning autodiff
│   chunk_size = 8
└   t = 7.98e-7
┌ Info: Selected auto-diff chunk size
└   ideal_chunk_size = 8
∇ℓπcallback(θ): 0.000002 seconds
[ Info: Preparing model
┌ Info: Determined number of free variables
└   D = 8
ℓπcallback(θ, args...): 0.000003 seconds (9 allocations: 320 bytes)
┌ Info: Tuning autodiff
│   chunk_size = 8
└   t = 8.08e-7
┌ Info: Selected auto-diff chunk size
└   ideal_chunk_size = 8
∇ℓπcallback(θ): 0.000002 seconds
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        Λ_var      time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0       3.38          3     0.0909   1.12e+06       -104  1.59e-164      0.664          1          1 
        4          0       3.99       4.01     0.0455   1.12e+05       -111   2.57e-10      0.579          1          1 
        8          0       3.89       3.55     0.0893   1.96e+05       -147    1.5e-39      0.608          1          1 
       16          0       4.71       3.99      0.174   3.23e+05       -104    0.00798      0.542          1          1 
       32          0       4.72       4.61      0.362   5.42e+05       -117    0.00419      0.509          1          1 
       64          9       4.62        1.8      0.917   1.97e+07      -97.2     0.0644      0.662          1          1 
      128         14       4.76       2.24       1.58   3.54e+07       -108      0.342      0.632          1          1 
      256         37       4.87       2.03       3.04   7.08e+07       -108      0.229      0.637          1          1 
      512         57       4.98       2.33       6.23    1.4e+08       -108      0.354      0.615          1          1 
 1.02e+03        157       5.04       2.13       12.1    2.8e+08       -108      0.371      0.623          1          1 
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
CairoMakie.Screen{IMAGE}


julia> versioninfo()
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
Threads: 4 default, 0 interactive, 2 GC (on 8 virtual cores)
Environment:
  JULIA_PKG_USE_CLI_GIT = true
  JULIA_DEBUG = 
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 4

julia> Pkg.status()
Status `~/projects/Pigeons.jl/test/Project.toml`
  [0bf59076] AdvancedHMC v0.6.2
  [dbc42088] ArgMacros v0.2.4
  [76274a88] Bijectors v0.13.18
  [c88b6f0a] BridgeStan v2.5.0
  [13f3f980] CairoMakie v0.12.14
⌃ [99d987ce] Comrade v0.10.5
  [8bb1440f] DelimitedFiles v1.9.1
  [31c24e10] Distributions v0.25.112
  [ced4e74d] DistributionsAD v0.6.57
⌃ [366bfd00] DynamicPPL v0.28.6
⌅ [7da242da] Enzyme v0.12.36
  [7a1cc6ca] FFTW v1.8.0
  [1a297f60] FillArrays v1.13.0
  [f6369f11] ForwardDiff v0.10.36
  [09f84164] HypothesisTests v0.11.3
  [682c06a0] JSON v0.21.4
  [92481ed7] LinearRegression v0.2.1
  [6fdf6af0] LogDensityProblems v2.1.2
⌃ [996a588d] LogDensityProblemsAD v1.10.1
  [2ab3a3ac] LogExpFunctions v0.3.28
  [c7f686f2] MCMCChains v6.0.6
  [be115224] MCMCDiagnosticTools v0.3.10
  [da04e1cc] MPI v0.20.22
  [3da0fdf6] MPIPreferences v0.1.11
  [daf3887e] Octofitter v4.0.0 `https://github.com/sefffal/Octofitter.jl.git#main`
  [c6a353d9] OctofitterRadialVelocity v4.0.0 `https://github.com/sefffal/Octofitter.jl.git:OctofitterRadialVelocity#main`
  [a15396b6] OnlineStats v1.7.1
  [43a3c2be] PairPlots v2.9.2
  [0eb8d820] Pigeons v0.4.6 `~/projects/Pigeons.jl`
  [fd6f9641] PlanetOrbits v0.10.1
  [91a5bcdd] Plots v1.40.8
  [c3e4b0f8] Pluto v0.20.0
  [7f904dfe] PlutoUI v0.7.60
  [37e2e3b7] ReverseDiff v1.15.3
  [276daf66] SpecialFunctions v2.4.0
  [8efc31e9] SplittableRandoms v0.1.2
⌅ [b1ba175b] VLBIImagePriors v0.8.4
  [a5390f91] ZipFile v0.10.1
  [b77e0a4c] InteractiveUtils
  [37e2e46d] LinearAlgebra
  [d6f4376e] Markdown
  [44cfe95a] Pkg v1.10.0
  [9a3f8284] Random
  [9e88b42a] Serialization
  [10745b16] Statistics v1.10.0
  [8dfed614] Test
Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated`

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

Here is some code to expose the calls to Inputs, pigeons, and get_sample without going through my convenience wrapper octofit_pigeons:

inputs = Pigeons.Inputs(
       target=model,
       explorer=SliceSampler(),
       n_rounds=10,
       n_chains=10,
       n_chains_variational=6,
       record=[traces; round_trip; record_default(); index_process],
       multithreaded=true,
       variational=GaussianReference()
);
pt = pigeons(inputs);
samples = get_sample(pt,10)
chn =  Octofitter.result2mcmcchain(  model.arr2nt.(model.invlink.(s[1:model.D] for s in samples)));

If we look at the two target chains they are quite different:

julia> Pigeons.target_chains(pt2)
2-element Vector{Int64}:
 10
 11

julia> chn_10 =  Octofitter.result2mcmcchain(  model.arr2nt.(model.invlink.(s[1:model.D] for s in get_sample(pt2,10))));

julia> chn_11 =  Octofitter.result2mcmcchain(  model.arr2nt.(model.invlink.(s[1:model.D] for s in get_sample(pt2,11))));

Julia> octocorner(model, chn_11, chn_10, small=false, viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))

image

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

@miguelbiron I ran your script, and agree it did not reproduce the error. Curious.

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

Okay I found the problem. Whatever the underlying issue, it is quite sensitive to the number of chains.

This does not have the issue:

inp_stab_vpt = Pigeons.Inputs(;
   target = model,
   record = [traces; round_trip; record_default(); index_process],
   multithreaded=true,
   show_report=true,
   n_rounds=10, 
   explorer=SliceSampler(), 
   n_chains=10, 
   n_chains_variational=10,
   variational = GaussianReference(first_tuning_round = 5)
)

But this does reproduce it:

inp_stab_vpt = Pigeons.Inputs(;
   target = model,
   record = [traces; round_trip; record_default(); index_process],
   multithreaded=true,
   show_report=true,
   n_rounds=10, 
   explorer=SliceSampler(), 
   n_chains=10, 
   n_chains_variational=6,
   variational = GaussianReference(first_tuning_round = 5)
)

See n_chains_variational=6.

@miguelbiron
Copy link
Collaborator

Ah yes yes I see at some point I changed the number on my script. Rerunning again...

@miguelbiron
Copy link
Collaborator

Ah yes now I see it. Sorry for the confusion.

@miguelbiron
Copy link
Collaborator

Oof I think I found the issue: around July 2023, the order of the fixed/var legs was swapped in the linear indexing of the combined chains

# Note: 2023/07/20: changed order to have variational first (as depicted below)
# to simplify log(Z) code for 2-legged
# <--- variational ----> <----- fixed ------>
# reference ----- target -- target ---- reference
# 1 ----- N -- N + 1 ---- 2N

But this was not propagated to the swap graph definition

is_target(deo::VariationalDEO, chain::Int) = (chain == deo.n_chains_fixed) || (chain == deo.n_chains_fixed + 1)

This should instead be

is_target(deo::VariationalDEO, chain::Int) = (chain == deo.n_chains_var) || (chain == deo.n_chains_var + 1)

(is_reference is correct)

Can confirm that with this change, the results look correct. I'm making a PR with a small test for this issue.

@trevorcampbell
Copy link
Collaborator

trevorcampbell commented Oct 21, 2024

@miguelbiron can we add a unit test for this too? This is a super subtle/insidious bug.

(woops, just saw that you are doing a test already. good stuff :-) )

Thanks again @sefffal for the issue thread!

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

Wow, great spot @miguelbiron !

I'm left wondering though: why do these two target legs look so different? I.e. how did these samples with way lower posterior density end up in the second target chain to begin with?

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

I guess that all my previous results with stabilized VPT also had this issue, and I only noticed there was a problem when the target chain from the variational leg was misbehaving?

@miguelbiron
Copy link
Collaborator

@sefffal because chains (10,11) where actually both in the interior of the fixed leg; so neither was actually a target. But since we use adaptive schedules, two consecutive chains can target surprisingly different distributions, if that arrangement is the one that leads to equirejection.

Concerning your past experiments with SVPT, I think you were most like getting incorrect samples too from the interior of one or the other leg (depending on the number of chains in each). Sometimes the samples can look ok because of the interior points can be arbitrarily close to the endpoints thanks to adaptive scheduling.

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

The error was sensitive to the number of chains; was that because some numbers of chains didn't have this issue, or because in some cases the samples were arbitrarily close to the endpoints due to the adaptive scheduling?

Sorry for all the questions; I just want to know how this will impact previous results.

@miguelbiron
Copy link
Collaborator

No worries! It's probably best to look at the indexer in question of your example

julia> pt_2_legs.shared.tempering.indexer.i2t
16-element Vector{@NamedTuple{chain::Int64, leg::Symbol}}:
 (chain = 1, leg = :variational)
 (chain = 2, leg = :variational)
 (chain = 3, leg = :variational)
 (chain = 4, leg = :variational)
 (chain = 5, leg = :variational)
 (chain = 6, leg = :variational)
 (chain = 10, leg = :fixed)
 (chain = 9, leg = :fixed)
 (chain = 8, leg = :fixed)
 (chain = 7, leg = :fixed)
 (chain = 6, leg = :fixed)
 (chain = 5, leg = :fixed)
 (chain = 4, leg = :fixed)
 (chain = 3, leg = :fixed)
 (chain = 2, leg = :fixed)
 (chain = 1, leg = :fixed)

This is a vector that, for each i 1:(total number of chains), tells us to which leg it belongs and to which index it corresponds within that leg. Now, the true target legs are entries 6 and 7

 (chain = 6, leg = :variational)
 (chain = 10, leg = :fixed)

But the failed target_chains was reporting 10 and 11, corresponding to

 (chain = 7, leg = :fixed)
 (chain = 6, leg = :fixed)

I.e., 2 interior chains of the fixed leg. You can see how different combinations of n_chains and n_chains_variational parameters would result in different output. But most likely, target_chains was returning a pair of chains in the interior of a leg instead of the true 2 targets in different legs.

How different those two interior distributions were from the actual target is impossible to say a priori, given that their beta parameters could've been close to 1---and therefore the samples could've looked ok.

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

Thanks @miguelbiron, that really clears it up.

So can we say this issue did not occur when the number of regular and variational chains were the same?

@miguelbiron
Copy link
Collaborator

For sure; in that special case, the bug does not appear.

@sefffal
Copy link
Author

sefffal commented Oct 21, 2024

Excellent! I am relieved, since our published results had so far used the same numbers of regular and variational chains.

I appreciate your help on this very much.

@miguelbiron
Copy link
Collaborator

OMG I'm so glad that that is the case---phew!!! 😌

@alexandrebouchard
Copy link
Member

Phhheewww.. thank you so much William for reporting. And Miguel, you have eagle eyes.. amazing job

@trevorcampbell
Copy link
Collaborator

I am relieved, since our published results had so far used the same numbers of regular and variational chains

Thank goodness!....

and thanks again @miguelbiron and @sefffal for hunting this very bad bug down 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants