Skip to content

Add xarray-specific encoding convention for pd.IntervalArray #10483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,3 +696,53 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
raise NotImplementedError()


class IntervalCoder(VariableCoder):
"""
Xarray-specific Interval Coder to roundtrip 1D pd.IntervalArray objects.
"""

encoded_dtype = "pandas_interval"
encoded_bounds_dim = "__xarray_bounds__"

def encode(self, variable: Variable, name: T_Name = None) -> Variable:
if isinstance(dtype := variable.dtype, pd.IntervalDtype):
dims, data, attrs, encoding = unpack_for_encoding(variable)

new_data = np.stack([data.left, data.right], axis=0)
dims = (self.encoded_bounds_dim, *dims)
safe_setitem(attrs, "closed", dtype.closed, name=name)
safe_setitem(attrs, "dtype", self.encoded_dtype, name=name)
safe_setitem(attrs, "bounds_dim", self.encoded_bounds_dim, name=name)
return Variable(dims, new_data, attrs, encoding, fastpath=True)
else:
return Variable

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
if (
variable.attrs.get("dtype", None) == self.encoded_dtype
and self.encoded_bounds_dim in variable.dims
):
if variable.ndim != 2:
raise ValueError(
f"Cannot decode intervals for variable named {name!r} with more than two dimensions."
)

dims, data, attrs, encoding = unpack_for_decoding(variable)
pop_to(attrs, encoding, "dtype", name=name)
pop_to(attrs, encoding, "bounds_dim", name=name)
closed = pop_to(attrs, encoding, "closed", name=name)

_, new_dims = variable.dims
variable = variable.load()
new_data = pd.arrays.IntervalArray.from_arrays(
variable.isel({self.encoded_bounds_dim: 0}).data,
variable.isel({self.encoded_bounds_dim: 1}).data,
closed=closed,
)
return Variable(
dims=new_dims, data=new_data, attrs=attrs, encoding=encoding
)
else:
return Variable
5 changes: 5 additions & 0 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def encode_cf_variable(
ensure_not_multiindex(var, name=name)

for coder in [
# IntervalCoder must be before CFDatetimeCoder,
# so we can first encode the interval, then datetimes if necessary
variables.IntervalCoder(),
CFDatetimeCoder(),
CFTimedeltaCoder(),
variables.CFScaleOffsetCoder(),
Expand Down Expand Up @@ -238,6 +241,8 @@ def decode_cf_variable(
)
var = decode_times.decode(var, name=name)

var = variables.IntervalCoder().decode(var)

if decode_endianness and not var.dtype.isnative:
var = variables.EndianCoder().decode(var)
original_dtype = var.dtype
Expand Down
28 changes: 28 additions & 0 deletions xarray/tests/test_coding.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
def test_CFMaskCoder_encode_missing_fill_values_conflict(data, encoding) -> None:
original = xr.Variable(("x",), data, encoding=encoding)
encoded = encode_cf_variable(original)

Check failure on line 49 in xarray/tests/test_coding.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 bare-minimum

test_CFMaskCoder_encode_missing_fill_values_conflict[times-with-dtype] AttributeError: 'property' object has no attribute 'dtype'

Check failure on line 49 in xarray/tests/test_coding.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 bare-minimum

test_CFMaskCoder_encode_missing_fill_values_conflict[numeric-without-dtype] AttributeError: 'property' object has no attribute 'dtype'

Check failure on line 49 in xarray/tests/test_coding.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 bare-minimum

test_CFMaskCoder_encode_missing_fill_values_conflict[numeric-with-dtype] AttributeError: 'property' object has no attribute 'dtype'

Check failure on line 49 in xarray/tests/test_coding.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 bare-min-and-scipy

test_CFMaskCoder_encode_missing_fill_values_conflict[times-with-dtype] AttributeError: 'property' object has no attribute 'dtype'

Check failure on line 49 in xarray/tests/test_coding.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 bare-min-and-scipy

test_CFMaskCoder_encode_missing_fill_values_conflict[numeric-without-dtype] AttributeError: 'property' object has no attribute 'dtype'

Check failure on line 49 in xarray/tests/test_coding.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 bare-min-and-scipy

test_CFMaskCoder_encode_missing_fill_values_conflict[numeric-with-dtype] AttributeError: 'property' object has no attribute 'dtype'

assert encoded.dtype == encoded.attrs["missing_value"].dtype
assert encoded.dtype == encoded.attrs["_FillValue"].dtype
Expand All @@ -63,7 +63,7 @@
)
expected.attrs["missing_value"] = -9999

decoded = xr.decode_cf(expected.to_dataset())

Check failure on line 66 in xarray/tests/test_coding.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 bare-minimum

test_CFMaskCoder_missing_value AttributeError: Failed to decode variable 'tmpk': 'property' object has no attribute 'isnative'

Check failure on line 66 in xarray/tests/test_coding.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 bare-min-and-scipy

test_CFMaskCoder_missing_value AttributeError: Failed to decode variable 'tmpk': 'property' object has no attribute 'isnative'
encoded, _ = xr.conventions.cf_encoder(decoded.variables, decoded.attrs)

assert_equal(encoded["tmpk"], expected.variable)
Expand Down Expand Up @@ -147,3 +147,31 @@
decoded = coder.decode(encoded)
assert decoded.dtype == signed_dtype
assert decoded.values == original_values


@pytest.mark.parametrize(
"data",
[
[1, 2, 3, 4],
np.array([1, 2, 3, 4], dtype=float),
pd.date_range("2001-01-01", "2002-01-01", freq="MS"),
],
)
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
def test_roundtrip_pandas_interval(data, closed) -> None:
v = xr.Variable("time", pd.IntervalIndex.from_breaks(data, closed=closed))
coder = variables.IntervalCoder()
encoded = coder.encode(v)
expected = xr.Variable(
dims=("__xarray_bounds__", "time"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could make this trailing dimension

data=np.stack([data[:-1], data[1:]], axis=0),
attrs={
"dtype": "pandas_interval",
"bounds_dim": "__xarray_bounds__",
"closed": closed,
},
)
assert_identical(encoded, expected)

decoded = coder.decode(encoded)
assert_identical(decoded, v)
20 changes: 20 additions & 0 deletions xarray/tests/test_conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,23 @@ def test_decode_cf_variables_decode_timedelta_warning() -> None:

with pytest.warns(FutureWarning, match="decode_timedelta"):
conventions.decode_cf_variables(variables, {})


@pytest.mark.parametrize(
"data",
[
[1, 2, 3, 4],
np.array([1, 2, 3, 4], dtype=float),
pd.date_range("2001-01-01", "2002-01-01", freq="MS"),
],
)
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
def test_roundtrip_pandas_interval(data, closed) -> None:
v = Variable("time", pd.IntervalIndex.from_breaks(data, closed=closed))
encoded = conventions.encode_cf_variable(v)
if isinstance(data, pd.DatetimeIndex):
# make sure we've encoded datetimes.
assert "units" in encoded.attrs
assert "calendar" in encoded.attrs
roundtripped = conventions.decode_cf_variable("foo", encoded)
assert_identical(roundtripped, v)
Loading