Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@


# type hints
X = Literal[DIMS[0]]
Y = Literal[DIMS[1]]
X = Literal["x"]
Y = Literal["y"]


# dataclasses
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@


# type hints
X = Literal[DIMS[0]]
Y = Literal[DIMS[1]]
X = Literal["x"]
Y = Literal["y"]


# dataclasses
Expand Down
1 change: 0 additions & 1 deletion xarray_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def _make_field_generic():
from . import deprecated
from . import datamodel
from . import typing
from . import utils


# aliases
Expand Down
77 changes: 41 additions & 36 deletions xarray_dataclasses/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@


# submodules
from .datamodel import DataModel
from .typing import Reference
from .datamodel import DataModel, Reference


# type hints
P = ParamSpec("P")
R = TypeVar("R", bound=xr.DataArray)
TDataArray = TypeVar("TDataArray", bound=xr.DataArray)
TDataArray_ = TypeVar("TDataArray_", bound=xr.DataArray, contravariant=True)
Order = Literal["C", "F"]
Shape = Union[Sequence[int], int]

Expand All @@ -34,30 +34,50 @@ class DataClass(Protocol[P]):
__dataclass_fields__: Dict[str, Field[Any]]


class DataArrayClass(Protocol[P, R]):
class DataArrayClass(Protocol[P, TDataArray_]):
"""Type hint for a dataclass object with a DataArray factory."""

__init__: Callable[P, None]
__dataclass_fields__: Dict[str, Field[Any]]
__dataarray_factory__: Callable[..., R]
__dataarray_factory__: Callable[..., TDataArray_]


# runtime classes
class classproperty:
"""Class property only for AsDataArray.new().

As a classmethod and a property can be chained together since Python 3.9,
this class will be removed when the support for Python 3.7 and 3.8 ends.

"""

def __init__(self, func: Callable[..., Callable[P, TDataArray]]) -> None:
self.__func__ = func

def __get__(
self,
obj: Any,
cls: Type[DataArrayClass[P, TDataArray]],
) -> Callable[P, TDataArray]:
return self.__func__(cls)


# runtime functions and classes
@overload
def asdataarray(
dataclass: DataArrayClass[Any, R],
dataclass: DataArrayClass[Any, TDataArray],
reference: Reference = None,
dataarray_factory: Any = xr.DataArray,
) -> R:
) -> TDataArray:
...


@overload
def asdataarray(
dataclass: DataClass[Any],
reference: Reference = None,
dataarray_factory: Callable[..., R] = xr.DataArray,
) -> R:
dataarray_factory: Callable[..., TDataArray] = xr.DataArray,
) -> TDataArray:
...


Expand Down Expand Up @@ -97,21 +117,6 @@ def asdataarray(
return dataarray


class classproperty:
"""Class property only for AsDataArray.new().

As a classmethod and a property can be chained together since Python 3.9,
this class will be removed when the support for Python 3.7 and 3.8 ends.

"""

def __init__(self, func: Callable[..., Callable[P, R]]) -> None:
self.__func__ = func

def __get__(self, obj: Any, cls: Type[DataArrayClass[P, R]]) -> Callable[P, R]:
return self.__func__(cls)


class AsDataArray:
"""Mix-in class that provides shorthand methods."""

Expand All @@ -120,30 +125,30 @@ def __dataarray_factory__(self, data: Any = None) -> xr.DataArray:
return xr.DataArray(data)

@classproperty
def new(cls: Type[DataArrayClass[P, R]]) -> Callable[P, R]:
def new(cls: Type[DataArrayClass[P, TDataArray]]) -> Callable[P, TDataArray]:
"""Create a DataArray object from dataclass parameters."""

init = copy(cls.__init__)
init.__annotations__["return"] = R
init.__annotations__["return"] = TDataArray
init.__doc__ = cls.__init__.__doc__

@wraps(init)
def new(
cls: Type[DataArrayClass[P, R]],
cls: Type[DataArrayClass[P, TDataArray]],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
) -> TDataArray:
return asdataarray(cls(*args, **kwargs))

return MethodType(new, cls)

@classmethod
def empty(
cls: Type[DataArrayClass[P, R]],
cls: Type[DataArrayClass[P, TDataArray]],
shape: Shape,
order: Order = "C",
**kwargs: Any,
) -> R:
) -> TDataArray:
"""Create a DataArray object without initializing data.

Args:
Expand All @@ -162,11 +167,11 @@ def empty(

@classmethod
def zeros(
cls: Type[DataArrayClass[P, R]],
cls: Type[DataArrayClass[P, TDataArray]],
shape: Shape,
order: Order = "C",
**kwargs: Any,
) -> R:
) -> TDataArray:
"""Create a DataArray object filled with zeros.

Args:
Expand All @@ -185,11 +190,11 @@ def zeros(

@classmethod
def ones(
cls: Type[DataArrayClass[P, R]],
cls: Type[DataArrayClass[P, TDataArray]],
shape: Shape,
order: Order = "C",
**kwargs: Any,
) -> R:
) -> TDataArray:
"""Create a DataArray object filled with ones.

Args:
Expand All @@ -208,12 +213,12 @@ def ones(

@classmethod
def full(
cls: Type[DataArrayClass[P, R]],
cls: Type[DataArrayClass[P, TDataArray]],
shape: Shape,
fill_value: Any,
order: Order = "C",
**kwargs: Any,
) -> R:
) -> TDataArray:
"""Create a DataArray object filled with given value.

Args:
Expand Down
Loading