Skip to content

Commit

Permalink
Buffer Prototype Argument (#1910)
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk authored Jun 4, 2024
1 parent b431cf7 commit 661acb3
Show file tree
Hide file tree
Showing 25 changed files with 438 additions and 252 deletions.
3 changes: 1 addition & 2 deletions src/zarr/abc/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
if TYPE_CHECKING:
from typing_extensions import Self

from zarr.common import ArraySpec
from zarr.array_spec import ArraySpec
from zarr.indexing import SelectorTuple
from zarr.metadata import ArrayMetadata


CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer)
CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer)

Expand Down
19 changes: 14 additions & 5 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import AsyncGenerator
from typing import Protocol, runtime_checkable

from zarr.buffer import Buffer
from zarr.buffer import Buffer, BufferPrototype
from zarr.common import BytesLike, OpenMode


Expand Down Expand Up @@ -30,7 +30,10 @@ def _check_writable(self) -> None:

@abstractmethod
async def get(
self, key: str, byte_range: tuple[int | None, int | None] | None = None
self,
key: str,
prototype: BufferPrototype,
byte_range: tuple[int | None, int | None] | None = None,
) -> Buffer | None:
"""Retrieve the value associated with a given key.
Expand All @@ -47,7 +50,9 @@ async def get(

@abstractmethod
async def get_partial_values(
self, key_ranges: list[tuple[str, tuple[int | None, int | None]]]
self,
prototype: BufferPrototype,
key_ranges: list[tuple[str, tuple[int | None, int | None]]],
) -> list[Buffer | None]:
"""Retrieve possibly partial values from given key_ranges.
Expand Down Expand Up @@ -175,12 +180,16 @@ def close(self) -> None: # noqa: B027

@runtime_checkable
class ByteGetter(Protocol):
async def get(self, byte_range: tuple[int, int | None] | None = None) -> Buffer | None: ...
async def get(
self, prototype: BufferPrototype, byte_range: tuple[int, int | None] | None = None
) -> Buffer | None: ...


@runtime_checkable
class ByteSetter(Protocol):
async def get(self, byte_range: tuple[int, int | None] | None = None) -> Buffer | None: ...
async def get(
self, prototype: BufferPrototype, byte_range: tuple[int, int | None] | None = None
) -> Buffer | None: ...

async def set(self, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: ...

Expand Down
104 changes: 79 additions & 25 deletions src/zarr/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from zarr.abc.codec import Codec
from zarr.abc.store import set_or_delete
from zarr.attributes import Attributes
from zarr.buffer import Factory, NDArrayLike, NDBuffer
from zarr.buffer import BufferPrototype, NDArrayLike, NDBuffer, default_buffer_prototype
from zarr.chunk_grids import RegularChunkGrid
from zarr.chunk_key_encodings import ChunkKeyEncoding, DefaultChunkKeyEncoding, V2ChunkKeyEncoding
from zarr.codecs import BytesCodec
Expand Down Expand Up @@ -414,8 +414,8 @@ async def _get_selection(
self,
indexer: Indexer,
*,
prototype: BufferPrototype,
out: NDBuffer | None = None,
factory: Factory.Create = NDBuffer.create,
fields: Fields | None = None,
) -> NDArrayLike:
# check fields are sensible
Expand All @@ -432,7 +432,7 @@ async def _get_selection(
f"shape of out argument doesn't match. Expected {indexer.shape}, got {out.shape}"
)
else:
out_buffer = factory(
out_buffer = prototype.nd_buffer.create(
shape=indexer.shape,
dtype=out_dtype,
order=self.order,
Expand All @@ -444,7 +444,7 @@ async def _get_selection(
[
(
self.store_path / self.metadata.encode_chunk_key(chunk_coords),
self.metadata.get_chunk_spec(chunk_coords, self.order),
self.metadata.get_chunk_spec(chunk_coords, self.order, prototype=prototype),
chunk_selection,
out_selection,
)
Expand All @@ -456,14 +456,14 @@ async def _get_selection(
return out_buffer.as_ndarray_like()

async def getitem(
self, selection: Selection, *, factory: Factory.Create = NDBuffer.create
self, selection: Selection, *, prototype: BufferPrototype = default_buffer_prototype
) -> NDArrayLike:
indexer = BasicIndexer(
selection,
shape=self.metadata.shape,
chunk_grid=self.metadata.chunk_grid,
)
return await self._get_selection(indexer, factory=factory)
return await self._get_selection(indexer, prototype=prototype)

async def _save_metadata(self, metadata: ArrayMetadata) -> None:
to_save = metadata.to_buffer_dict()
Expand All @@ -475,7 +475,7 @@ async def _set_selection(
indexer: Indexer,
value: NDArrayLike,
*,
factory: Factory.NDArrayLike = NDBuffer.from_ndarray_like,
prototype: BufferPrototype,
fields: Fields | None = None,
) -> None:
# check fields are sensible
Expand All @@ -497,14 +497,14 @@ async def _set_selection(
# We accept any ndarray like object from the user and convert it
# to a NDBuffer (or subclass). From this point onwards, we only pass
# Buffer and NDBuffer between components.
value_buffer = factory(value)
value_buffer = prototype.nd_buffer.from_ndarray_like(value)

# merging with existing data and encoding chunks
await self.metadata.codec_pipeline.write(
[
(
self.store_path / self.metadata.encode_chunk_key(chunk_coords),
self.metadata.get_chunk_spec(chunk_coords, self.order),
self.metadata.get_chunk_spec(chunk_coords, self.order, prototype),
chunk_selection,
out_selection,
)
Expand All @@ -518,14 +518,14 @@ async def setitem(
self,
selection: Selection,
value: NDArrayLike,
factory: Factory.NDArrayLike = NDBuffer.from_ndarray_like,
prototype: BufferPrototype = default_buffer_prototype,
) -> None:
indexer = BasicIndexer(
selection,
shape=self.metadata.shape,
chunk_grid=self.metadata.chunk_grid,
)
return await self._set_selection(indexer, value, factory=factory)
return await self._set_selection(indexer, value, prototype=prototype)

async def resize(
self, new_shape: ChunkCoords, delete_outside_chunks: bool = True
Expand Down Expand Up @@ -714,7 +714,9 @@ def __setitem__(self, selection: Selection, value: NDArrayLike) -> None:
def get_basic_selection(
self,
selection: BasicSelection = Ellipsis,
*,
out: NDBuffer | None = None,
prototype: BufferPrototype = default_buffer_prototype,
fields: Fields | None = None,
) -> NDArrayLike:
if self.shape == ():
Expand All @@ -725,57 +727,101 @@ def get_basic_selection(
BasicIndexer(selection, self.shape, self.metadata.chunk_grid),
out=out,
fields=fields,
prototype=prototype,
)
)

def set_basic_selection(
self, selection: BasicSelection, value: NDArrayLike, fields: Fields | None = None
self,
selection: BasicSelection,
value: NDArrayLike,
*,
fields: Fields | None = None,
prototype: BufferPrototype = default_buffer_prototype,
) -> None:
indexer = BasicIndexer(selection, self.shape, self.metadata.chunk_grid)
sync(self._async_array._set_selection(indexer, value, fields=fields))
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))

def get_orthogonal_selection(
self,
selection: OrthogonalSelection,
*,
out: NDBuffer | None = None,
fields: Fields | None = None,
prototype: BufferPrototype = default_buffer_prototype,
) -> NDArrayLike:
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
return sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields))
return sync(
self._async_array._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
)
)

def set_orthogonal_selection(
self, selection: OrthogonalSelection, value: NDArrayLike, fields: Fields | None = None
self,
selection: OrthogonalSelection,
value: NDArrayLike,
*,
fields: Fields | None = None,
prototype: BufferPrototype = default_buffer_prototype,
) -> None:
indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid)
return sync(self._async_array._set_selection(indexer, value, fields=fields))
return sync(
self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)
)

def get_mask_selection(
self, mask: MaskSelection, out: NDBuffer | None = None, fields: Fields | None = None
self,
mask: MaskSelection,
*,
out: NDBuffer | None = None,
fields: Fields | None = None,
prototype: BufferPrototype = default_buffer_prototype,
) -> NDArrayLike:
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
return sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields))
return sync(
self._async_array._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
)
)

def set_mask_selection(
self, mask: MaskSelection, value: NDArrayLike, fields: Fields | None = None
self,
mask: MaskSelection,
value: NDArrayLike,
*,
fields: Fields | None = None,
prototype: BufferPrototype = default_buffer_prototype,
) -> None:
indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid)
sync(self._async_array._set_selection(indexer, value, fields=fields))
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))

def get_coordinate_selection(
self,
selection: CoordinateSelection,
*,
out: NDBuffer | None = None,
fields: Fields | None = None,
prototype: BufferPrototype = default_buffer_prototype,
) -> NDArrayLike:
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
out_array = sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields))
out_array = sync(
self._async_array._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
)
)

# restore shape
out_array = out_array.reshape(indexer.sel_shape)
return out_array

def set_coordinate_selection(
self, selection: CoordinateSelection, value: NDArrayLike, fields: Fields | None = None
self,
selection: CoordinateSelection,
value: NDArrayLike,
*,
fields: Fields | None = None,
prototype: BufferPrototype = default_buffer_prototype,
) -> None:
# setup indexer
indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid)
Expand All @@ -792,25 +838,33 @@ def set_coordinate_selection(
if hasattr(value, "shape") and len(value.shape) > 1:
value = value.reshape(-1)

sync(self._async_array._set_selection(indexer, value, fields=fields))
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))

def get_block_selection(
self,
selection: BlockSelection,
*,
out: NDBuffer | None = None,
fields: Fields | None = None,
prototype: BufferPrototype = default_buffer_prototype,
) -> NDArrayLike:
indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid)
return sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields))
return sync(
self._async_array._get_selection(
indexer=indexer, out=out, fields=fields, prototype=prototype
)
)

def set_block_selection(
self,
selection: BlockSelection,
value: NDArrayLike,
*,
fields: Fields | None = None,
prototype: BufferPrototype = default_buffer_prototype,
) -> None:
indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid)
sync(self._async_array._set_selection(indexer, value, fields=fields))
sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype))

@property
def vindex(self) -> VIndex:
Expand Down
41 changes: 41 additions & 0 deletions src/zarr/array_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Literal

import numpy as np

from zarr.buffer import BufferPrototype
from zarr.common import ChunkCoords, parse_dtype, parse_fill_value, parse_order, parse_shapelike


@dataclass(frozen=True)
class ArraySpec:
shape: ChunkCoords
dtype: np.dtype[Any]
fill_value: Any
order: Literal["C", "F"]
prototype: BufferPrototype

def __init__(
self,
shape: ChunkCoords,
dtype: np.dtype[Any],
fill_value: Any,
order: Literal["C", "F"],
prototype: BufferPrototype,
) -> None:
shape_parsed = parse_shapelike(shape)
dtype_parsed = parse_dtype(dtype)
fill_value_parsed = parse_fill_value(fill_value)
order_parsed = parse_order(order)

object.__setattr__(self, "shape", shape_parsed)
object.__setattr__(self, "dtype", dtype_parsed)
object.__setattr__(self, "fill_value", fill_value_parsed)
object.__setattr__(self, "order", order_parsed)
object.__setattr__(self, "prototype", prototype)

@property
def ndim(self) -> int:
return len(self.shape)
Loading

0 comments on commit 661acb3

Please sign in to comment.