Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Identify non dimension coords #156

Merged
merged 4 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/releases.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ Bug fixes
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- Ensure that `.attrs` on coordinate variables are preserved during round-tripping. (:issue:`155`, :pull:`154`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- Ensure that non-dimension coordinate variables described via the CF conventions are preserved during round-tripping. (:issue:`105`, :pull:`156`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
9 changes: 8 additions & 1 deletion virtualizarr/kerchunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,18 @@ def dataset_to_kerchunk_refs(ds: xr.Dataset) -> KerchunkStoreRefs:

all_arr_refs.update(prepended_with_var_name)

zattrs = ds.attrs
if ds.coords:
coord_names = list(ds.coords)
# this weird concatenated string instead of a list of strings is inconsistent with how other features in the kerchunk references format are stored
# see https://github.com/zarr-developers/VirtualiZarr/issues/105#issuecomment-2187266739
zattrs["coordinates"] = " ".join(coord_names)

ds_refs = {
"version": 1,
"refs": {
".zgroup": '{"zarr_format":2}',
".zattrs": ujson.dumps(ds.attrs),
".zattrs": ujson.dumps(zattrs),
**all_arr_refs,
},
}
Expand Down
26 changes: 26 additions & 0 deletions virtualizarr/tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest
import xarray as xr
import xarray.testing as xrt
Expand Down Expand Up @@ -95,6 +96,31 @@ def test_kerchunk_roundtrip_concat(self, tmpdir, format):
# assert identical to original dataset
xrt.assert_identical(roundtrip, ds)

def test_non_dimension_coordinates(self, tmpdir, format):
# regression test for GH issue #105

# set up example xarray dataset containing non-dimension coordinate variables
ds = xr.Dataset(coords={"lat": (["x", "y"], np.arange(6).reshape(2, 3))})

# save it to disk as netCDF (in temporary directory)
ds.to_netcdf(f"{tmpdir}/non_dim_coords.nc")

vds = open_virtual_dataset(f"{tmpdir}/non_dim_coords.nc", indexes={})

assert "lat" in vds.coords
assert "coordinates" not in vds.attrs

# write those references to disk as kerchunk references format
vds.virtualize.to_kerchunk(f"{tmpdir}/refs.{format}", format=format)

# use fsspec to read the dataset from disk via the kerchunk references
roundtrip = xr.open_dataset(
f"{tmpdir}/refs.{format}", engine="kerchunk", decode_times=False
)

# assert equal to original dataset
xrt.assert_identical(roundtrip, ds)


def test_open_scalar_variable(tmpdir):
# regression test for GH issue #100
Expand Down
21 changes: 13 additions & 8 deletions virtualizarr/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def open_virtual_dataset(
virtual_array_class=virtual_array_class,
)
ds_attrs = kerchunk.fully_decode_arr_refs(vds_refs["refs"]).get(".zattrs", {})
coord_names = ds_attrs.pop("coordinates", [])

if indexes is None or len(loadable_variables) > 0:
# TODO we are reading a bunch of stuff we know we won't need here, e.g. all of the data variables...
Expand Down Expand Up @@ -152,7 +153,7 @@ def open_virtual_dataset(

vars = {**virtual_vars, **loadable_vars}

data_vars, coords = separate_coords(vars, indexes)
data_vars, coords = separate_coords(vars, indexes, coord_names)

vds = xr.Dataset(
data_vars,
Expand All @@ -177,6 +178,7 @@ def open_virtual_dataset_from_v3_store(
_storepath = Path(storepath)

ds_attrs = attrs_from_zarr_group_json(_storepath / "zarr.json")
coord_names = ds_attrs.pop("coordinates", [])

# TODO recursive glob to create a datatree
# Note: this .is_file() check should not be necessary according to the pathlib docs, but tests fail on github CI without it
Expand Down Expand Up @@ -205,7 +207,7 @@ def open_virtual_dataset_from_v3_store(
else:
indexes = dict(**indexes) # for type hinting: to allow mutation

data_vars, coords = separate_coords(vars, indexes)
data_vars, coords = separate_coords(vars, indexes, coord_names)

ds = xr.Dataset(
data_vars,
Expand All @@ -223,8 +225,10 @@ def virtual_vars_from_kerchunk_refs(
virtual_array_class=ManifestArray,
) -> Mapping[str, xr.Variable]:
"""
Translate a store-level kerchunk reference dict into aa set of xarray Variables containing virtualized arrays.
Translate a store-level kerchunk reference dict into aaset of xarray Variables containing virtualized arrays.

Parameters
----------
drop_variables: list[str], default is None
Variables in the file to drop before returning.
virtual_array_class
Expand Down Expand Up @@ -263,12 +267,12 @@ def dataset_from_kerchunk_refs(
"""

vars = virtual_vars_from_kerchunk_refs(refs, drop_variables, virtual_array_class)
ds_attrs = kerchunk.fully_decode_arr_refs(refs["refs"]).get(".zattrs", {})
coord_names = ds_attrs.pop("coordinates", [])

if indexes is None:
indexes = {}
data_vars, coords = separate_coords(vars, indexes)

ds_attrs = kerchunk.fully_decode_arr_refs(refs["refs"]).get(".zattrs", {})
data_vars, coords = separate_coords(vars, indexes, coord_names)

vds = xr.Dataset(
data_vars,
Expand Down Expand Up @@ -301,6 +305,7 @@ def variable_from_kerchunk_refs(
def separate_coords(
vars: Mapping[str, xr.Variable],
indexes: MutableMapping[str, Index],
coord_names: Iterable[str] | None = None,
) -> tuple[Mapping[str, xr.Variable], xr.Coordinates]:
"""
Try to generate a set of coordinates that won't cause xarray to automatically build a pandas.Index for the 1D coordinates.
Expand All @@ -310,8 +315,8 @@ def separate_coords(
Will also preserve any loaded variables and indexes it is passed.
"""

# this would normally come from CF decoding, let's hope the fact we're skipping that doesn't cause any problems...
coord_names: list[str] = []
if coord_names is None:
coord_names = []

# split data and coordinate variables (promote dimension coordinates)
data_vars = {}
Expand Down
Loading