Skip to content

Commit 4be6ca4

Browse files
huardkeewisZeitsperreaulemahal
authored
Fix for scalar detection (#8821)
Co-authored-by: Justus Magin <keewis@users.noreply.github.com> Co-authored-by: Trevor James Smith <10819524+Zeitsperre@users.noreply.github.com> Co-authored-by: Pascal Bourgault <bourgault.pascal@ouranos.ca>
1 parent e42aa6f commit 4be6ca4

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

xarray/core/indexing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,14 @@ def __repr__(self) -> str:
769769

770770
def _wrap_numpy_scalars(array):
771771
"""Wrap NumPy scalars in 0d arrays."""
772-
if np.isscalar(array):
772+
if np.ndim(array) == 0 and (
773+
isinstance(array, np.generic)
774+
or not (is_duck_array(array) or isinstance(array, NDArrayMixin))
775+
):
776+
return np.array(array)
777+
elif hasattr(array, "dtype"):
778+
return array
779+
elif np.ndim(array) == 0:
773780
return np.array(array)
774781
else:
775782
return array

xarray/tests/test_backends.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,6 +2276,36 @@ def test_write_inconsistent_chunks(self) -> None:
22762276
def test_roundtrip_coordinates(self) -> None:
22772277
super().test_roundtrip_coordinates()
22782278

2279+
@requires_cftime
2280+
def test_roundtrip_cftime_bnds(self):
2281+
# Regression test for issue #7794
2282+
import cftime
2283+
2284+
original = xr.Dataset(
2285+
{
2286+
"foo": ("time", [0.0]),
2287+
"time_bnds": (
2288+
("time", "bnds"),
2289+
[
2290+
[
2291+
cftime.Datetime360Day(2005, 12, 1, 0, 0, 0, 0),
2292+
cftime.Datetime360Day(2005, 12, 2, 0, 0, 0, 0),
2293+
]
2294+
],
2295+
),
2296+
},
2297+
{"time": [cftime.Datetime360Day(2005, 12, 1, 12, 0, 0, 0)]},
2298+
)
2299+
2300+
with create_tmp_file() as tmp_file:
2301+
original.to_netcdf(tmp_file)
2302+
with open_dataset(tmp_file) as actual:
2303+
# Operation to load actual time_bnds into memory
2304+
assert_array_equal(actual.time_bnds.values, original.time_bnds.values)
2305+
chunked = actual.chunk(time=1)
2306+
with create_tmp_file() as tmp_file_chunked:
2307+
chunked.to_netcdf(tmp_file_chunked)
2308+
22792309

22802310
@requires_zarr
22812311
@pytest.mark.usefixtures("default_zarr_format")

0 commit comments

Comments
 (0)