Skip to content

Commit b14f15f

Browse files
madsbkjakirkhamjoshmoore
authored
Getitems: support meta_array (#1131)
* Use _chunk_getitems() always * Implement getitems() always * FSStore.getitems(): accept meta_array and on_error * getitems(): handle on_error="omit" * Removed the `on_error argument` * remove redundant check * getitems(): use Sequence instead of Iterable * Typo Co-authored-by: Josh Moore <josh@openmicroscopy.org> * Introduce a contexts argument * CountingDict: impl. getitems() * added test_getitems() * Introduce Context * doc * support the new get_partial_values() method * Resolve conflict with get_partial_values() * make contexts keyword-only * Introduce ConstantMap * use typing.Mapping * test_constant_map --------- Co-authored-by: jakirkham <jakirkham@gmail.com> Co-authored-by: Josh Moore <josh@openmicroscopy.org>
1 parent 4b0705c commit b14f15f

File tree

9 files changed

+190
-70
lines changed

9 files changed

+190
-70
lines changed

zarr/_storage/store.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from zarr.meta import Metadata2, Metadata3
1010
from zarr.util import normalize_storage_path
11+
from zarr.context import Context
1112

1213
# v2 store keys
1314
array_meta_key = '.zarray'
@@ -131,6 +132,33 @@ def _ensure_store(store: Any):
131132
f"wrap it in Zarr.storage.KVStore. Got {store}"
132133
)
133134

135+
def getitems(
136+
self, keys: Sequence[str], *, contexts: Mapping[str, Context]
137+
) -> Mapping[str, Any]:
138+
"""Retrieve data from multiple keys.
139+
140+
Parameters
141+
----------
142+
keys : Iterable[str]
143+
The keys to retrieve
144+
contexts: Mapping[str, Context]
145+
A mapping of keys to their context. Each context is a mapping of store
146+
specific information. E.g. a context could be a dict telling the store
147+
the preferred output array type: `{"meta_array": cupy.empty(())}`
148+
149+
Returns
150+
-------
151+
Mapping
152+
A collection mapping the input keys to their results.
153+
154+
Notes
155+
-----
156+
This default implementation uses __getitem__() to read each key sequentially and
157+
ignores contexts. Overwrite this method to implement concurrent reads of multiple
158+
keys and/or to utilize the contexts.
159+
"""
160+
return {k: self[k] for k in keys if k in self}
161+
134162

135163
class Store(BaseStore):
136164
"""Abstract store class used by implementations following the Zarr v2 spec.

zarr/context.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
2+
from typing import TypedDict
3+
4+
from numcodecs.compat import NDArrayLike
5+
6+
7+
class Context(TypedDict, total=False):
8+
""" A context for component specific information
9+
10+
All keys are optional. Any component reading the context must provide
11+
a default implementation in the case a key cannot be found.
12+
13+
Items
14+
-----
15+
meta_array : array-like, optional
16+
An array-like instance to use for determining the preferred output
17+
array type.
18+
"""
19+
meta_array: NDArrayLike

zarr/core.py

Lines changed: 28 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from zarr._storage.store import _prefix_to_attrs_key, assert_zarr_v3_api_available
1414
from zarr.attrs import Attributes
1515
from zarr.codecs import AsType, get_codec
16+
from zarr.context import Context
1617
from zarr.errors import ArrayNotFoundError, ReadOnlyError, ArrayIndexError
1718
from zarr.indexing import (
1819
BasicIndexer,
@@ -41,6 +42,7 @@
4142
normalize_store_arg,
4243
)
4344
from zarr.util import (
45+
ConstantMap,
4446
all_equal,
4547
InfoReporter,
4648
check_array_shape,
@@ -1275,24 +1277,14 @@ def _get_selection(self, indexer, out=None, fields=None):
12751277
check_array_shape('out', out, out_shape)
12761278

12771279
# iterate over chunks
1278-
if (
1279-
not hasattr(self.chunk_store, "getitems") and not (
1280-
hasattr(self.chunk_store, "get_partial_values") and
1281-
self.chunk_store.supports_efficient_get_partial_values
1282-
)
1283-
) or any(map(lambda x: x == 0, self.shape)):
1284-
# sequentially get one key at a time from storage
1285-
for chunk_coords, chunk_selection, out_selection in indexer:
12861280

1287-
# load chunk selection into output array
1288-
self._chunk_getitem(chunk_coords, chunk_selection, out, out_selection,
1289-
drop_axes=indexer.drop_axes, fields=fields)
1290-
else:
1281+
if math.prod(out_shape) > 0:
12911282
# allow storage to get multiple items at once
12921283
lchunk_coords, lchunk_selection, lout_selection = zip(*indexer)
1293-
self._chunk_getitems(lchunk_coords, lchunk_selection, out, lout_selection,
1294-
drop_axes=indexer.drop_axes, fields=fields)
1295-
1284+
self._chunk_getitems(
1285+
lchunk_coords, lchunk_selection, out, lout_selection,
1286+
drop_axes=indexer.drop_axes, fields=fields
1287+
)
12961288
if out.shape:
12971289
return out
12981290
else:
@@ -1963,68 +1955,36 @@ def _process_chunk(
19631955
# store selected data in output
19641956
out[out_selection] = tmp
19651957

1966-
def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection,
1967-
drop_axes=None, fields=None):
1968-
"""Obtain part or whole of a chunk.
1958+
def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection,
1959+
drop_axes=None, fields=None):
1960+
"""Obtain part or whole of chunks.
19691961
19701962
Parameters
19711963
----------
1972-
chunk_coords : tuple of ints
1973-
Indices of the chunk.
1974-
chunk_selection : selection
1975-
Location of region within the chunk to extract.
1964+
chunk_coords : list of tuple of ints
1965+
Indices of the chunks.
1966+
chunk_selection : list of selections
1967+
Location of region within the chunks to extract.
19761968
out : ndarray
19771969
Array to store result in.
1978-
out_selection : selection
1979-
Location of region within output array to store results in.
1970+
out_selection : list of selections
1971+
Location of regions within output array to store results in.
19801972
drop_axes : tuple of ints
19811973
Axes to squeeze out of the chunk.
19821974
fields
19831975
TODO
1984-
19851976
"""
1986-
out_is_ndarray = True
1987-
try:
1988-
out = ensure_ndarray_like(out)
1989-
except TypeError:
1990-
out_is_ndarray = False
1991-
1992-
assert len(chunk_coords) == len(self._cdata_shape)
1993-
1994-
# obtain key for chunk
1995-
ckey = self._chunk_key(chunk_coords)
19961977

1997-
try:
1998-
# obtain compressed data for chunk
1999-
cdata = self.chunk_store[ckey]
2000-
2001-
except KeyError:
2002-
# chunk not initialized
2003-
if self._fill_value is not None:
2004-
if fields:
2005-
fill_value = self._fill_value[fields]
2006-
else:
2007-
fill_value = self._fill_value
2008-
out[out_selection] = fill_value
2009-
2010-
else:
2011-
self._process_chunk(out, cdata, chunk_selection, drop_axes,
2012-
out_is_ndarray, fields, out_selection)
2013-
2014-
def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection,
2015-
drop_axes=None, fields=None):
2016-
"""As _chunk_getitem, but for lists of chunks
2017-
2018-
This gets called where the storage supports ``getitems``, so that
2019-
it can decide how to fetch the keys, allowing concurrency.
2020-
"""
20211978
out_is_ndarray = True
20221979
try:
20231980
out = ensure_ndarray_like(out)
20241981
except TypeError: # pragma: no cover
20251982
out_is_ndarray = False
20261983

1984+
# Keys to retrieve
20271985
ckeys = [self._chunk_key(ch) for ch in lchunk_coords]
1986+
1987+
# Check if we can do a partial read
20281988
if (
20291989
self._partial_decompress
20301990
and self._compressor
@@ -2056,13 +2016,17 @@ def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection,
20562016
for ckey in ckeys
20572017
if ckey in self.chunk_store
20582018
}
2019+
elif hasattr(self.chunk_store, "get_partial_values"):
2020+
partial_read_decode = False
2021+
values = self.chunk_store.get_partial_values([(ckey, (0, None)) for ckey in ckeys])
2022+
cdatas = {key: value for key, value in zip(ckeys, values) if value is not None}
20592023
else:
20602024
partial_read_decode = False
2061-
if not hasattr(self.chunk_store, "getitems"):
2062-
values = self.chunk_store.get_partial_values([(ckey, (0, None)) for ckey in ckeys])
2063-
cdatas = {key: value for key, value in zip(ckeys, values) if value is not None}
2064-
else:
2065-
cdatas = self.chunk_store.getitems(ckeys, on_error="omit")
2025+
contexts = {}
2026+
if not isinstance(self._meta_array, np.ndarray):
2027+
contexts = ConstantMap(ckeys, constant=Context(meta_array=self._meta_array))
2028+
cdatas = self.chunk_store.getitems(ckeys, contexts=contexts)
2029+
20662030
for ckey, chunk_select, out_select in zip(ckeys, lchunk_selection, lout_selection):
20672031
if ckey in cdatas:
20682032
self._process_chunk(

zarr/storage.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from os import scandir
3232
from pickle import PicklingError
3333
from threading import Lock, RLock
34-
from typing import Optional, Union, List, Tuple, Dict, Any
34+
from typing import Sequence, Mapping, Optional, Union, List, Tuple, Dict, Any
3535
import uuid
3636
import time
3737

@@ -42,6 +42,7 @@
4242
ensure_contiguous_ndarray_like
4343
)
4444
from numcodecs.registry import codec_registry
45+
from zarr.context import Context
4546

4647
from zarr.errors import (
4748
MetadataError,
@@ -1380,7 +1381,10 @@ def _normalize_key(self, key):
13801381

13811382
return key.lower() if self.normalize_keys else key
13821383

1383-
def getitems(self, keys, **kwargs):
1384+
def getitems(
1385+
self, keys: Sequence[str], *, contexts: Mapping[str, Context]
1386+
) -> Mapping[str, Any]:
1387+
13841388
keys_transformed = [self._normalize_key(key) for key in keys]
13851389
results = self.map.getitems(keys_transformed, on_error="omit")
13861390
# The function calling this method may not recognize the transformed keys

zarr/tests/test_storage.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import zarr
2121
from zarr._storage.store import _get_hierarchy_metadata
2222
from zarr.codecs import BZ2, AsType, Blosc, Zlib
23+
from zarr.context import Context
2324
from zarr.convenience import consolidate_metadata
2425
from zarr.errors import ContainsArrayError, ContainsGroupError, MetadataError
2526
from zarr.hierarchy import group
@@ -37,7 +38,7 @@
3738
from zarr.storage import FSStore, rename, listdir
3839
from zarr._storage.v3 import KVStoreV3
3940
from zarr.tests.util import CountingDict, have_fsspec, skip_test_env_var, abs_container, mktemp
40-
from zarr.util import json_dumps
41+
from zarr.util import ConstantMap, json_dumps
4142

4243

4344
@contextmanager
@@ -2584,3 +2585,35 @@ def test_meta_prefix_6853():
25842585

25852586
fixtures = group(store=DirectoryStore(str(fixture)))
25862587
assert list(fixtures.arrays())
2588+
2589+
2590+
def test_getitems_contexts():
2591+
2592+
class MyStore(CountingDict):
2593+
def __init__(self):
2594+
super().__init__()
2595+
self.last_contexts = None
2596+
2597+
def getitems(self, keys, *, contexts):
2598+
self.last_contexts = contexts
2599+
return super().getitems(keys, contexts=contexts)
2600+
2601+
store = MyStore()
2602+
z = zarr.create(shape=(10,), chunks=1, store=store)
2603+
2604+
# By default, not contexts are given to the store's getitems()
2605+
z[0]
2606+
assert len(store.last_contexts) == 0
2607+
2608+
# Setting a non-default meta_array, will create contexts for the store's getitems()
2609+
z._meta_array = "my_meta_array"
2610+
z[0]
2611+
assert store.last_contexts == {'0': {'meta_array': 'my_meta_array'}}
2612+
assert isinstance(store.last_contexts, ConstantMap)
2613+
# Accseeing different chunks should trigger different key request
2614+
z[1]
2615+
assert store.last_contexts == {'1': {'meta_array': 'my_meta_array'}}
2616+
assert isinstance(store.last_contexts, ConstantMap)
2617+
z[2:4]
2618+
assert store.last_contexts == ConstantMap(['2', '3'], Context({'meta_array': 'my_meta_array'}))
2619+
assert isinstance(store.last_contexts, ConstantMap)

zarr/tests/test_storage_v3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,8 @@ def _get_public_and_dunder_methods(some_class):
666666
def test_storage_transformer_interface():
667667
store_v3_methods = _get_public_and_dunder_methods(StoreV3)
668668
store_v3_methods.discard("__init__")
669+
# Note, getitems() isn't mandatory when get_partial_values() is available
670+
store_v3_methods.discard("getitems")
669671
storage_transformer_methods = _get_public_and_dunder_methods(StorageTransformer)
670672
storage_transformer_methods.discard("__init__")
671673
storage_transformer_methods.discard("get_config")

zarr/tests/test_util.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
from zarr.core import Array
8-
from zarr.util import (all_equal, flatten, guess_chunks, human_readable_size,
8+
from zarr.util import (ConstantMap, all_equal, flatten, guess_chunks, human_readable_size,
99
info_html_report, info_text_report, is_total_slice,
1010
json_dumps, normalize_chunks,
1111
normalize_dimension_separator,
@@ -248,3 +248,16 @@ def test_json_dumps_numpy_dtype():
248248
# Check that we raise the error of the superclass for unsupported object
249249
with pytest.raises(TypeError):
250250
json_dumps(Array)
251+
252+
253+
def test_constant_map():
254+
val = object()
255+
m = ConstantMap(keys=[1, 2], constant=val)
256+
assert len(m) == 2
257+
assert m[1] is val
258+
assert m[2] is val
259+
assert 1 in m
260+
assert 0 not in m
261+
with pytest.raises(KeyError):
262+
m[0]
263+
assert repr(m) == repr({1: val, 2: val})

zarr/tests/util.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import collections
22
import os
33
import tempfile
4+
from typing import Any, Mapping, Sequence
5+
from zarr.context import Context
46

57
from zarr.storage import Store
68
from zarr._storage.v3 import StoreV3
@@ -42,6 +44,13 @@ def __delitem__(self, key):
4244
self.counter['__delitem__', key] += 1
4345
del self.wrapped[key]
4446

47+
def getitems(
48+
self, keys: Sequence[str], *, contexts: Mapping[str, Context]
49+
) -> Mapping[str, Any]:
50+
for key in keys:
51+
self.counter['__getitem__', key] += 1
52+
return {k: self.wrapped[k] for k in keys if k in self.wrapped}
53+
4554

4655
class CountingDictV3(CountingDict, StoreV3):
4756
pass

0 commit comments

Comments
 (0)