Skip to content
forked from pydata/xarray

Commit

Permalink
Faster encoding functions. (pydata#8565)
Browse files Browse the repository at this point in the history
* Faster ensure_not_multiindex

* Better check

* Fix test and add typing

* Optimize string encoding a bit.
  • Loading branch information
dcherian authored Jan 4, 2024
1 parent 693f0b9 commit 5f1f78f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 17 deletions.
20 changes: 12 additions & 8 deletions xarray/coding/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,11 @@ class EncodedStringCoder(VariableCoder):
def __init__(self, allows_unicode=True):
self.allows_unicode = allows_unicode

def encode(self, variable, name=None):
def encode(self, variable: Variable, name=None) -> Variable:
dims, data, attrs, encoding = unpack_for_encoding(variable)

contains_unicode = is_unicode_dtype(data.dtype)
encode_as_char = encoding.get("dtype") == "S1"

if encode_as_char:
del encoding["dtype"] # no longer relevant

Expand All @@ -69,9 +68,12 @@ def encode(self, variable, name=None):
# TODO: figure out how to handle this in a lazy way with dask
data = encode_string_array(data, string_encoding)

return Variable(dims, data, attrs, encoding)
return Variable(dims, data, attrs, encoding)
else:
variable.encoding = encoding
return variable

def decode(self, variable, name=None):
def decode(self, variable: Variable, name=None) -> Variable:
dims, data, attrs, encoding = unpack_for_decoding(variable)

if "_Encoding" in attrs:
Expand All @@ -95,13 +97,15 @@ def encode_string_array(string_array, encoding="utf-8"):
return np.array(encoded, dtype=bytes).reshape(string_array.shape)


def ensure_fixed_length_bytes(var):
def ensure_fixed_length_bytes(var: Variable) -> Variable:
"""Ensure that a variable with vlen bytes is converted to fixed width."""
dims, data, attrs, encoding = unpack_for_encoding(var)
if check_vlen_dtype(data.dtype) == bytes:
if check_vlen_dtype(var.dtype) == bytes:
dims, data, attrs, encoding = unpack_for_encoding(var)
# TODO: figure out how to handle this with dask
data = np.asarray(data, dtype=np.bytes_)
return Variable(dims, data, attrs, encoding)
return Variable(dims, data, attrs, encoding)
else:
return var


class CharacterArrayCoder(VariableCoder):
Expand Down
16 changes: 9 additions & 7 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from xarray.core.pycompat import is_duck_dask_array
from xarray.core.utils import emit_user_level_warning
from xarray.core.variable import IndexVariable, Variable
from xarray.core.variable import Variable

CF_RELATED_DATA = (
"bounds",
Expand Down Expand Up @@ -97,10 +97,10 @@ def _infer_dtype(array, name=None):


def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None:
if isinstance(var, IndexVariable) and isinstance(var.to_index(), pd.MultiIndex):
if isinstance(var._data, indexing.PandasMultiIndexingAdapter):
raise NotImplementedError(
f"variable {name!r} is a MultiIndex, which cannot yet be "
"serialized to netCDF files. Instead, either use reset_index() "
"serialized. Instead, either use reset_index() "
"to convert MultiIndex levels into coordinate variables instead "
"or use https://cf-xarray.readthedocs.io/en/latest/coding.html."
)
Expand Down Expand Up @@ -647,7 +647,9 @@ def cf_decoder(
return variables, attributes


def _encode_coordinates(variables, attributes, non_dim_coord_names):
def _encode_coordinates(
variables: T_Variables, attributes: T_Attrs, non_dim_coord_names
):
# calculate global and variable specific coordinates
non_dim_coord_names = set(non_dim_coord_names)

Expand Down Expand Up @@ -675,7 +677,7 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names):
variable_coordinates[k].add(coord_name)

if any(
attr_name in v.encoding and coord_name in v.encoding.get(attr_name)
coord_name in v.encoding.get(attr_name, tuple())
for attr_name in CF_RELATED_DATA
):
not_technically_coordinates.add(coord_name)
Expand Down Expand Up @@ -742,7 +744,7 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names):
return variables, attributes


def encode_dataset_coordinates(dataset):
def encode_dataset_coordinates(dataset: Dataset):
"""Encode coordinates on the given dataset object into variable specific
and global attributes.
Expand All @@ -764,7 +766,7 @@ def encode_dataset_coordinates(dataset):
)


def cf_encoder(variables, attributes):
def cf_encoder(variables: T_Variables, attributes: T_Attrs):
"""
Encode a set of CF encoded variables and attributes.
Takes a dicts of variables and attributes and encodes them
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_coding_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,15 +733,15 @@ def test_encode_time_bounds() -> None:

# if time_bounds attrs are same as time attrs, it doesn't matter
ds.time_bounds.encoding = {"calendar": "noleap", "units": "days since 2000-01-01"}
encoded, _ = cf_encoder({k: ds[k] for k in ds.variables}, ds.attrs)
encoded, _ = cf_encoder({k: v for k, v in ds.variables.items()}, ds.attrs)
assert_equal(encoded["time_bounds"], expected["time_bounds"])
assert "calendar" not in encoded["time_bounds"].attrs
assert "units" not in encoded["time_bounds"].attrs

# for CF-noncompliant case of time_bounds attrs being different from
# time attrs; preserve them for faithful roundtrip
ds.time_bounds.encoding = {"calendar": "noleap", "units": "days since 1849-01-01"}
encoded, _ = cf_encoder({k: ds[k] for k in ds.variables}, ds.attrs)
encoded, _ = cf_encoder({k: v for k, v in ds.variables.items()}, ds.attrs)
with pytest.raises(AssertionError):
assert_equal(encoded["time_bounds"], expected["time_bounds"])
assert "calendar" not in encoded["time_bounds"].attrs
Expand Down

0 comments on commit 5f1f78f

Please sign in to comment.