-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX backend fails for latent scan variables #6718
Comments
Possibly related to #6351 The issue is that there is not a 1-to-1 map between the Scan RV and the Scan value variable (due to the weird output of Scan actually being a Slice I think) |
It works with the default backend just fine, I was hoping it was something to do with how the |
Ah okay, that sounds different then |
Regardless of this issue, we should definitely clean up the |
How complex of a fix would that be? |
By the way this seems to be triggered by the |
I don't quite know, but worth a look. An option is to grab the user provided mode and exclude rewrites that are incompatible with JAX (since we know then that we are compiling to JAX. Otherwise we could have an optional kwarg to the dispatch function with the |
Yeah it is trying to feed a numpy generator as input. This would also fix #6697 which would be a big improvement. |
An immediate solution to your problem is to pass a valid model.register_rv(traj, name='traj', initval=np.zeros(100)) |
Good to know! I can try to have a look at the mode problem as well over the next couple days if you're busy with other stuff. |
Describe the issue:
Not sure if this belongs here or in the pytensor repo. Putting it here because the minimal example I can come up with uses PyMC. If you make a scan variable, register it without observations, then use it for further computation, the graph will fail to compile.
Reproduceable code example:
Error message:
PyMC version information:
The text was updated successfully, but these errors were encountered: