Skip to content

Making Codec classes self-contained #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/zarr/v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from zarr.v3.array_v2 import ArrayV2
from zarr.v3.common import RuntimeConfiguration # noqa: F401
from zarr.v3.group import Group # noqa: F401
from zarr.v3.metadata import RuntimeConfiguration, runtime_configuration # noqa: F401
from zarr.v3.metadata import runtime_configuration # noqa: F401
from zarr.v3.store import ( # noqa: F401
StoreLike,
make_store_path,
Expand Down
22 changes: 7 additions & 15 deletions src/zarr/v3/abc/codec.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,35 @@
from __future__ import annotations

from abc import abstractmethod, ABC
from typing import TYPE_CHECKING, Optional, Type
from abc import abstractmethod
from typing import TYPE_CHECKING, Optional

import numpy as np
from zarr.v3.abc.metadata import Metadata

from zarr.v3.common import ArraySpec
from zarr.v3.store import StorePath


if TYPE_CHECKING:
from zarr.v3.common import BytesLike, SliceSelection, NamedConfig
from typing_extensions import Self
from zarr.v3.common import BytesLike, SliceSelection
from zarr.v3.metadata import (
ArrayMetadata,
RuntimeConfiguration,
)


class Codec(ABC):
metadata: NamedConfig
class Codec(Metadata):
is_fixed_size: bool

@classmethod
@abstractmethod
def from_metadata(cls, codec_metadata: "NamedConfig") -> Codec:
pass

@classmethod
def get_metadata_class(cls) -> Type[NamedConfig]:
pass

@abstractmethod
def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int:
pass

def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
return chunk_spec

def evolve(self, *, ndim: int, data_type: np.dtype) -> Codec:
def evolve(self, array_spec: ArraySpec) -> Self:
return self

def validate(self, array_metadata: ArrayMetadata) -> None:
Expand Down
10 changes: 7 additions & 3 deletions src/zarr/v3/abc/metadata.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence

if TYPE_CHECKING:
from typing import Dict, Any, Sequence
from typing import Dict, Any
from typing_extensions import Self

from dataclasses import fields

from zarr.v3.common import JSON


class Metadata:
def to_dict(self) -> Dict[str, Any]:
Expand All @@ -23,6 +25,8 @@ def to_dict(self) -> Dict[str, Any]:
value = getattr(self, key)
if isinstance(value, Metadata):
out_dict[field.name] = getattr(self, field.name).to_dict()
elif isinstance(value, str):
out_dict[key] = value
elif isinstance(value, Sequence):
out_dict[key] = [v.to_dict() if isinstance(v, Metadata) else v for v in value]
else:
Expand All @@ -31,7 +35,7 @@ def to_dict(self) -> Dict[str, Any]:
return out_dict

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Self:
def from_dict(cls, data: Dict[str, JSON]) -> Self:
"""
Create an instance of the model from a dictionary
"""
Expand Down
91 changes: 30 additions & 61 deletions src/zarr/v3/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,13 @@
from dataclasses import dataclass, replace

import json
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
from typing import Any, Dict, Iterable, Literal, Optional, Tuple, Union

import numpy as np

from zarr.v3.abc.codec import Codec

# from zarr.v3.array_v2 import ArrayV2
from zarr.v3.codecs.common import decode, encode

# from zarr.v3.array_v2 import ArrayV2
from zarr.v3.codecs import bytes_codec
from zarr.v3.codecs import BytesCodec
from zarr.v3.common import (
ZARR_JSON,
ArraySpec,
Expand All @@ -35,15 +31,9 @@
concurrent_map,
)
from zarr.v3.indexing import BasicIndexer, all_chunk_coords, is_total_slice
from zarr.v3.metadata import (
ArrayMetadata,
DefaultChunkKeyEncodingConfigurationMetadata,
DefaultChunkKeyEncodingMetadata,
RegularChunkGridConfigurationMetadata,
RegularChunkGridMetadata,
V2ChunkKeyEncodingConfigurationMetadata,
V2ChunkKeyEncodingMetadata,
)
from zarr.v3.chunk_grids import RegularChunkGrid
from zarr.v3.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding
from zarr.v3.metadata import ArrayMetadata
from zarr.v3.store import StoreLike, StorePath, make_store_path
from zarr.v3.sync import sync

Expand All @@ -57,38 +47,27 @@ def parse_array_metadata(data: Any):
raise TypeError


@dataclass(frozen=True)
class AsyncArray:
metadata: ArrayMetadata
store_path: StorePath
runtime_configuration: RuntimeConfiguration
codecs: List[Codec]

@property
def codecs(self):
return self.metadata.codecs

@property
def store_path(self):
return self._store_path

def __init__(
self,
metadata: ArrayMetadata,
store_path: StorePath,
runtime_configuration: RuntimeConfiguration,
):
self.metadata = parse_array_metadata(metadata)
self._store_path = store_path
self.runtime_configuration = runtime_configuration
metadata_parsed = parse_array_metadata(metadata)

async def encode_chunk(self, data: np.ndarray):
"""
Encode a numpy array using the codec pipeline
"""
return await encode(self.codecs, data, self.runtime_configuration)

async def decode_chunk(self, data: bytes):
return await decode(self.codecs, data, self.runtime_configuration)
object.__setattr__(self, "metadata", metadata_parsed)
object.__setattr__(self, "store_path", store_path)
object.__setattr__(self, "runtime_configuration", runtime_configuration)

@classmethod
async def create(
Expand All @@ -113,7 +92,7 @@ async def create(
if not exists_ok:
assert not await (store_path / ZARR_JSON).exists()

codecs = list(codecs) if codecs is not None else [bytes_codec()]
codecs = list(codecs) if codecs is not None else [BytesCodec()]

if fill_value is None:
if dtype == np.dtype("bool"):
Expand All @@ -124,21 +103,11 @@ async def create(
metadata = ArrayMetadata(
shape=shape,
data_type=dtype,
chunk_grid=RegularChunkGridMetadata(
configuration=RegularChunkGridConfigurationMetadata(chunk_shape=chunk_shape)
),
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
chunk_key_encoding=(
V2ChunkKeyEncodingMetadata(
configuration=V2ChunkKeyEncodingConfigurationMetadata(
separator=chunk_key_encoding[1]
)
)
V2ChunkKeyEncoding(separator=chunk_key_encoding[1])
if chunk_key_encoding[0] == "v2"
else DefaultChunkKeyEncodingMetadata(
configuration=DefaultChunkKeyEncodingConfigurationMetadata(
separator=chunk_key_encoding[1]
)
)
else DefaultChunkKeyEncoding(separator=chunk_key_encoding[1])
),
fill_value=fill_value,
codecs=codecs,
Expand Down Expand Up @@ -228,7 +197,7 @@ async def getitem(self, selection: Selection):
indexer = BasicIndexer(
selection,
shape=self.metadata.shape,
chunk_shape=self.metadata.chunk_grid.configuration.chunk_shape,
chunk_shape=self.metadata.chunk_grid.chunk_shape,
)

# setup output array
Expand Down Expand Up @@ -260,13 +229,13 @@ async def _save_metadata(self) -> None:

def _validate_metadata(self) -> None:
assert len(self.metadata.shape) == len(
self.metadata.chunk_grid.configuration.chunk_shape
self.metadata.chunk_grid.chunk_shape
), "`chunk_shape` and `shape` need to have the same number of dimensions."
assert self.metadata.dimension_names is None or len(self.metadata.shape) == len(
self.metadata.dimension_names
), "`dimension_names` and `shape` need to have the same number of dimensions."
assert self.metadata.fill_value is not None, "`fill_value` is required."
# self.codecs.validate(self.metadata)
self.codecs.validate(self.metadata)

async def _read_chunk(
self,
Expand All @@ -280,8 +249,8 @@ async def _read_chunk(
chunk_key = chunk_key_encoding.encode_chunk_key(chunk_coords)
store_path = self.store_path / chunk_key

if self.codec_pipeline.supports_partial_decode:
chunk_array = await self.codec_pipeline.decode_partial(
if self.codecs.supports_partial_decode:
chunk_array = await self.codecs.decode_partial(
store_path, chunk_selection, chunk_spec, self.runtime_configuration
)
if chunk_array is not None:
Expand All @@ -291,7 +260,7 @@ async def _read_chunk(
else:
chunk_bytes = await store_path.get()
if chunk_bytes is not None:
chunk_array = await self.codec_pipeline.decode(
chunk_array = await self.codecs.decode(
chunk_bytes, chunk_spec, self.runtime_configuration
)
tmp = chunk_array[chunk_selection]
Expand All @@ -300,7 +269,7 @@ async def _read_chunk(
out[out_selection] = self.metadata.fill_value

async def setitem(self, selection: Selection, value: np.ndarray) -> None:
chunk_shape = self.metadata.chunk_grid.configuration.chunk_shape
chunk_shape = self.metadata.chunk_grid.chunk_shape
indexer = BasicIndexer(
selection,
shape=self.metadata.shape,
Expand Down Expand Up @@ -361,9 +330,9 @@ async def _write_chunk(
chunk_array = value[out_selection]
await self._write_chunk_to_store(store_path, chunk_array, chunk_spec)

elif self.codec_pipeline.supports_partial_encode:
elif self.codecs.supports_partial_encode:
# print("encode_partial", chunk_coords, chunk_selection, repr(self))
await self.codec_pipeline.encode_partial(
await self.codecs.encode_partial(
store_path,
value[out_selection],
chunk_selection,
Expand All @@ -384,9 +353,7 @@ async def _write_chunk(
chunk_array.fill(self.metadata.fill_value)
else:
chunk_array = (
await self.codec_pipeline.decode(
chunk_bytes, chunk_spec, self.runtime_configuration
)
await self.codecs.decode(chunk_bytes, chunk_spec, self.runtime_configuration)
).copy() # make a writable copy
chunk_array[chunk_selection] = value[out_selection]

Expand All @@ -399,7 +366,9 @@ async def _write_chunk_to_store(
# chunks that only contain fill_value will be removed
await store_path.delete()
else:
chunk_bytes = await encode(self.codecs, chunk_array, self.runtime_configuration)
chunk_bytes = await self.codecs.encode(
chunk_array, chunk_spec, self.runtime_configuration
)
if chunk_bytes is None:
await store_path.delete()
else:
Expand All @@ -410,7 +379,7 @@ async def resize(self, new_shape: ChunkCoords) -> AsyncArray:
new_metadata = replace(self.metadata, shape=new_shape)

# Remove all chunks outside of the new shape
chunk_shape = self.metadata.chunk_grid.configuration.chunk_shape
chunk_shape = self.metadata.chunk_grid.chunk_shape
chunk_key_encoding = self.metadata.chunk_key_encoding
old_chunk_coords = set(all_chunk_coords(self.metadata.shape, chunk_shape))
new_chunk_coords = set(all_chunk_coords(new_shape, chunk_shape))
Expand All @@ -429,14 +398,14 @@ async def _delete_key(key: str) -> None:

# Write new metadata
await (self.store_path / ZARR_JSON).set(new_metadata.to_bytes())
return evolve(self, metadata=new_metadata)
return replace(self, metadata=new_metadata)

async def update_attributes(self, new_attributes: Dict[str, Any]) -> Array:
new_metadata = replace(self.metadata, attributes=new_attributes)

# Write new metadata
await (self.store_path / ZARR_JSON).set(new_metadata.to_bytes())
return evolve(self, metadata=new_metadata)
return replace(self, metadata=new_metadata)

def __repr__(self):
return f"<AsyncArray {self.store_path} shape={self.shape} dtype={self.dtype}>"
Expand Down
Loading