Skip to content

Fix critical np.timedelta64 encoding bugs #10469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ Bug fixes
(:pull:`10352`). By `Spencer Clark <https://github.com/spencerkclark>`_.
- Avoid unsafe casts from float to unsigned int in CFMaskCoder (:issue:`9815`, :pull:`9964`).
By ` Elliott Sales de Andrade <https://github.com/QuLogic>`_.
- Fix attribute overwriting bug when decoding encoded
:py:class:`numpy.timedelta64` values from disk with a dtype attribute
(:issue:`10468`, :pull:`10469`). By `Spencer Clark
<https://github.com/spencerkclark>`_.
- Fix default ``"_FillValue"`` dtype coercion bug when encoding
:py:class:`numpy.timedelta64` values to an on-disk format that only supports
32-bit integers (:issue:`10466`, :pull:`10469`). By `Spencer Clark
<https://github.com/spencerkclark>`_.


Performance
~~~~~~~~~~~
Expand Down
154 changes: 67 additions & 87 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,6 +1410,43 @@ def has_timedelta64_encoding_dtype(attrs_or_encoding: dict) -> bool:
return isinstance(dtype, str) and dtype.startswith("timedelta64")


def resolve_time_unit_from_attrs_dtype(
attrs_dtype: str, name: T_Name
) -> PDDatetimeUnitOptions:
dtype = np.dtype(attrs_dtype)
resolution, _ = np.datetime_data(dtype)
resolution = cast(NPDatetimeUnitOptions, resolution)
if np.timedelta64(1, resolution) > np.timedelta64(1, "s"):
time_unit = cast(PDDatetimeUnitOptions, "s")
message = (
f"Following pandas, xarray only supports decoding to timedelta64 "
f"values with a resolution of 's', 'ms', 'us', or 'ns'. Encoded "
f"values for variable {name!r} have a resolution of "
f"{resolution!r}. Attempting to decode to a resolution of 's'. "
f"Note, depending on the encoded values, this may lead to an "
f"OverflowError. Additionally, data will not be identically round "
f"tripped; xarray will choose an encoding dtype of "
f"'timedelta64[s]' when re-encoding."
)
emit_user_level_warning(message)
elif np.timedelta64(1, resolution) < np.timedelta64(1, "ns"):
time_unit = cast(PDDatetimeUnitOptions, "ns")
message = (
f"Following pandas, xarray only supports decoding to timedelta64 "
f"values with a resolution of 's', 'ms', 'us', or 'ns'. Encoded "
f"values for variable {name!r} have a resolution of "
f"{resolution!r}. Attempting to decode to a resolution of 'ns'. "
f"Note, depending on the encoded values, this may lead to loss of "
f"precision. Additionally, data will not be identically round "
f"tripped; xarray will choose an encoding dtype of "
f"'timedelta64[ns]' when re-encoding."
)
emit_user_level_warning(message)
else:
time_unit = cast(PDDatetimeUnitOptions, resolution)
return time_unit


class CFTimedeltaCoder(VariableCoder):
"""Coder for CF Timedelta coding.
Expand All @@ -1430,7 +1467,7 @@ class CFTimedeltaCoder(VariableCoder):

def __init__(
self,
time_unit: PDDatetimeUnitOptions = "ns",
time_unit: PDDatetimeUnitOptions | None = None,
decode_via_units: bool = True,
decode_via_dtype: bool = True,
) -> None:
Expand All @@ -1442,45 +1479,18 @@ def __init__(
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
if np.issubdtype(variable.data.dtype, np.timedelta64):
dims, data, attrs, encoding = unpack_for_encoding(variable)
has_timedelta_dtype = has_timedelta64_encoding_dtype(encoding)
if ("units" in encoding or "dtype" in encoding) and not has_timedelta_dtype:
dtype = encoding.get("dtype", None)
units = encoding.pop("units", None)
dtype = encoding.get("dtype", None)
units = encoding.pop("units", None)

# in the case of packed data we need to encode into
# float first, the correct dtype will be established
# via CFScaleOffsetCoder/CFMaskCoder
if "add_offset" in encoding or "scale_factor" in encoding:
dtype = data.dtype if data.dtype.kind == "f" else "float64"
# in the case of packed data we need to encode into
# float first, the correct dtype will be established
# via CFScaleOffsetCoder/CFMaskCoder
if "add_offset" in encoding or "scale_factor" in encoding:
dtype = data.dtype if data.dtype.kind == "f" else "float64"

else:
resolution, _ = np.datetime_data(variable.dtype)
dtype = np.int64
attrs_dtype = f"timedelta64[{resolution}]"
units = _numpy_dtype_to_netcdf_timeunit(variable.dtype)
safe_setitem(attrs, "dtype", attrs_dtype, name=name)
# Remove dtype encoding if it exists to prevent it from
# interfering downstream in NonStringCoder.
encoding.pop("dtype", None)

if any(
k in encoding for k in _INVALID_LITERAL_TIMEDELTA64_ENCODING_KEYS
):
raise ValueError(
f"Specifying 'add_offset' or 'scale_factor' is not "
f"supported when encoding the timedelta64 values of "
f"variable {name!r} with xarray's new default "
f"timedelta64 encoding approach. To encode {name!r} "
f"with xarray's previous timedelta64 encoding "
f"approach, which supports the 'add_offset' and "
f"'scale_factor' parameters, additionally set "
f"encoding['units'] to a unit of time, e.g. "
f"'seconds'. To proceed with encoding of {name!r} "
f"via xarray's new approach, remove any encoding "
f"entries for 'add_offset' or 'scale_factor'."
)
if "_FillValue" not in encoding and "missing_value" not in encoding:
encoding["_FillValue"] = np.iinfo(np.int64).min
resolution, _ = np.datetime_data(variable.dtype)
attrs_dtype = f"timedelta64[{resolution}]"
safe_setitem(attrs, "dtype", attrs_dtype, name=name)
Comment on lines 1479 to +1493
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Spencer! This is much cleaner and easier to follow.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the help discussing this. Early on I was enamored with the idea of strict "literal" encoding of timedelta64 values (i.e. simply converting to int64 and using units based solely on the resolution), but clearly we need to support more flexibility, which our existing approach of course allows!


data, units = encode_cf_timedelta(data, units, dtype)
safe_setitem(attrs, "units", units, name=name)
Expand All @@ -1499,54 +1509,13 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
):
dims, data, attrs, encoding = unpack_for_decoding(variable)
units = pop_to(attrs, encoding, "units")
if is_dtype_decodable and self.decode_via_dtype:
if any(
k in encoding for k in _INVALID_LITERAL_TIMEDELTA64_ENCODING_KEYS
):
raise ValueError(
f"Decoding timedelta64 values via dtype is not "
f"supported when 'add_offset', or 'scale_factor' are "
f"present in encoding. Check the encoding parameters "
f"of variable {name!r}."
)
dtype = pop_to(attrs, encoding, "dtype", name=name)
dtype = np.dtype(dtype)
resolution, _ = np.datetime_data(dtype)
resolution = cast(NPDatetimeUnitOptions, resolution)
if np.timedelta64(1, resolution) > np.timedelta64(1, "s"):
time_unit = cast(PDDatetimeUnitOptions, "s")
dtype = np.dtype("timedelta64[s]")
message = (
f"Following pandas, xarray only supports decoding to "
f"timedelta64 values with a resolution of 's', 'ms', "
f"'us', or 'ns'. Encoded values for variable {name!r} "
f"have a resolution of {resolution!r}. Attempting to "
f"decode to a resolution of 's'. Note, depending on "
f"the encoded values, this may lead to an "
f"OverflowError. Additionally, data will not be "
f"identically round tripped; xarray will choose an "
f"encoding dtype of 'timedelta64[s]' when re-encoding."
)
emit_user_level_warning(message)
elif np.timedelta64(1, resolution) < np.timedelta64(1, "ns"):
time_unit = cast(PDDatetimeUnitOptions, "ns")
dtype = np.dtype("timedelta64[ns]")
message = (
f"Following pandas, xarray only supports decoding to "
f"timedelta64 values with a resolution of 's', 'ms', "
f"'us', or 'ns'. Encoded values for variable {name!r} "
f"have a resolution of {resolution!r}. Attempting to "
f"decode to a resolution of 'ns'. Note, depending on "
f"the encoded values, this may lead to loss of "
f"precision. Additionally, data will not be "
f"identically round tripped; xarray will choose an "
f"encoding dtype of 'timedelta64[ns]' "
f"when re-encoding."
)
emit_user_level_warning(message)
if is_dtype_decodable:
attrs_dtype = attrs.pop("dtype")
if self.time_unit is None:
time_unit = resolve_time_unit_from_attrs_dtype(attrs_dtype, name)
else:
time_unit = cast(PDDatetimeUnitOptions, resolution)
elif self.decode_via_units:
time_unit = self.time_unit
else:
if self._emit_decode_timedelta_future_warning:
emit_user_level_warning(
"In a future version, xarray will not decode "
Expand All @@ -1564,8 +1533,19 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
"'CFTimedeltaCoder' instance.",
FutureWarning,
)
dtype = np.dtype(f"timedelta64[{self.time_unit}]")
time_unit = self.time_unit
if self.time_unit is None:
time_unit = cast(PDDatetimeUnitOptions, "ns")
else:
time_unit = self.time_unit

# Handle edge case that decode_via_dtype=False and
# decode_via_units=True, and timedeltas were encoded with a
# dtype attribute. We need to remove the dtype attribute
# to prevent an error during round tripping.
if has_timedelta_dtype:
attrs.pop("dtype")

dtype = np.dtype(f"timedelta64[{time_unit}]")
transform = partial(decode_cf_timedelta, units=units, time_unit=time_unit)
data = lazy_elemwise_func(data, transform, dtype=dtype)
return Variable(dims, data, attrs, encoding, fastpath=True)
Expand Down
11 changes: 11 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from xarray.conventions import encode_dataset_coordinates
from xarray.core import indexing
from xarray.core.options import set_options
from xarray.core.types import PDDatetimeUnitOptions
from xarray.core.utils import module_available
from xarray.namedarray.pycompat import array_type
from xarray.tests import (
Expand Down Expand Up @@ -642,6 +643,16 @@ def test_roundtrip_timedelta_data(self) -> None:
) as actual:
assert_identical(expected, actual)

def test_roundtrip_timedelta_data_via_dtype(
self, time_unit: PDDatetimeUnitOptions
) -> None:
time_deltas = pd.to_timedelta(["1h", "2h", "NaT"]).as_unit(time_unit) # type: ignore[arg-type, unused-ignore]
expected = Dataset(
{"td": ("td", time_deltas), "td0": time_deltas[0].to_numpy()}
)
with self.roundtrip(expected) as actual:
assert_identical(expected, actual)

def test_roundtrip_float64_data(self) -> None:
expected = Dataset({"x": ("y", np.array([1.0, 2.0, np.pi], dtype="float64"))})
with self.roundtrip(expected) as actual:
Expand Down
Loading
Loading