Skip to content

Commit

Permalink
Avoid duplicate Zarr array read (#8472)
Browse files Browse the repository at this point in the history
Co-authored-by: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com>
  • Loading branch information
dcherian and andersy005 authored Dec 1, 2023
1 parent b313ffc commit 1715ed3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 50 deletions.
15 changes: 6 additions & 9 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,12 @@ def encode_zarr_attr_value(value):


class ZarrArrayWrapper(BackendArray):
__slots__ = ("datastore", "dtype", "shape", "variable_name", "_array")

def __init__(self, variable_name, datastore):
self.datastore = datastore
self.variable_name = variable_name
__slots__ = ("dtype", "shape", "_array")

def __init__(self, zarr_array):
# some callers attempt to evaluate an array if an `array` property exists on the object.
# we prefix with _ to avoid this inference.
self._array = self.datastore.zarr_group[self.variable_name]
self._array = zarr_array
self.shape = self._array.shape

# preserve vlen string object dtype (GH 7328)
Expand All @@ -86,10 +83,10 @@ def get_array(self):
return self._array

def _oindex(self, key):
return self.get_array().oindex[key]
return self._array.oindex[key]

def __getitem__(self, key):
array = self.get_array()
array = self._array
if isinstance(key, indexing.BasicIndexer):
return array[key.tuple]
elif isinstance(key, indexing.VectorizedIndexer):
Expand Down Expand Up @@ -506,7 +503,7 @@ def ds(self):
return self.zarr_group

def open_store_variable(self, name, zarr_array):
data = indexing.LazilyIndexedArray(ZarrArrayWrapper(name, self))
data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array))
try_nczarr = self._mode == "r"
dimensions, attributes = _get_zarr_dims_and_attrs(
zarr_array, DIMENSION_KEY, try_nczarr
Expand Down
78 changes: 37 additions & 41 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2836,6 +2836,43 @@ def test_write_empty(
ls = listdir(os.path.join(store, "test"))
assert set(expected) == set([file for file in ls if file[0] != "."])

def test_avoid_excess_metadata_calls(self) -> None:
"""Test that chunk requests do not trigger redundant metadata requests.
This test targets logic in backends.zarr.ZarrArrayWrapper, asserting that calls
to retrieve chunk data after initialization do not trigger additional
metadata requests.
https://github.com/pydata/xarray/issues/8290
"""

import zarr

ds = xr.Dataset(data_vars={"test": (("Z",), np.array([123]).reshape(1))})

# The call to retrieve metadata performs a group lookup. We patch Group.__getitem__
# so that we can inspect calls to this method - specifically count of calls.
# Use of side_effect means that calls are passed through to the original method
# rather than a mocked method.
Group = zarr.hierarchy.Group
with (
self.create_zarr_target() as store,
patch.object(
Group, "__getitem__", side_effect=Group.__getitem__, autospec=True
) as mock,
):
ds.to_zarr(store, mode="w")

# We expect this to request array metadata information, so call_count should be == 1,
xrds = xr.open_zarr(store)
call_count = mock.call_count
assert call_count == 1

# compute() requests array data, which should not trigger additional metadata requests
# we assert that the number of calls has not increased after fetchhing the array
xrds.test.compute(scheduler="sync")
assert mock.call_count == call_count


class ZarrBaseV3(ZarrBase):
zarr_version = 3
Expand Down Expand Up @@ -2876,47 +2913,6 @@ def create_zarr_target(self):
yield tmp


@requires_zarr
class TestZarrArrayWrapperCalls(TestZarrKVStoreV3):
def test_avoid_excess_metadata_calls(self) -> None:
"""Test that chunk requests do not trigger redundant metadata requests.
This test targets logic in backends.zarr.ZarrArrayWrapper, asserting that calls
to retrieve chunk data after initialization do not trigger additional
metadata requests.
https://github.com/pydata/xarray/issues/8290
"""

import zarr

ds = xr.Dataset(data_vars={"test": (("Z",), np.array([123]).reshape(1))})

# The call to retrieve metadata performs a group lookup. We patch Group.__getitem__
# so that we can inspect calls to this method - specifically count of calls.
# Use of side_effect means that calls are passed through to the original method
# rather than a mocked method.
Group = zarr.hierarchy.Group
with (
self.create_zarr_target() as store,
patch.object(
Group, "__getitem__", side_effect=Group.__getitem__, autospec=True
) as mock,
):
ds.to_zarr(store, mode="w")

# We expect this to request array metadata information, so call_count should be >= 1,
# At time of writing, 2 calls are made
xrds = xr.open_zarr(store)
call_count = mock.call_count
assert call_count > 0

# compute() requests array data, which should not trigger additional metadata requests
# we assert that the number of calls has not increased after fetchhing the array
xrds.test.compute(scheduler="sync")
assert mock.call_count == call_count


@requires_zarr
@requires_fsspec
def test_zarr_storage_options() -> None:
Expand Down

0 comments on commit 1715ed3

Please sign in to comment.