Skip to content

Commit

Permalink
Only use necessary dims when creating temporary dataarray (#9206)
Browse files Browse the repository at this point in the history
* Only use necessary dims when creating temporary dataarray

* Update dataset_plot.py

* Can't check only data_vars all corrds are no longer added by default

* Update dataset_plot.py

* Add tests

* Update whats-new.rst

* Update dataset_plot.py
  • Loading branch information
Illviljan authored Jul 9, 2024
1 parent 179c670 commit 3024655
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 5 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ Deprecations

Bug fixes
~~~~~~~~~
- Fix scatter plot broadcasting unneccesarily. (:issue:`9129`, :pull:`9206`)
By `Jimmy Westling <https://github.com/illviljan>`_.
- Don't convert custom indexes to ``pandas`` indexes when computing a diff (:pull:`9157`)
By `Justus Magin <https://github.com/keewis>`_.
- Make :py:func:`testing.assert_allclose` work with numpy 2.0 (:issue:`9165`, :pull:`9166`).
Expand Down
15 changes: 10 additions & 5 deletions xarray/plot/dataset_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,8 +721,8 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr
"""Create a temporary datarray with extra coords."""
from xarray.core.dataarray import DataArray

# Base coords:
coords = dict(ds.coords)
coords = dict(ds[y].coords)
dims = set(ds[y].dims)

# Add extra coords to the DataArray from valid kwargs, if using all
# kwargs there is a risk that we add unnecessary dataarrays as
Expand All @@ -732,12 +732,17 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr
coord_kwargs = locals_.keys() & valid_coord_kwargs
for k in coord_kwargs:
key = locals_[k]
if ds.data_vars.get(key) is not None:
coords[key] = ds[key]
darray = ds.get(key)
if darray is not None:
coords[key] = darray
dims.update(darray.dims)

# Trim dataset from unneccessary dims:
ds_trimmed = ds.drop_dims(ds.sizes.keys() - dims) # TODO: Use ds.dims in the future

# The dataarray has to include all the dims. Broadcast to that shape
# and add the additional coords:
_y = ds[y].broadcast_like(ds)
_y = ds[y].broadcast_like(ds_trimmed)

return DataArray(_y, coords=coords)

Expand Down
40 changes: 40 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3416,3 +3416,43 @@ def test_9155() -> None:
data = xr.DataArray([1, 2, 3], dims=["x"])
fig, ax = plt.subplots(ncols=1, nrows=1)
data.plot(ax=ax)


@requires_matplotlib
def test_temp_dataarray() -> None:
from xarray.plot.dataset_plot import _temp_dataarray

x = np.arange(1, 4)
y = np.arange(4, 6)
var1 = np.arange(x.size * y.size).reshape((x.size, y.size))
var2 = np.arange(x.size * y.size).reshape((x.size, y.size))
ds = xr.Dataset(
{
"var1": (["x", "y"], var1),
"var2": (["x", "y"], 2 * var2),
"var3": (["x"], 3 * x),
},
coords={
"x": x,
"y": y,
"model": np.arange(7),
},
)

# No broadcasting:
y_ = "var1"
locals_ = {"x": "var2"}
da = _temp_dataarray(ds, y_, locals_)
assert da.shape == (3, 2)

# Broadcast from 1 to 2dim:
y_ = "var3"
locals_ = {"x": "var1"}
da = _temp_dataarray(ds, y_, locals_)
assert da.shape == (3, 2)

# Ignore non-valid coord kwargs:
y_ = "var3"
locals_ = dict(x="x", extend="var2")
da = _temp_dataarray(ds, y_, locals_)
assert da.shape == (3,)

0 comments on commit 3024655

Please sign in to comment.