Description
Ran into an issue that I think a lot of people will get who want to try out the rxinference()
function.
For the simple model
@model function observation_model(Hs, Qs, R)
# specify data variables
μ_s = datavar(Vector{Float64})
Σ_s = datavar(Matrix{Float64})
# specify priors
s_prev ~ MvNormalMeanCovariance(μ_s, Σ_s)
# add process noise
s ~ MvNormalMeanCovariance(s_prev, Qs)
# form observation
y = datavar(Float64)
y ~ NormalMeanVariance(dot(Hs, s), R)
# return variables
return y, s, n
end
with auto updates:
autoupdates = @autoupdates begin
μ_s, Σ_s = mean_cov(q(s))
end;
I started off with the following implementation:
rxinference(
model = observation_model(Hs, Qs, R),
data = (y = data, ),
autoupdates = autoupdates,
initmarginals = (
s_prev = vague(MvNormalMeanCovariance, deployed_model_signal.dim_in),
),
returnvars = (:s,),
keephistory = length(data),
historyvars = (s = KeepLast(),),
autostart = true,
)
which did not start the inference procedure, i.e. methoderror: iterate(::Nothing)
as a result of mean_cov(::Missing)
. It turned out that initmarginals
first sets the marginals, then autoupdates
is called and then inference actually starts. Below code did run:
rxinference(
model = observation_model(Hs, Qs, R),
data = (y = data, ),
autoupdates = autoupdates,
initmarginals = (
s= vague(MvNormalMeanCovariance, deployed_model_signal.dim_in),
), # specifies the initial q(s) to get the inference starting (i.e. first autoupdates, then message passing)
returnvars = (:s,),
keephistory = length(data),
historyvars = (s = KeepLast(),),
autostart = true,
)
I understand why it is implemented as such, but it is not clear from the error message what is going wrong. Perhaps we can improve the error handling by first check whether all marginals in autoupdates
are specified in the initmarginals
struct. If this is not the case we should throw an error message stating that first marginals are set, then autoupdates is called and then data is fed into the model.