Skip to content

Should branching logps accept constants #7711

Open
@ricardoV94

Description

@ricardoV94

Description

The following example illustrates a restriction in the current logp derivations, when branch includes constants

import pytensor.tensor as pt
import pymc as pm

t = pt.arange(10)
cat = pm.Categorical.dist(p=[0.5, 0.5], shape=(10,))
# cat_fixed = pt.where(t > 5, cat, -1)  # Not accepted because -1 is not measurable
cat_fixed = pt.where(t > 5, cat, pm.DiracDelta.dist(-1, shape=cat.shape))  # fine
pm.logp(cat_fixed, cat_fixed.type())

Should we allow it? This also applies to operations like join and make_vector where one may combine measurable and constant inputs.

If we allow it should we also allow broadcasting? This is currently not allowed (hence the need for shape=cat.shape) because the logp of broadcasted operations can be tricky to handle systematically, but for constants it may be fine?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions