-
Notifications
You must be signed in to change notification settings - Fork 136
Small fixups to xtensor type and XRV #1503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5761704
cd017f5
e3e4afe
0978d11
9ecfc10
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,15 +5,16 @@ | |
import pytensor.tensor.random.basic as ptr | ||
from pytensor.graph.basic import Variable | ||
from pytensor.tensor.random.op import RandomVariable | ||
from pytensor.xtensor import as_xtensor | ||
from pytensor.xtensor.math import sqrt | ||
from pytensor.xtensor.type import as_xtensor | ||
from pytensor.xtensor.vectorization import XRV | ||
|
||
|
||
def _as_xrv( | ||
core_op: RandomVariable, | ||
core_inps_dims_map: Sequence[Sequence[int]] | None = None, | ||
core_out_dims_map: Sequence[int] | None = None, | ||
name: str | None = None, | ||
): | ||
"""Helper function to define an XRV constructor. | ||
|
||
|
@@ -41,7 +42,14 @@ def _as_xrv( | |
core_out_dims_map = tuple(range(core_op.ndim_supp)) | ||
|
||
core_dims_needed = max( | ||
(*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0 | ||
max( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just to check my understanding: This returns how many core dims the "broadcasting" between the inputs and outputs will have? For each input "map", it's returning the largest core dim index, then the largest core dim among all inputs, then the largest between the inputs and the outputs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not quite. the mapping tells if the user passes a list of n core dims (say 2 in the MvNormal, the covariance dims), which of these correspond to each input / output, positionally. From this it is trivial to infer how many the user has to pass, so we can give an automatic useful message. With zero based index you need to pass a sequence that is as long as the largest index + 1. The problem is there is a difference between 0 and empty in this case, which we weren't handling correctly before. |
||
( | ||
max((entry + 1 for entry in dims_map), default=0) | ||
for dims_map in core_inps_dims_map | ||
), | ||
default=0, | ||
), | ||
max((entry + 1 for entry in core_out_dims_map), default=0), | ||
) | ||
|
||
@wraps(core_op) | ||
|
@@ -76,7 +84,10 @@ def xrv_constructor( | |
extra_dims = {} | ||
|
||
return XRV( | ||
core_op, core_dims=full_core_dims, extra_dims=tuple(extra_dims.keys()) | ||
core_op, | ||
core_dims=full_core_dims, | ||
extra_dims=tuple(extra_dims.keys()), | ||
name=name, | ||
)(rng, *extra_dims.values(), *params) | ||
|
||
return xrv_constructor | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is dtype = 'floatX' being depreciated? (I'm trying to guess what "still" means here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you create a RandomVariable Op you can specify
dtype="floatX"
at the Op level. But when we make an actual node we need to commit to one dtype, since floatX is not a real thing.If you call
__call__
we already commit to a dtype, and this is where users can specify a custom one. But if you call directlymake_node
likeXRV
does, it doesn't go through this step. It's a quirk of how we are wrapping RV ops in xtensor, but in theory if you have an Op you should always be able to callmake_node
and get a valid graph.