Skip to content

optimize align for scalars at least #8350

Open
@dcherian

Description

@dcherian

What happened?

Here's a simple rescaling calculation:

import numpy as np
import xarray as xr

ds = xr.Dataset(
    {"a": (("x", "y"), np.ones((300, 400))), "b": (("x", "y"), np.ones((300, 400)))}
)
mean = ds.mean() # scalar
std = ds.std() # scalar
rescaled = (ds - mean) / std

The profile for the last line shows 30% (!!!) time spent in align (really reindex_like) except there's nothing to reindex when only scalars are involved!

image

This is a small example inspired by a ML pipeline where this normalization is happening very many times in a tight loop.

cc @benbovy

What did you expect to happen?

A fast path for when no reindexing needs to happen.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    Status

    To do

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions