Skip to content

Fix some typing #9988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
NestedSequence,
ReadBuffer,
T_Chunks,
ZarrStoreLike,
)

T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"]
Expand Down Expand Up @@ -2100,7 +2101,7 @@ def save_mfdataset(
@overload
def to_zarr(
dataset: Dataset,
store: MutableMapping | str | os.PathLike[str] | None = None,
store: ZarrStoreLike | None = None,
chunk_store: MutableMapping | str | os.PathLike | None = None,
mode: ZarrWriteModes | None = None,
synchronizer=None,
Expand All @@ -2123,7 +2124,7 @@ def to_zarr(
@overload
def to_zarr(
dataset: Dataset,
store: MutableMapping | str | os.PathLike[str] | None = None,
store: ZarrStoreLike | None = None,
chunk_store: MutableMapping | str | os.PathLike | None = None,
mode: ZarrWriteModes | None = None,
synchronizer=None,
Expand All @@ -2144,7 +2145,7 @@ def to_zarr(

def to_zarr(
dataset: Dataset,
store: MutableMapping | str | os.PathLike[str] | None = None,
store: ZarrStoreLike | None = None,
chunk_store: MutableMapping | str | os.PathLike | None = None,
mode: ZarrWriteModes | None = None,
synchronizer=None,
Expand Down
17 changes: 7 additions & 10 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import struct
from collections.abc import Hashable, Iterable, Mapping
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -38,13 +38,10 @@
from xarray.namedarray.utils import module_available

if TYPE_CHECKING:
from zarr import Array as ZarrArray
from zarr import Group as ZarrGroup

from xarray.backends.common import AbstractDataStore
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
from xarray.core.types import ReadBuffer
from xarray.core.types import ReadBuffer, ZarrArray, ZarrGroup


def _get_mappers(*, storage_options, store, chunk_store):
Expand Down Expand Up @@ -108,8 +105,7 @@ def _choose_default_mode(


def _zarr_v3() -> bool:
# TODO: switch to "3" once Zarr V3 is released
return module_available("zarr", minversion="2.99")
return module_available("zarr", minversion="3")


# need some special secret attributes to tell us the dimensions
Expand Down Expand Up @@ -768,7 +764,7 @@ def __init__(
self._members = self._fetch_members()

@property
def members(self) -> dict[str, ZarrArray]:
def members(self) -> dict[str, ZarrArray | ZarrGroup]:
"""
Model the arrays and groups contained in self.zarr_group as a dict. If `self._cache_members`
is true, the dict is cached. Otherwise, it is retrieved from storage.
Expand All @@ -778,7 +774,7 @@ def members(self) -> dict[str, ZarrArray]:
else:
return self._members

def _fetch_members(self) -> dict[str, ZarrArray]:
def _fetch_members(self) -> dict[str, ZarrArray | ZarrGroup]:
"""
Get the arrays and groups defined in the zarr group modelled by this Store
"""
Expand Down Expand Up @@ -1035,6 +1031,7 @@ def sync(self):

def _open_existing_array(self, *, name) -> ZarrArray:
import zarr
from zarr import Array as ZarrArray

# TODO: if mode="a", consider overriding the existing variable
# metadata. This would need some case work properly with region
Expand Down Expand Up @@ -1066,7 +1063,7 @@ def _open_existing_array(self, *, name) -> ZarrArray:
else:
zarr_array = self.zarr_group[name]

return zarr_array
return cast(ZarrArray, zarr_array)

def _create_new_array(
self, *, name, shape, dtype, fill_value, encoding, attrs
Expand Down
9 changes: 6 additions & 3 deletions xarray/core/datatree_io.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import annotations

from collections.abc import Mapping, MutableMapping
from collections.abc import Mapping
from os import PathLike
from typing import Any, Literal, get_args
from typing import TYPE_CHECKING, Any, Literal, get_args

from xarray.core.datatree import DataTree
from xarray.core.types import NetcdfWriteModes, ZarrWriteModes

T_DataTreeNetcdfEngine = Literal["netcdf4", "h5netcdf"]
T_DataTreeNetcdfTypes = Literal["NETCDF4"]

if TYPE_CHECKING:
from xarray.core.types import ZarrStoreLike


def _datatree_to_netcdf(
dt: DataTree,
Expand Down Expand Up @@ -78,7 +81,7 @@ def _datatree_to_netcdf(

def _datatree_to_zarr(
dt: DataTree,
store: MutableMapping | str | PathLike[str],
store: ZarrStoreLike,
mode: ZarrWriteModes = "w-",
encoding: Mapping[str, Any] | None = None,
consolidated: bool = True,
Expand Down
9 changes: 8 additions & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,15 @@

try:
from zarr import Array as ZarrArray
from zarr import Group as ZarrGroup
except ImportError:
ZarrArray = np.ndarray
ZarrArray = np.ndarray # type: ignore[misc, assignment, unused-ignore]
ZarrGroup = Any # type: ignore[misc, assignment, unused-ignore]
try:
# this is V3 only
from zarr.storage import StoreLike as ZarrStoreLike
except ImportError:
ZarrStoreLike = Any # type: ignore[misc, assignment, unused-ignore]

# Anything that can be coerced to a shape tuple
_ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]]
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_backends_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def test_to_zarr_zip_store(self, tmpdir, simple_datatree):
store = ZipStore(filepath)
original_dt.to_zarr(store)

with open_datatree(store, engine="zarr") as roundtrip_dt:
with open_datatree(store, engine="zarr") as roundtrip_dt: # type: ignore[arg-type, unused-ignore]
assert_equal(original_dt, roundtrip_dt)

def test_to_zarr_not_consolidated(self, tmpdir, simple_datatree):
Expand Down
Loading