Skip to content

Commit dd4f549

Browse files
committed
Recreate @gajomi's pydata#2070 to keep attrs when calling astype()
1 parent 7daad4f commit dd4f549

File tree

6 files changed

+91
-1
lines changed

6 files changed

+91
-1
lines changed

xarray/core/common.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,6 +1305,46 @@ def isin(self, test_elements):
13051305
dask="allowed",
13061306
)
13071307

1308+
def astype(self, dtype, casting="unsafe", copy=True):
1309+
"""
1310+
Copy of the xarray object, with data cast to a specified type.
1311+
Leaves coordinate dtype unchanged.
1312+
1313+
Parameters
1314+
----------
1315+
dtype : str or dtype
1316+
Typecode or data-type to which the array is cast.
1317+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
1318+
Controls what kind of data casting may occur. Defaults to 'unsafe'
1319+
for backwards compatibility.
1320+
1321+
* 'no' means the data types should not be cast at all.
1322+
* 'equiv' means only byte-order changes are allowed.
1323+
* 'safe' means only casts which can preserve values are allowed.
1324+
* 'same_kind' means only safe casts or casts within a kind,
1325+
like float64 to float32, are allowed.
1326+
* 'unsafe' means any data conversions may be done.
1327+
copy : bool, optional
1328+
By default, astype always returns a newly allocated array. If this
1329+
is set to False and the `dtype` requirement is satisfied, the input
1330+
array is returned instead of a copy.
1331+
1332+
See also
1333+
--------
1334+
np.ndarray.astype
1335+
dask.array.Array.astype
1336+
"""
1337+
from .computation import apply_ufunc
1338+
1339+
return apply_ufunc(
1340+
duck_array_ops.astype,
1341+
self,
1342+
dtype,
1343+
keep_attrs=True,
1344+
kwargs={"casting": casting, "copy": copy},
1345+
dask="allowed",
1346+
)
1347+
13081348
def __enter__(self: T) -> T:
13091349
return self
13101350

xarray/core/duck_array_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ def trapz(y, x, axis):
149149
)
150150

151151

152+
def astype(data, dtype, **kwargs):
153+
return data.astype(dtype, **kwargs)
154+
155+
152156
def asarray(data, xp=np):
153157
return (
154158
data

xarray/core/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
NUMPY_SAME_METHODS = ["item", "searchsorted"]
4343
# methods which don't modify the data shape, so the result should still be
4444
# wrapped in an Variable/DataArray
45-
NUMPY_UNARY_METHODS = ["astype", "argsort", "clip", "conj", "conjugate"]
45+
NUMPY_UNARY_METHODS = ["argsort", "clip", "conj", "conjugate"]
4646
PANDAS_UNARY_FUNCTIONS = ["isnull", "notnull"]
4747
# methods which remove an axis
4848
REDUCE_METHODS = ["all", "any"]

xarray/core/variable.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,37 @@ def data(self, data):
360360
)
361361
self._data = data
362362

363+
def astype(self, dtype, casting="unsafe", copy=True):
364+
"""
365+
Copy of the Variable object, with data cast to a specified type.
366+
367+
Parameters
368+
----------
369+
dtype : str or dtype
370+
Typecode or data-type to which the array is cast.
371+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
372+
Controls what kind of data casting may occur. Defaults to 'unsafe'
373+
for backwards compatibility.
374+
375+
* 'no' means the data types should not be cast at all.
376+
* 'equiv' means only byte-order changes are allowed.
377+
* 'safe' means only casts which can preserve values are allowed.
378+
* 'same_kind' means only safe casts or casts within a kind,
379+
like float64 to float32, are allowed.
380+
* 'unsafe' means any data conversions may be done.
381+
copy : bool, optional
382+
By default, astype always returns a newly allocated array. If this
383+
is set to False and the `dtype` requirement is satisfied, the input
384+
array is returned instead of a copy.
385+
386+
See also
387+
--------
388+
np.ndarray.astype
389+
dask.array.Array.astype
390+
"""
391+
self.data = duck_array_ops.astype(self.data, dtype, casting=casting, copy=copy)
392+
return self
393+
363394
def load(self, **kwargs):
364395
"""Manually trigger loading of this variable's data from disk or a
365396
remote source into memory and return this variable.

xarray/tests/test_dataarray.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,6 +1874,13 @@ def test_array_interface(self):
18741874
bar = Variable(["x", "y"], np.zeros((10, 20)))
18751875
assert_equal(self.dv, np.maximum(self.dv, bar))
18761876

1877+
def test_astype_attrs(self):
1878+
mda1 = self.mda.copy()
1879+
mda1.attrs["foo"] = "bar"
1880+
mda2 = mda1.astype(bool)
1881+
1882+
assert list(mda1.attrs.items()) == list(mda2.attrs.items())
1883+
18771884
def test_is_null(self):
18781885
x = np.random.RandomState(42).randn(5, 6)
18791886
x[x < 0] = np.nan

xarray/tests/test_dataset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5607,6 +5607,14 @@ def test_pad(self):
56075607
np.testing.assert_equal(padded["var1"].isel(dim2=[0, -1]).data, 42)
56085608
np.testing.assert_equal(padded["dim2"][[0, -1]].data, np.nan)
56095609

5610+
def test_astype_attrs(self):
5611+
data = create_test_data(seed=123)
5612+
data.attrs["foo"] = "bar"
5613+
databool = data.astype(bool)
5614+
5615+
assert list(data.attrs.items()) == list(databool.attrs.items())
5616+
assert list(data.var1.attrs.items()) == list(databool.var1.attrs.items())
5617+
56105618

56115619
# Py.test tests
56125620

0 commit comments

Comments
 (0)