Open
Description
openedon Oct 30, 2024
In 5SSD0, we have a simple binary classification model:
@model function linear_classification(y,X)
θ ~ MvNormalMeanCovariance(zeros(D), diageye(D))
for i in eachindex(y)
y[i] ~ Probit(dot(θ, X[i]))
end
end
results = infer(
model = linear_classification(),
data = (y = y, X = X),
returnvars = (θ = KeepLast()),
predictvars = (y = KeepLast()),
iterations = 10,
)
Requesting a prediction will throw:
RuleMethodError: no method matching rule for the given arguments
Possible fix, define:
@rule Probit(:out, Marginalisation) (q_in::NormalMeanVariance, meta::ProbitMeta) = begin
return ...
end
ReactiveMP complains that the Probit node's rule for making output predictions does not exist. Upon inspection of source code, I find:
@rule Probit(:out, Marginalisation) (m_in::UnivariateNormalDistributionsFamily, meta::Union{ProbitMeta, Nothing}) = begin
p = normcdf(mean(m_in) / sqrt(1 + var(m_in)))
return Bernoulli(p)
end
So, the rule does exist, but for m_in
not q_in
.
Seeing as binary classification is a pretty important use case, I think we need to have a rule for q_in
. Can I just copy-paste or is there some reason we don't have a q_in
rule?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment