Skip to content

WIP: experiment with first class dim objects #1517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import (
as_xtensor,
dim,
xtensor,
xtensor_constant,
)
Expand Down
81 changes: 58 additions & 23 deletions pytensor/xtensor/basic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from collections.abc import Sequence

from pytensor.compile.ops import TypeCastingOp
from pytensor.graph import Apply, Op
from pytensor.scalar.basic import uint64
from pytensor.tensor.basic import ones as tensor_ones
from pytensor.tensor.basic import zeros as tensor_zeros
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import TensorType
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
from pytensor.xtensor.type import DimVariable, XTensorType, as_xtensor, xtensor


DIM_LENGTH_SCALAR = uint64


class XOp(Op):
Expand Down Expand Up @@ -32,6 +37,7 @@ def make_node(self, x):
return Apply(self, [x], [output])

def L_op(self, inputs, outs, g_outs):
# TODO fix
[x] = inputs
[g_out] = g_outs
return [xtensor_from_tensor(g_out, dims=x.type.dims)]
Expand All @@ -41,46 +47,49 @@ def L_op(self, inputs, outs, g_outs):


class XTensorFromTensor(XTypeCastOp):
__props__ = ("dims",)

def __init__(self, dims: Sequence[str]):
super().__init__()
self.dims = tuple(dims)
__props__ = ()

def make_node(self, x):
def make_node(self, x, *dims):
if not isinstance(x.type, TensorType):
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
return Apply(self, [x], [output])
output = xtensor(dtype=x.type.dtype, dims=dims)
return Apply(self, [x, *dims], [output])

def L_op(self, inputs, outs, g_outs):
# TODO fix
[g_out] = g_outs
return [tensor_from_xtensor(g_out)]


def xtensor_from_tensor(x, dims, name=None):
return XTensorFromTensor(dims=dims)(x, name=name)
def xtensor_from_tensor(x, dims, name=None, check: bool = True):
if check:
x = specify_shape(x, [dim.size for dim in dims])
return XTensorFromTensor()(x, *dims, name=name)


class Rename(XTypeCastOp):
__props__ = ("new_dims",)
class MapDims(XTypeCastOp):
__props__ = ("new_dim_indices",)

def __init__(self, new_dims: tuple[str, ...]):
super().__init__()
self.new_dims = new_dims
def __init__(self, new_dim_indices: tuple[int, ...]):
self.new_dims_indices = new_dim_indices

def make_node(self, x):
def make_node(self, x, *new_dims):
x = as_xtensor(x)
output = x.type.clone(dims=self.new_dims)()
new_dims = list(x.dims)
for i, idx in enumerate(self.new_dims_indices):
new_dims[idx] = new_dims[i]

output = x.type.clone(dims=new_dims)()
return Apply(self, [x], [output])

def L_op(self, inputs, outs, g_outs):
# TODO fix
[x] = inputs
[g_out] = g_outs
return [rename(g_out, dims=x.type.dims)]
return [map_dims(g_out, dims=x.type.dims)]


def rename(x, name_dict: dict[str, str] | None = None, **names: str):
def map_dims(x, name_dict: dict[DimVariable, DimVariable] | None = None, **names):
if name_dict is not None:
if names:
raise ValueError("Cannot use both positional and keyword names in rename")
Expand All @@ -97,4 +106,30 @@ def rename(x, name_dict: dict[str, str] | None = None, **names: str):
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
)

return Rename(tuple(new_names))(x)
return MapDims(tuple(new_names))(x)


def zeros(*dims, dtype=None, name=None):
"""Create a new XTensor filled with zeros."""
if not dims:
raise ValueError("At least one dimension must be specified")

return xtensor_from_tensor(
tensor_zeros(shape=[dim.size for dim in dims], dtype=dtype),
dims=dims,
name=name,
check=False,
)


def ones(*dims, dtype=None, name=None):
"""Create a new XTensor filled with zeros."""
if not dims:
raise ValueError("At least one dimension must be specified")

return xtensor_from_tensor(
tensor_ones(shape=[dim.size for dim in dims], dtype=dtype),
dims=dims,
name=name,
check=False,
)
169 changes: 169 additions & 0 deletions pytensor/xtensor/dims.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from __future__ import annotations

from uuid import uuid4

import numpy as np

from pytensor.graph.basic import Apply
from pytensor.graph.op import Op, Variable
from pytensor.scalar.basic import ScalarVariable
from pytensor.xtensor.type import (
DIM_LENGTH_SCALAR,
BasicDim,
CloneDim,
DimType,
DimVariable,
XTensorVariable,
)


class DimOp(Op):
def perform(self, node, inputs, outputs):
raise NotImplementedError(
f"xtensor operation {self} must be lowered to equivalent tensor operations"
)


# Not a dim op, because it doesn't return a DimVariable
class Length(Op):
__props__ = ()

def make_node(self, *inputs: Variable) -> Apply:
(x,) = inputs
if not isinstance(x, DimVariable):
raise TypeError(f"x must be a DimVariable, got {type(x.type)}")
return Apply(self, [x], [DIM_LENGTH_SCALAR()])

def perform(self, node, inputs, outputs):
outputs[0][0] = inputs[0]


def _dim_size(dim: DimVariable) -> ScalarVariable:
return Length()(dim)


class FromLength(DimOp):
__props__ = ("dim_type",)

def __init__(self, dim_type: DimType):
super().__init__()
self.dim_type = dim_type

def make_node(self, *inputs: Variable) -> Apply:
(length,) = inputs
if not isinstance(length, ScalarVariable):
raise TypeError(f"length must be a ScalarVariable, got {type(length.type)}")
if length.type != DIM_LENGTH_SCALAR:
raise TypeError(
f"length must be of dtype 'DIM_LENGTH_SCALAR', got {length.type.dtype}"
)
return Apply(self, [length], [self.dim_type()])

def perform(self, node, inputs, outputs):
"""Convert the length to a list of lengths."""
outputs[0][0] = inputs[0]


def from_length(length: ScalarVariable, name: str | None = None) -> DimVariable:
# TODO add check for dtype
if not isinstance(length, ScalarVariable):
raise TypeError(f"length must be a ScalarVariable, got {type(length.type)}")
if length.type != DIM_LENGTH_SCALAR:
raise TypeError(
f"length must be of dtype 'DIM_LENGTH_SCALAR', got {length.type.dtype}"
)

uuid = uuid4()
dim_type = BasicDim(uuid=uuid, name=name)
op = FromLength(dim_type)
return op(length, name=name)


class FromTensor(Op):
__props__ = ("dim_type",)

def __init__(self, dim_type: DimType):
super().__init__()
self.dim_type = dim_type

def make_node(self, *inputs: Variable) -> Apply:
(x,) = inputs
if not isinstance(x, XTensorVariable):
raise TypeError(f"x must be an XTensorVariable, got {type(x.type)}")
return Apply(self, [x], [self.dim_type()])

def perform(self, node, inputs, outputs):
"""Convert the tensor to a dimension variable."""
(x,) = inputs
(x_var,) = node.inputs
for i, dim in enumerate(x_var.type.dims):
if dim == self.dim_type:
outputs[0][0] = x.shape[i]
return
raise ValueError(f"Dimension {self.dim_type} not found in tensor {x.type.dims}")


def _dim_from_tensor(x: XTensorVariable, idx: int) -> DimVariable:
op = FromTensor(dim_type=x.type.dims[idx])
return op(x, name=x.type.dims[idx].name)


class Clone(Op):
__props__ = ("dim_type",)

def __init__(self, dim_type):
super().__init__()
self.dim_type = dim_type

def make_node(self, *inputs: Variable) -> Apply:
(x,) = inputs
if not isinstance(x, DimVariable):
raise TypeError(f"x must be a DimVariable, got {type(x.type)}")
return Apply(self, [x], [self.dim_type()])

def perform(self, node, inputs, outputs):
outputs[0][0] = inputs[0]


def _clone_dim(dim: DimVariable, *, name: str | None = None) -> DimVariable:
"""Rename a dimension variable.

Args:
name: The new name for the dimension.

Returns:
A new DimVariable with the updated name.
"""
dim_type = CloneDim(uuid=uuid4(), base=dim.type)
return Clone(dim_type)(dim, name=name)


class Product(Op):
__props__ = ()

def make_node(self, *dims: Variable) -> Apply:
if not all(isinstance(dim, DimVariable) for dim in dims):
raise TypeError("All inputs must be DimVariables.")
out = dim_type()
return Apply(self, list(dims), [out])

def perform(self, node, inputs, outputs):
outputs[0][0] = np.prod(inputs, dtype=DIM_LENGTH_SCALAR.dtype).item()


def product_dim(*dims: DimVariable, name: str | None = None) -> DimVariable:
return Product()(*dims, name=name)


def rebase_dim(dim: DimVariable, *tensors: XTensorVariable) -> DimVariable:
if not isinstance(dim, DimVariable):
raise TypeError(f"dim must be a DimVariable, got {type(dim)}")

if not tensors:
raise ValueError("At least one tensor must be provided for rebasing.")

for tensor in tensors:
for i, tensor_dim in enumerate(tensor.type.dims):
if dim.type == tensor_dim:
return _dim_from_tensor(tensor, idx=i)
raise ValueError(f"Dimension {dim.type} not found in any of the provided tensors.")
42 changes: 30 additions & 12 deletions pytensor/xtensor/rewriting/basic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from pytensor.graph import node_rewriter
from pytensor.tensor.basic import register_infer_shape
from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
register_infer_shape,
register_useless,
)
from pytensor.xtensor.basic import (
Rename,
MapDims,
TensorFromXTensor,
XTensorFromTensor,
xtensor_from_tensor,
)
from pytensor.xtensor.dims import FromLength, Length
from pytensor.xtensor.rewriting.utils import register_lower_xtensor


Expand All @@ -22,30 +26,44 @@ def useless_tensor_from_xtensor(fgraph, node):
return [x.owner.inputs[0]]


@register_infer_shape
@register_useless
@register_canonicalize
@register_lower_xtensor
@node_rewriter(tracks=[XTensorFromTensor])
# TODO
# @register_infer_shape
# @register_useless
# @register_canonicalize
# @register_lower_xtensor
# @node_rewriter(tracks=[XTensorFromTensor])
def useless_xtensor_from_tensor(fgraph, node):
"""XTensorFromTensor(TensorFromXTensor(x)) -> x"""
[x] = node.inputs
# TODO
[x, *dims] = node.inputs
if x.owner and isinstance(x.owner.op, TensorFromXTensor):
return [x.owner.inputs[0]]


@register_infer_shape
@register_useless
@register_canonicalize
@register_lower_xtensor
@node_rewriter(tracks=[Length])
def useless_length(fgraph, node):
"""Length(FromLength(x)) -> x"""
[dim] = node.inputs
if dim.owner and isinstance(dim.owner.op, FromLength):
return [dim.owner.inputs[0]]


@register_lower_xtensor
@node_rewriter(tracks=[TensorFromXTensor])
def useless_tensor_from_xtensor_of_rename(fgraph, node):
"""TensorFromXTensor(Rename(x)) -> TensorFromXTensor(x)"""
[renamed_x] = node.inputs
if renamed_x.owner and isinstance(renamed_x.owner.op, Rename):
if renamed_x.owner and isinstance(renamed_x.owner.op, MapDims):
[x] = renamed_x.owner.inputs
return node.op(x, return_list=True)


@register_lower_xtensor
@node_rewriter(tracks=[Rename])
@node_rewriter(tracks=[MapDims])
def useless_rename(fgraph, node):
"""

Expand All @@ -54,7 +72,7 @@ def useless_rename(fgraph, node):
"""
[renamed_x] = node.inputs
if renamed_x.owner:
if isinstance(renamed_x.owner.op, Rename):
if isinstance(renamed_x.owner.op, MapDims):
[x] = renamed_x.owner.inputs
return [node.op(x)]
elif isinstance(renamed_x.owner.op, TensorFromXTensor):
Expand Down
Loading