Skip to content

Probit node :out marginalisation not defined for q_in, which is needed for binary linear classification. #425

@wmkouw

Description

@wmkouw

In 5SSD0, we have a simple binary classification model:

@model function linear_classification(y,X)
    
    θ ~ MvNormalMeanCovariance(zeros(D), diageye(D))
    
    for i in eachindex(y)
        y[i] ~ Probit(dot(θ, X[i]))
    end
end

results = infer(
    model       = linear_classification(),
    data        = (y = y, X = X),
    returnvars  = (θ = KeepLast()),
    predictvars = (y = KeepLast()),
    iterations  = 10,
)

Requesting a prediction will throw:

RuleMethodError: no method matching rule for the given arguments

Possible fix, define:

@rule Probit(:out, Marginalisation) (q_in::NormalMeanVariance, meta::ProbitMeta) = begin 
    return ...
end

ReactiveMP complains that the Probit node's rule for making output predictions does not exist. Upon inspection of source code, I find:

@rule Probit(:out, Marginalisation) (m_in::UnivariateNormalDistributionsFamily, meta::Union{ProbitMeta, Nothing}) = begin
    p = normcdf(mean(m_in) / sqrt(1 + var(m_in)))
    return Bernoulli(p)
end

So, the rule does exist, but for m_in not q_in.

Seeing as binary classification is a pretty important use case, I think we need to have a rule for q_in. Can I just copy-paste or is there some reason we don't have a q_in rule?

Metadata

Metadata

Assignees

Labels

enhancementNew feature or requestquestionFurther information is requested

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions