Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid duplicate Zarr array read #8472

Merged
merged 5 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,12 @@ def encode_zarr_attr_value(value):


class ZarrArrayWrapper(BackendArray):
__slots__ = ("datastore", "dtype", "shape", "variable_name", "_array")

def __init__(self, variable_name, datastore):
self.datastore = datastore
self.variable_name = variable_name
__slots__ = ("dtype", "shape", "_array")

def __init__(self, zarr_array):
# some callers attempt to evaluate an array if an `array` property exists on the object.
# we prefix with _ to avoid this inference.
self._array = self.datastore.zarr_group[self.variable_name]
self._array = zarr_array
self.shape = self._array.shape

# preserve vlen string object dtype (GH 7328)
Expand All @@ -86,10 +83,10 @@ def get_array(self):
return self._array

def _oindex(self, key):
return self.get_array().oindex[key]
return self._array.oindex[key]

def __getitem__(self, key):
array = self.get_array()
array = self._array
if isinstance(key, indexing.BasicIndexer):
return array[key.tuple]
elif isinstance(key, indexing.VectorizedIndexer):
Expand Down Expand Up @@ -506,7 +503,7 @@ def ds(self):
return self.zarr_group

def open_store_variable(self, name, zarr_array):
data = indexing.LazilyIndexedArray(ZarrArrayWrapper(name, self))
data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array))
try_nczarr = self._mode == "r"
dimensions, attributes = _get_zarr_dims_and_attrs(
zarr_array, DIMENSION_KEY, try_nczarr
Expand Down
78 changes: 37 additions & 41 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2836,6 +2836,43 @@ def test_write_empty(
ls = listdir(os.path.join(store, "test"))
assert set(expected) == set([file for file in ls if file[0] != "."])

def test_avoid_excess_metadata_calls(self) -> None:
dcherian marked this conversation as resolved.
Show resolved Hide resolved
"""Test that chunk requests do not trigger redundant metadata requests.

This test targets logic in backends.zarr.ZarrArrayWrapper, asserting that calls
to retrieve chunk data after initialization do not trigger additional
metadata requests.

https://github.com/pydata/xarray/issues/8290
"""

import zarr

ds = xr.Dataset(data_vars={"test": (("Z",), np.array([123]).reshape(1))})

# The call to retrieve metadata performs a group lookup. We patch Group.__getitem__
# so that we can inspect calls to this method - specifically count of calls.
# Use of side_effect means that calls are passed through to the original method
# rather than a mocked method.
Group = zarr.hierarchy.Group
with (
self.create_zarr_target() as store,
patch.object(
Group, "__getitem__", side_effect=Group.__getitem__, autospec=True
) as mock,
):
ds.to_zarr(store, mode="w")

# We expect this to request array metadata information, so call_count should be == 1,
xrds = xr.open_zarr(store)
call_count = mock.call_count
assert call_count == 1
dcherian marked this conversation as resolved.
Show resolved Hide resolved

# compute() requests array data, which should not trigger additional metadata requests
# we assert that the number of calls has not increased after fetchhing the array
xrds.test.compute(scheduler="sync")
assert mock.call_count == call_count


class ZarrBaseV3(ZarrBase):
zarr_version = 3
Expand Down Expand Up @@ -2876,47 +2913,6 @@ def create_zarr_target(self):
yield tmp


@requires_zarr
class TestZarrArrayWrapperCalls(TestZarrKVStoreV3):
def test_avoid_excess_metadata_calls(self) -> None:
"""Test that chunk requests do not trigger redundant metadata requests.

This test targets logic in backends.zarr.ZarrArrayWrapper, asserting that calls
to retrieve chunk data after initialization do not trigger additional
metadata requests.

https://github.com/pydata/xarray/issues/8290
"""

import zarr

ds = xr.Dataset(data_vars={"test": (("Z",), np.array([123]).reshape(1))})

# The call to retrieve metadata performs a group lookup. We patch Group.__getitem__
# so that we can inspect calls to this method - specifically count of calls.
# Use of side_effect means that calls are passed through to the original method
# rather than a mocked method.
Group = zarr.hierarchy.Group
with (
self.create_zarr_target() as store,
patch.object(
Group, "__getitem__", side_effect=Group.__getitem__, autospec=True
) as mock,
):
ds.to_zarr(store, mode="w")

# We expect this to request array metadata information, so call_count should be >= 1,
# At time of writing, 2 calls are made
xrds = xr.open_zarr(store)
call_count = mock.call_count
assert call_count > 0

# compute() requests array data, which should not trigger additional metadata requests
# we assert that the number of calls has not increased after fetchhing the array
xrds.test.compute(scheduler="sync")
assert mock.call_count == call_count


@requires_zarr
@requires_fsspec
def test_zarr_storage_options() -> None:
Expand Down
Loading