Skip to content

Off-by-one error in non-stationary HMM transitions #438

@umeshksingla

Description

@umeshksingla

There was a recent PR #436 on fixing the issues of off-by-1 error in non-stationary HMM transitions code by @colecitrenbaum. However, the code seems to have another bug and if anyone has encountered it before, please let me know.

In abstractions.py, at line L341:

lp = jnp.sum(expected_transitions * log_trans_matrix)

It gives an error:

     lp = jnp.sum(expected_transitions * log_trans_matrix)
                 ~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~
    return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
           ~~~~~~~^^^^^^
TypeError: mul got incompatible shapes for broadcasting: (9, 3, 3), (8, 3, 3).

on a setup of T=10 timesteps.

I have attached below the screenshot of where I think the bug could be (highlighted in pink, L276 and L298), but I can also provide a minimal example to reproduce the error if needed. To me, it seems like there is no need to do pytree_slice at L298, as that is taken care of at L276 when computing transition matrices.

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions