Skip to content

Probit node :out marginalisation not defined for q_in, which is needed for binary linear classification. #425

Description

In 5SSD0, we have a simple binary classification model:

@model function linear_classification(y,X)
    
    θ ~ MvNormalMeanCovariance(zeros(D), diageye(D))
    
    for i in eachindex(y)
        y[i] ~ Probit(dot(θ, X[i]))
    end
end

results = infer(
    model       = linear_classification(),
    data        = (y = y, X = X),
    returnvars  = (θ = KeepLast()),
    predictvars = (y = KeepLast()),
    iterations  = 10,
)

Requesting a prediction will throw:

RuleMethodError: no method matching rule for the given arguments

Possible fix, define:

@rule Probit(:out, Marginalisation) (q_in::NormalMeanVariance, meta::ProbitMeta) = begin 
    return ...
end

ReactiveMP complains that the Probit node's rule for making output predictions does not exist. Upon inspection of source code, I find:

@rule Probit(:out, Marginalisation) (m_in::UnivariateNormalDistributionsFamily, meta::Union{ProbitMeta, Nothing}) = begin
    p = normcdf(mean(m_in) / sqrt(1 + var(m_in)))
    return Bernoulli(p)
end

So, the rule does exist, but for m_in not q_in.

Seeing as binary classification is a pretty important use case, I think we need to have a rule for q_in. Can I just copy-paste or is there some reason we don't have a q_in rule?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

Labels

enhancementNew feature or requestquestionFurther information is requested

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions