Skip to content

Commit fb691b3

Browse files
ScalarArray to xr.DataArray Conversions (#113)
Co-authored-by: Roy Smart <roytsmart@gmail.com>
1 parent ca315c5 commit fb691b3

File tree

4 files changed

+27
-0
lines changed

4 files changed

+27
-0
lines changed

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
'ndfilters': ('https://ndfilters.readthedocs.io/en/stable', None),
123123
'colorsynth': ('https://colorsynth.readthedocs.io/en/stable', None),
124124
'regridding': ('https://regridding.readthedocs.io/en/stable', None),
125+
'xarray': ('https://docs.xarray.dev/en/stable/', None),
125126
}
126127

127128
# plt.Axes.__module__ = matplotlib.axes.__name__

named_arrays/_scalars/scalars.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import scipy.ndimage
99
import astropy.units as u
1010
import named_arrays as na
11+
import xarray as xr
1112

1213
__all__ = [
1314
'ScalarStartT',
@@ -166,6 +167,12 @@ def ndarray(self: Self) -> bool | int | float | complex | str | np.ndarray | u.Q
166167
This is usually an instance of :class:`numpy.ndarray` or :class:`astropy.units.Quantity`, but it can also be a
167168
built-in python type such as a :class:`int`, :class:`float`, or :class:`bool`
168169
"""
170+
@property
171+
def to_xarray(self: Self) -> xr.DataArray:
172+
"""
173+
Cast :class:`ScalarArray` to a :class:`xarray.DataArray`. Useful for saving to netcdf/zarr or using Dask.
174+
"""
175+
return xr.DataArray(self.ndarray, dims=self.axes)
169176

170177
@property
171178
def value(self: Self) -> ScalarArray:
@@ -819,6 +826,13 @@ def from_scalar_array(
819826

820827
return self
821828

829+
@classmethod
830+
def from_xarray(cls, dataarray: xr.DataArray) -> Self:
831+
"""
832+
Convert :class:`xarray.DataArray` into a :class:`ScalarArray`
833+
"""
834+
return cls(dataarray.data, axes=dataarray.dims)
835+
822836
@classmethod
823837
def empty(cls: Type[Self], shape: dict[str, int], dtype: Type | np.dtype = float) -> Self:
824838
"""

named_arrays/_scalars/tests/test_scalars.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,17 @@ class AbstractTestAbstractScalarArray(
342342
def test_ndarray(self, array: na.AbstractScalarArray):
343343
assert isinstance(array.ndarray, (int, float, complex, str, np.ndarray))
344344

345+
def test_to_xarray(self, array: na.AbstractScalarArray):
346+
xarray = array.to_xarray
347+
assert np.all(xarray.data == array.ndarray)
348+
assert set(xarray.dims) == set(array.axes)
349+
350+
def test_from_xarray(self, array: na.AbstractScalarArray):
351+
xarray = array.to_xarray
352+
scalar_array = na.ScalarArray.from_xarray(xarray)
353+
assert np.all(scalar_array.ndarray == array.ndarray)
354+
assert set(scalar_array.axes) == set(array.axes)
355+
345356
def test_axes(self, array: na.AbstractScalarArray):
346357
super().test_axes(array)
347358
assert len(array.axes) == np.ndim(array.ndarray)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ dependencies = [
2424
"ndfilters==0.3.0",
2525
"colorsynth==0.1.5",
2626
"regridding==0.1.1",
27+
"xarray",
2728
]
2829
dynamic = ["version"]
2930

0 commit comments

Comments
 (0)