Description
Hi @VasylHafych, Hi @oschulz. I apologize for the long message but I want to tell you what I did and where did I get stuck
As you asked me on Friday I'm creating an issue to explain what I found regarding type instability in bat_sample. The problem is simple:
- This works:
@inferred(bat_sample(MvNormal([0.4, 0.6], [2.0 1.2; 1.2 3.0]), IIDSampling(nsamples=10^6)))
- This doesn't work:
iid_distribution = NamedTupleDist(a = mixture_model,)
,@inferred(bat_sample(iid_distribution, IIDSampling(nsamples=10^6)))
- This doesn't work:
@inferred(bat_sample(posterior, ps))
,ps = PartitionedSampling(sampler = mcmc, npartitions=4, exploration_sampler=mcmc_exp, integrator = ahmi, nmax_resampling=5)
. For some posterior.
The output of, e.g., @code_warntype bat_sample(iid_distribution, IIDSampling(nsamples=10^6))
is:
`Variables
#self#::Core.Const(BAT.bat_sample)
target::NamedTupleDist{(:a,), Tuple{MixtureModel{Multivariate, Continuous, MvNormal, Categorical{Float64, Vector{Float64}}}}, Tuple{ValueAccessor{ArrayShape{Real, 1}}}}
algorithm::IIDSampling
Body::NamedTuple{(:result, :optargs, :kwargs), _A} where _A<:Tuple
1 ─ nothing
│ %2 = Core.NamedTuple()::Core.Const(NamedTuple())
│ %3 = Base.pairs(%2)::Core.Const(Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}())
│ %4 = BAT.:(var"#bat_sample#166")(%3, #self#, target, algorithm)::NamedTuple{(:result, :optargs, :kwargs), _A} where _A<:Tuple
└── return %4`
where Body::NamedTuple{(:result, :optargs, :kwargs), _A} where _A<:Tuple
is highlighted in red both before 1 - nothing and in %4.
After calling @which bat_sample(iid_distribution, IIDSampling(nsamples=10^6))
I get that the executed code is PATH/TO/DIR/.julia/dev/BAT/src/algotypes/sampling_algorithm.jl:57
which is this code:
@inline function bat_sample(target::AnySampleable, algorithm::AbstractSamplingAlgorithm; kwargs...)
rng = bat_default_withinfo(bat_sample, Val(:rng), target)
bat_sample(rng, target, algorithm; kwargs...)
end
so it's just choosing the random number generator and I guess is not the source of the instability. Is at this point where I'm stuck. I tried to @code_warntype
the implementations of IIDSampling
and PartitionedSampling
but I didn't get anywhere probably because I'm doing something wrong.