Open
Description
Description
The following example illustrates a restriction in the current logp derivations, when branch includes constants
import pytensor.tensor as pt
import pymc as pm
t = pt.arange(10)
cat = pm.Categorical.dist(p=[0.5, 0.5], shape=(10,))
# cat_fixed = pt.where(t > 5, cat, -1) # Not accepted because -1 is not measurable
cat_fixed = pt.where(t > 5, cat, pm.DiracDelta.dist(-1, shape=cat.shape)) # fine
pm.logp(cat_fixed, cat_fixed.type())
Should we allow it? This also applies to operations like join
and make_vector
where one may combine measurable and constant inputs.
If we allow it should we also allow broadcasting? This is currently not allowed (hence the need for shape=cat.shape
) because the logp of broadcasted operations can be tricky to handle systematically, but for constants it may be fine?