-
Notifications
You must be signed in to change notification settings - Fork 108
Open
Description
There was a recent PR #436 on fixing the issues of off-by-1 error in non-stationary HMM transitions code by @colecitrenbaum. However, the code seems to have another bug and if anyone has encountered it before, please let me know.
In abstractions.py, at line L341:
lp = jnp.sum(expected_transitions * log_trans_matrix)It gives an error:
lp = jnp.sum(expected_transitions * log_trans_matrix)
~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~
return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
~~~~~~~^^^^^^
TypeError: mul got incompatible shapes for broadcasting: (9, 3, 3), (8, 3, 3).
on a setup of T=10 timesteps.
I have attached below the screenshot of where I think the bug could be (highlighted in pink, L276 and L298), but I can also provide a minimal example to reproduce the error if needed. To me, it seems like there is no need to do pytree_slice at L298, as that is taken care of at L276 when computing transition matrices.

Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels