Skip to content

Small fixups to xtensor type and XRV #1503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,13 @@ def make_node(self, rng, size, *dist_params):
out_type = TensorType(dtype=self.dtype, shape=static_shape)
outputs = (rng.type(), out_type())

if self.dtype == "floatX":
# Commit to a specific float type if the Op is still using "floatX"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is dtype = 'floatX' being depreciated? (I'm trying to guess what "still" means here)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you create a RandomVariable Op you can specify dtype="floatX" at the Op level. But when we make an actual node we need to commit to one dtype, since floatX is not a real thing.

If you call __call__ we already commit to a dtype, and this is where users can specify a custom one. But if you call directly make_node like XRV does, it doesn't go through this step. It's a quirk of how we are wrapping RV ops in xtensor, but in theory if you have an Op you should always be able to call make_node and get a valid graph.

dtype = config.floatX
props = self._props_dict()
props["dtype"] = dtype
self = type(self)(**props)

return Apply(self, inputs, outputs)

def batch_ndim(self, node: Apply) -> int:
Expand Down
2 changes: 1 addition & 1 deletion pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings

import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg
from pytensor.xtensor import linalg, random
from pytensor.xtensor.math import dot
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import (
Expand Down
2 changes: 1 addition & 1 deletion pytensor/xtensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _as_xelemwise(core_op: ScalarOp) -> XElemwise:
maximum = _as_xelemwise(ps.scalar_maximum)
minimum = _as_xelemwise(ps.scalar_minimum)
second = _as_xelemwise(ps.second)
sigmoid = _as_xelemwise(ps.sigmoid)
sigmoid = expit = _as_xelemwise(ps.sigmoid)
sign = _as_xelemwise(ps.sign)
sin = _as_xelemwise(ps.sin)
sinh = _as_xelemwise(ps.sinh)
Expand Down
17 changes: 14 additions & 3 deletions pytensor/xtensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
import pytensor.tensor.random.basic as ptr
from pytensor.graph.basic import Variable
from pytensor.tensor.random.op import RandomVariable
from pytensor.xtensor import as_xtensor
from pytensor.xtensor.math import sqrt
from pytensor.xtensor.type import as_xtensor
from pytensor.xtensor.vectorization import XRV


def _as_xrv(
core_op: RandomVariable,
core_inps_dims_map: Sequence[Sequence[int]] | None = None,
core_out_dims_map: Sequence[int] | None = None,
name: str | None = None,
):
"""Helper function to define an XRV constructor.

Expand Down Expand Up @@ -41,7 +42,14 @@ def _as_xrv(
core_out_dims_map = tuple(range(core_op.ndim_supp))

core_dims_needed = max(
(*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0
max(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to check my understanding:

This returns how many core dims the "broadcasting" between the inputs and outputs will have? For each input "map", it's returning the largest core dim index, then the largest core dim among all inputs, then the largest between the inputs and the outputs.

Copy link
Member Author

@ricardoV94 ricardoV94 Jun 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite. the mapping tells if the user passes a list of n core dims (say 2 in the MvNormal, the covariance dims), which of these correspond to each input / output, positionally.

From this it is trivial to infer how many the user has to pass, so we can give an automatic useful message. With zero based index you need to pass a sequence that is as long as the largest index + 1. The problem is there is a difference between 0 and empty in this case, which we weren't handling correctly before.

(
max((entry + 1 for entry in dims_map), default=0)
for dims_map in core_inps_dims_map
),
default=0,
),
max((entry + 1 for entry in core_out_dims_map), default=0),
)

@wraps(core_op)
Expand Down Expand Up @@ -76,7 +84,10 @@ def xrv_constructor(
extra_dims = {}

return XRV(
core_op, core_dims=full_core_dims, extra_dims=tuple(extra_dims.keys())
core_op,
core_dims=full_core_dims,
extra_dims=tuple(extra_dims.keys()),
name=name,
)(rng, *extra_dims.values(), *params)

return xrv_constructor
Expand Down
12 changes: 11 additions & 1 deletion pytensor/xtensor/rewriting/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import NodeRewriter
from pytensor.graph.rewriting.basic import NodeRewriter, in2out
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
from pytensor.tensor.rewriting.ofg import inline_ofg_expansion


lower_xtensor_db = EquilibriumDB(ignore_newtrees=False)
Expand All @@ -14,6 +15,15 @@
position=0.1,
)

# Register OFG inline again after lowering xtensor
optdb.register(
"inline_ofg_expansion_xtensor",
in2out(inline_ofg_expansion),
"fast_run",
"fast_compile",
position=0.11,
)


def register_lower_xtensor(
node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion pytensor/xtensor/rewriting/vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def lower_rv(fgraph, node):
size = [*extra_dim_lengths, *param_batch_shape]

# RVs are their own core Op
new_next_rng, tensor_out = core_op(*tensor_params, rng=rng, size=size).owner.outputs
new_next_rng, tensor_out = core_op.make_node(rng, size, *tensor_params).outputs

# Convert output Tensors to XTensors
new_out = xtensor_from_tensor(tensor_out, dims=old_out.type.dims)
Expand Down
8 changes: 7 additions & 1 deletion pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
self.name = name
self.numpy_dtype = np.dtype(self.dtype)
self.filter_checks_isfinite = False
# broadcastable is here just for code that would work fine with XTensorType but checks for it
self.broadcastable = (False,) * self.ndim

def clone(
self,
Expand All @@ -93,6 +95,10 @@
self, value, strict=strict, allow_downcast=allow_downcast
)

@staticmethod
def may_share_memory(a, b):
return TensorType.may_share_memory(a, b)

Check warning on line 100 in pytensor/xtensor/type.py

View check run for this annotation

Codecov / codecov/patch

pytensor/xtensor/type.py#L100

Added line #L100 was not covered by tests

def filter_variable(self, other, allow_convert=True):
if not isinstance(other, Variable):
# The value is not a Variable: we cast it into
Expand Down Expand Up @@ -160,7 +166,7 @@
return None

def __repr__(self):
return f"XTensorType({self.dtype}, {self.dims}, {self.shape})"
return f"XTensorType({self.dtype}, shape={self.shape}, dims={self.dims})"

Check warning on line 169 in pytensor/xtensor/type.py

View check run for this annotation

Codecov / codecov/patch

pytensor/xtensor/type.py#L169

Added line #L169 was not covered by tests

def __hash__(self):
return hash((type(self), self.dtype, self.shape, self.dims))
Expand Down
13 changes: 13 additions & 0 deletions pytensor/xtensor/vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,12 @@
core_op,
core_dims: tuple[tuple[tuple[str, ...], ...], tuple[str, ...]],
extra_dims: tuple[str, ...],
name: str | None = None,
):
super().__init__()
if name is None:
name = getattr(core_op, "name", None)
self.name = name
self.core_op = core_op
inps_core_dims, out_core_dims = core_dims
for operand_dims in (*inps_core_dims, out_core_dims):
Expand All @@ -154,6 +158,15 @@
raise ValueError("size_dims must be unique")
self.extra_dims = tuple(extra_dims)

def __str__(self):
if self.name is not None:
name = self.name
attrs = f"(core_dims={self.core_dims}, extra_dims={self.extra_dims})"

Check warning on line 164 in pytensor/xtensor/vectorization.py

View check run for this annotation

Codecov / codecov/patch

pytensor/xtensor/vectorization.py#L163-L164

Added lines #L163 - L164 were not covered by tests
else:
name = self.__class__.__name__
attrs = f"(core_op={self.core_op}, core_dims={self.core_dims}, extra_dims={self.extra_dims})"
return f"{name}({attrs})"

Check warning on line 168 in pytensor/xtensor/vectorization.py

View check run for this annotation

Codecov / codecov/patch

pytensor/xtensor/vectorization.py#L166-L168

Added lines #L166 - L168 were not covered by tests

def update(self, node):
# RNG input and update are the first input and output respectively
return {node.inputs[0]: node.outputs[0]}
Expand Down
8 changes: 8 additions & 0 deletions tests/sparse/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,10 @@ def test_csm_grad(self):
structured=True,
)

@pytest.mark.skipif(
version.parse(sp.__version__) >= version.parse("1.16.0"),
reason="Scipy 1.16 introduced some changes that make this test fail",
)
def test_csm_sparser(self):
# Test support for gradients sparser than the input.

Expand Down Expand Up @@ -1191,6 +1195,10 @@ def test_csm_sparser(self):

assert len(spmat.data) == len(res)

@pytest.mark.skipif(
version.parse(sp.__version__) >= version.parse("1.16.0"),
reason="Scipy 1.16 introduced some changes that make this test fail",
)
def test_csm_unsorted(self):
# Test support for gradients of unsorted inputs.

Expand Down
15 changes: 14 additions & 1 deletion tests/xtensor/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytensor.tensor.random as ptr
import pytensor.xtensor.random as pxr
from pytensor import function, shared
from pytensor import config, function, shared
from pytensor.graph import rewrite_graph
from pytensor.graph.basic import equal_computations
from pytensor.tensor import broadcast_arrays, tensor
Expand Down Expand Up @@ -112,6 +112,19 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):
)


def test_dtype():
x = normal(0, 1)
assert x.type.dtype == config.floatX

with config.change_flags(floatX="float64"):
x = normal(0, 1)
assert x.type.dtype == "float64"

with config.change_flags(floatX="float32"):
x = normal(0, 1)
assert x.type.dtype == "float32"


def test_normal():
rng = random_generator_type("rng")
c_size = tensor("c_size", shape=(), dtype=int)
Expand Down