Skip to content

Adding expand_dims for xtensor #1449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: labeled_tensors
Choose a base branch
from

Conversation

AllenDowney
Copy link

@AllenDowney AllenDowney commented Jun 6, 2025

Add expand_dims operation for labeled tensors

This PR adds support for the expand_dims operation in PyTensor's labeled tensor system, allowing users to add new dimensions to labeled tensors with explicit dimension names.

Key Features

  • New ExpandDims operation that adds a new dimension to an XTensorVariable
  • Support for both static and symbolic dimension sizes
  • Automatic broadcasting when size > 1
  • Integration with existing tensor operations
  • Full compatibility with xarray's expand_dims behavior

Implementation Details

The implementation includes:

  1. New ExpandDims class in pytensor/xtensor/shape.py that handles:

    • Adding new dimensions with specified names
    • Support for both static and symbolic sizes
    • Shape inference and validation
  2. Rewriting rule in pytensor/xtensor/rewriting/shape.py that:

    • Converts labeled tensor operations to standard tensor operations
    • Handles broadcasting when needed
    • Validates symbolic sizes
  3. Comprehensive test suite in tests/xtensor/test_shape.py covering:

    • Basic dimension expansion
    • Static and symbolic sizes
    • Error cases and edge cases
    • Compatibility with xarray operations
    • Integration with other labeled tensor operations

Usage Example

import pytensor.tensor as pt
from pytensor.xtensor import xtensor

# Create a labeled tensor
x = xtensor("x", dims=("city",), shape=(3,))

# Add a new dimension
y = expand_dims(x, "country")  # Adds a new dimension of size 1
z = expand_dims(x, "country", size=4)  # Adds a new dimension of size 4

Testing

The implementation includes extensive tests that verify:

  • Correct behavior with various input shapes
  • Proper handling of symbolic sizes
  • Error cases (invalid dimensions, sizes, etc.)
  • Compatibility with xarray's expand_dims
  • Integration with other labeled tensor operations

📚 Documentation preview 📚: https://pytensor--1449.org.readthedocs.build/en/1449/

@AllenDowney
Copy link
Author

Now that we have this PR based on the right commit, @ricardoV94 it is ready for a first look.

One question: my first draft of this was based on a later commit -- this draft goes back to an earlier commit, and it looks like @register_xcanonicalize doesn't exist yet, so I've replaced it with @register_lower_xtensor, which seems to be its predecessor. Is that the right thing to do for now?

@ricardoV94
Copy link
Member

That's the new name, it better represents the kind of rewrites it holds

@AllenDowney
Copy link
Author

@ricardoV94 I think this is a step toward handling symbolic sizes, but there are a couple of place where I'm not sure what the right behavior is. See the comments in test_shape.py, test_expand_dims_implicit.

Do those tests make sense? Are there more cases that should be covered?

@ricardoV94
Copy link
Member

The simplest test for symbolic expand_dims is:

size_new_dim = xtensor("size_new_dim", shape=(), dtype=int)
x = xtensor("x", shape=(3,))
y =  x.expand_dims(new_dim=size_new_dim)
xr_function = function([x, size_new_dim], y)

x_test = xr_arange_like(x)
size_new_dim_test = DataArray(np.array(5, dtype=int))
result = xr_function(x_test, size_new_dim_test)
expected_result = x_test.expand_dims(new_dim=size_new_dim_test)
xr_assert_allclose(result, expected_result)

Yout can parametrize the test to try default and explicit non-default axis as well.

Sidenote, what is an implicit expand_dims? I don't think that's a thing.

@AllenDowney
Copy link
Author

@ricardoV94 I've addressed most of your comments on the previous round, and made a first pass at adding support for multiple dimensions. Please take a look at the expand_dims wrapper function, which canonicalizes the inputs and loops through them to make a series of Ops.

Assuming that adding multiple dimensions is rare, what do with think of the loop option, as opposed to making a single Op that adds multiple dimensions?

@ricardoV94
Copy link
Member

Assuming that adding multiple dimensions is rare, what do with think of the loop option, as opposed to making a single Op that adds multiple dimensions?

That's fine. We used that for other Ops and we can revisit later of we want it to be fused

@AllenDowney
Copy link
Author

@ricardoV94 This is ready for another look.

The rewrite was a shambles, but I think I have a clearer idea now.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 10, 2025

I left some comments above.

Rewrite looks good. As we discussed we should redo the tests to use expand_dims as a method (like xarray users would).

Also, I suspected xarray allows specifying the size like x.expand_dims(dim_a=1, dim_b=2) which is equivalent to x.expand_dims({"dim_a":1, "dim_b":2}). At least that was a pattern I noticed in other xarray methods. I saw you had a test for multiple dims with dict, but I didn't see one with kwargs.

@AllenDowney
Copy link
Author

@ricardoV94 I cleaned up the code as suggested and took a first cut at handling the axis parameter. Please take a look.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great. Just the question about the size kwarg and small notes.

@AllenDowney
Copy link
Author

@ricardoV94 This is ready for another look

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're nearly there. This proved to be more complex than I antecipated.

Small error checks changes and one question. Also I don't think there's any test for the passing sequences as the length of the dimensions?

Comment on lines +437 to +442
if not create_index_for_new_dim:
warnings.warn(
"create_index_for_new_dim=False has no effect in pytensor.xtensor",
UserWarning,
stacklevel=2,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neither option has an effect to be fair, just don't mention?

)

# Extract size from dim_kwargs if present
size = dim_kwargs.pop("size", 1) if dim_kwargs else 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still here, did you not push yet?

Comment on lines +447 to +449
# xarray compatibility: error if a sequence (list/tuple) of dims and size are given
if (isinstance(dim, list | tuple)) and ("size" in locals() and size != 1):
raise ValueError("cannot specify both keyword and positional arguments")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also not a thing?


# Normalize to a dimension-size mapping
if isinstance(dim, str):
dims_dict = {dim: size}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dims_dict = {dim: size}
dims_dict = {dim: 1}

elif isinstance(dim, dict):
dims_dict = {}
for name, val in dim.items():
if isinstance(val, Sequence | np.ndarray) and not isinstance(val, str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does xarray treat expand_dims(new_dim=np.array(5))? Does it treat as coordinates or the size? I suppose the latter?

In that case the check here should be (isinstance(val, np.ndarray), and val.ndim > 0). We could also consider symbolic variables with (isinstance(val, np.ndarray) or (isinstance(val, Variable) and isinstance(val.type, HasShape)) and val.ndim > 0)

Comment on lines +472 to +475
elif isinstance(val, int):
dims_dict[name] = val
else:
dims_dict[name] = val # symbolic/int scalar allowed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ends up being accepted anyway so merge?

Suggested change
elif isinstance(val, int):
dims_dict[name] = val
else:
dims_dict[name] = val # symbolic/int scalar allowed
else:
dims_dict[name] = val # symbolic/int scalar allowed

elif isinstance(dim, dict):
dims_dict = {}
for name, val in dim.items():
if isinstance(val, Sequence | np.ndarray) and not isinstance(val, str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(val, Sequence | np.ndarray) and not isinstance(val, str):
if isinstance(val, str):
raise ValueError(f"The size of a dimension cannot be a string, got {val})
if isinstance(val, Sequence | np.ndarray):

@ricardoV94
Copy link
Member

PS it's so annoying there's no type for non-string sequences

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants