Skip to content

Make create_array signatures consistent #2819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
4 changes: 4 additions & 0 deletions changes/2819.chore.rst
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since "chore type" changelong entries don't get rendered (e.g., see https://zarr.readthedocs.io/en/stable/release-notes.html#misc), I'd recommend splitting this up into a "feature" entry for the updated signatures, and a separate "feature" for the change in default fill_value arguments.

Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Ensure that invocations of ``create_array`` use consistent keyword arguments, with consistent defaults.
Specifically, ``zarr.api.synchronous.create_array`` now takes a ``write_data`` keyword argument; The
``create_array`` method on ``zarr.Group`` takes ``data`` and ``write_data`` keyword arguments. The ``fill_value``
keyword argument of the various invocations of ``create_array`` has been consistently set to ``None``, where previously it was either ``None`` or ``0``.
2 changes: 1 addition & 1 deletion src/zarr/api/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ async def open_group(
async def create(
shape: ChunkCoords | int,
*, # Note: this is a change from v2
chunks: ChunkCoords | int | None = None, # TODO: v2 allowed chunks=True
chunks: ChunkCoords | int | bool | None = None,
dtype: npt.DTypeLike | None = None,
compressor: CompressorLike = "auto",
fill_value: Any | None = 0, # TODO: need type
Expand Down
47 changes: 34 additions & 13 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,22 +996,24 @@ async def create_array(
self,
name: str,
*,
shape: ShapeLike,
dtype: npt.DTypeLike,
shape: ShapeLike | None = None,
dtype: npt.DTypeLike | None = None,
data: np.ndarray[Any, np.dtype[Any]] | None = None,
chunks: ChunkCoords | Literal["auto"] = "auto",
shards: ShardsLike | None = None,
filters: FiltersLike = "auto",
compressors: CompressorsLike = "auto",
compressor: CompressorLike = "auto",
serializer: SerializerLike = "auto",
fill_value: Any | None = 0,
fill_value: Any | None = None,
order: MemoryOrder | None = None,
attributes: dict[str, JSON] | None = None,
chunk_key_encoding: ChunkKeyEncodingLike | None = None,
dimension_names: DimensionNames = None,
storage_options: dict[str, Any] | None = None,
overwrite: bool = False,
config: ArrayConfig | ArrayConfigLike | None = None,
config: ArrayConfigLike | None = None,
write_data: bool = True,
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
"""Create an array within this group.

Expand Down Expand Up @@ -1099,6 +1101,11 @@ async def create_array(
Whether to overwrite an array with the same name in the store, if one exists.
config : ArrayConfig or ArrayConfigLike, optional
Runtime configuration for the array.
write_data : bool
If a pre-existing array-like object was provided to this function via the ``data`` parameter
then ``write_data`` determines whether the values in that array-like object should be
written to the Zarr array created by this function. If ``write_data`` is ``False``, then the
array will be left empty.

Returns
-------
Expand All @@ -1113,6 +1120,7 @@ async def create_array(
name=name,
shape=shape,
dtype=dtype,
data=data,
chunks=chunks,
shards=shards,
filters=filters,
Expand All @@ -1127,6 +1135,7 @@ async def create_array(
storage_options=storage_options,
overwrite=overwrite,
config=config,
write_data=write_data,
)

@deprecated("Use AsyncGroup.create_array instead.")
Expand Down Expand Up @@ -2371,22 +2380,24 @@ def create_array(
self,
name: str,
*,
shape: ShapeLike,
dtype: npt.DTypeLike,
shape: ShapeLike | None = None,
dtype: npt.DTypeLike | None = None,
data: np.ndarray[Any, np.dtype[Any]] | None = None,
chunks: ChunkCoords | Literal["auto"] = "auto",
shards: ShardsLike | None = None,
filters: FiltersLike = "auto",
compressors: CompressorsLike = "auto",
compressor: CompressorLike = "auto",
serializer: SerializerLike = "auto",
fill_value: Any | None = 0,
order: MemoryOrder | None = "C",
fill_value: Any | None = None,
order: MemoryOrder | None = None,
attributes: dict[str, JSON] | None = None,
chunk_key_encoding: ChunkKeyEncodingLike | None = None,
dimension_names: DimensionNames = None,
storage_options: dict[str, Any] | None = None,
overwrite: bool = False,
config: ArrayConfig | ArrayConfigLike | None = None,
config: ArrayConfigLike | None = None,
write_data: bool = True,
) -> Array:
"""Create an array within this group.

Expand All @@ -2397,10 +2408,13 @@ def create_array(
name : str
The name of the array relative to the group. If ``path`` is ``None``, the array will be located
at the root of the store.
shape : ChunkCoords
Shape of the array.
dtype : npt.DTypeLike
Data type of the array.
shape : ChunkCoords, optional
Shape of the array. Can be ``None`` if ``data`` is provided.
dtype : npt.DTypeLike | None
Data type of the array. Can be ``None`` if ``data`` is provided.
data : Array-like data to use for initializing the array. If this parameter is provided, the
``shape`` and ``dtype`` parameters must be identical to ``data.shape`` and ``data.dtype``,
or ``None``.
chunks : ChunkCoords, optional
Chunk shape of the array.
If not specified, default are guessed based on the shape and dtype.
Expand Down Expand Up @@ -2474,6 +2488,11 @@ def create_array(
Whether to overwrite an array with the same name in the store, if one exists.
config : ArrayConfig or ArrayConfigLike, optional
Runtime configuration for the array.
write_data : bool
If a pre-existing array-like object was provided to this function via the ``data`` parameter
then ``write_data`` determines whether the values in that array-like object should be
written to the Zarr array created by this function. If ``write_data`` is ``False``, then the
array will be left empty.

Returns
-------
Expand All @@ -2488,6 +2507,7 @@ def create_array(
name=name,
shape=shape,
dtype=dtype,
data=data,
chunks=chunks,
shards=shards,
fill_value=fill_value,
Expand All @@ -2501,6 +2521,7 @@ def create_array(
overwrite=overwrite,
storage_options=storage_options,
config=config,
write_data=write_data,
)
)
)
Expand Down
25 changes: 25 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import inspect
import pathlib
import re
from typing import TYPE_CHECKING

Expand All @@ -8,6 +10,7 @@

if TYPE_CHECKING:
import pathlib
from collections.abc import Callable

from zarr.abc.store import Store
from zarr.core.common import JSON, MemoryOrder, ZarrFormat
Expand Down Expand Up @@ -1183,6 +1186,28 @@ def test_open_array_with_mode_r_plus(store: Store, zarr_format: ZarrFormat) -> N
z2[:] = 3


@pytest.mark.parametrize(
("a_func", "b_func"),
[
(zarr.api.asynchronous.create_array, zarr.api.synchronous.create_array),
(zarr.api.asynchronous.save, zarr.api.synchronous.save),
(zarr.api.asynchronous.save_array, zarr.api.synchronous.save_array),
(zarr.api.asynchronous.save_group, zarr.api.synchronous.save_group),
(zarr.api.asynchronous.open_group, zarr.api.synchronous.open_group),
(zarr.api.asynchronous.create, zarr.api.synchronous.create),
],
)
def test_consistent_signatures(
a_func: Callable[[object], object], b_func: Callable[[object], object]
) -> None:
"""
Ensure that pairs of functions have the same signature
"""
base_sig = inspect.signature(a_func)
test_sig = inspect.signature(b_func)
assert test_sig.parameters == base_sig.parameters


def test_api_exports() -> None:
"""
Test that the sync API and the async API export the same objects
Expand Down
56 changes: 56 additions & 0 deletions tests/test_array.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import inspect
import json
import math
import multiprocessing as mp
Expand Down Expand Up @@ -987,6 +988,42 @@ def test_auto_partition_auto_shards(
assert auto_shards == expected_shards


def test_chunks_and_shards() -> None:
store = StorePath(MemoryStore())
shape = (100, 100)
chunks = (5, 5)
shards = (10, 10)

arr_v3 = zarr.create_array(store=store / "v3", shape=shape, chunks=chunks, dtype="i4")
assert arr_v3.chunks == chunks
assert arr_v3.shards is None

arr_v3_sharding = zarr.create_array(
store=store / "v3_sharding",
shape=shape,
chunks=chunks,
shards=shards,
dtype="i4",
)
assert arr_v3_sharding.chunks == chunks
assert arr_v3_sharding.shards == shards

arr_v2 = zarr.create_array(
store=store / "v2", shape=shape, chunks=chunks, zarr_format=2, dtype="i4"
)
assert arr_v2.chunks == chunks
assert arr_v2.shards is None


@pytest.mark.parametrize("store", ["memory"], indirect=True)
@pytest.mark.parametrize(
("dtype", "fill_value_expected"), [("<U4", ""), ("<S4", b""), ("i", 0), ("f", 0.0)]
)
def test_default_fill_value(dtype: str, fill_value_expected: object, store: Store) -> None:
a = zarr.create_array(store, shape=(5,), chunks=(5,), dtype=dtype)
assert a.fill_value == fill_value_expected


@pytest.mark.parametrize("store", ["memory"], indirect=True)
class TestCreateArray:
@staticmethod
Expand Down Expand Up @@ -1708,6 +1745,25 @@ def test_multiprocessing(store: Store, method: Literal["fork", "spawn", "forkser
assert all(np.array_equal(r, data) for r in results)


def test_create_array_method_signature() -> None:
"""
Test that the signature of the ``AsyncGroup.create_array`` function has nearly the same signature
as the ``create_array`` function. ``AsyncGroup.create_array`` should take all of the same keyword
arguments as ``create_array`` except ``store``.
"""

base_sig = inspect.signature(create_array)
meth_sig = inspect.signature(AsyncGroup.create_array)
# ignore keyword arguments that are either missing or have different semantics when
# create_array is invoked as a group method
ignore_kwargs = {"zarr_format", "store", "name"}
# TODO: make this test stronger. right now, it only checks that all the parameters in the
# function signature are used in the method signature. we can be more strict and check that
# the method signature uses no extra parameters.
base_params = dict(filter(lambda kv: kv[0] not in ignore_kwargs, base_sig.parameters.items()))
assert (set(base_params.items()) - set(meth_sig.parameters.items())) == set()


async def test_sharding_coordinate_selection() -> None:
store = MemoryStore()
g = zarr.open_group(store, mode="w")
Expand Down
1 change: 1 addition & 0 deletions tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,7 @@ def test_create_nodes_concurrency_limit(store: MemoryStore) -> None:
@pytest.mark.parametrize(
("a_func", "b_func"),
[
(zarr.core.group.AsyncGroup.create_array, zarr.core.group.Group.create_array),
(zarr.core.group.AsyncGroup.create_hierarchy, zarr.core.group.Group.create_hierarchy),
(zarr.core.group.create_hierarchy, zarr.core.sync_group.create_hierarchy),
(zarr.core.group.create_nodes, zarr.core.sync_group.create_nodes),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_metadata/test_consolidated.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ async def test_consolidated_metadata_v2(self):
dtype=dtype,
attributes={"key": "a"},
chunks=(1,),
fill_value=0,
fill_value=None,
compressor=Blosc(),
order="C",
),
Expand Down
Loading