Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions tests/test_datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,44 +57,51 @@ class ColorImage:
def test_xaxis_attr() -> None:
assert xaxis_model.attr[0].name == "units"
assert xaxis_model.attr[0].value == "pixel"
assert xaxis_model.attr[0].type == "str"
assert xaxis_model.attr[0].type == "builtins.str"


def test_xaxis_data() -> None:
assert xaxis_model.data[0].name == "data"
assert xaxis_model.data[0].type == {"dims": ("x",), "dtype": "int"}
assert xaxis_model.data[0].factory is None


def test_xaxis_name() -> None:
assert xaxis_model.name[0].name == "name"
assert xaxis_model.name[0].value == "x axis"
assert xaxis_model.name[0].type == "str"
assert xaxis_model.name[0].type == "builtins.str"


def test_yaxis_attr() -> None:
assert yaxis_model.attr[0].name == "units"
assert yaxis_model.attr[0].value == "pixel"
assert yaxis_model.attr[0].type == "str"
assert yaxis_model.attr[0].type == "builtins.str"


def test_yaxis_data() -> None:
assert yaxis_model.data[0].name == "data"
assert yaxis_model.data[0].type == {"dims": ("y",), "dtype": "int"}
assert yaxis_model.data[0].factory is None


def test_yaxis_name() -> None:
assert yaxis_model.name[0].name == "name"
assert yaxis_model.name[0].value == "y axis"
assert yaxis_model.name[0].type == "str"
assert yaxis_model.name[0].type == "builtins.str"


def test_matrix_coord() -> None:
assert image_model.coord[0].name == "mask"
assert image_model.coord[0].type == {"dims": ("x", "y"), "dtype": "bool"}
assert image_model.coord[0].factory is None

assert image_model.coord[1].name == "x"
assert image_model.coord[1].type == "test_datamodel.XAxis"
assert image_model.coord[1].type == {"dims": ("x",), "dtype": "int"}
assert image_model.coord[1].factory is XAxis

assert image_model.coord[2].name == "y"
assert image_model.coord[2].type == "test_datamodel.YAxis"
assert image_model.coord[2].type == {"dims": ("y",), "dtype": "int"}
assert image_model.coord[2].factory is YAxis


def test_matrix_data() -> None:
Expand All @@ -104,8 +111,13 @@ def test_matrix_data() -> None:

def test_image_data() -> None:
assert color_model.data[0].name == "red"
assert color_model.data[0].type == "test_datamodel.Image"
assert color_model.data[0].type == {"dims": ("x", "y"), "dtype": "float"}
assert color_model.data[0].factory is Image

assert color_model.data[1].name == "green"
assert color_model.data[1].type == "test_datamodel.Image"
assert color_model.data[1].type == {"dims": ("x", "y"), "dtype": "float"}
assert color_model.data[1].factory is Image

assert color_model.data[2].name == "blue"
assert color_model.data[2].type == "test_datamodel.Image"
assert color_model.data[2].type == {"dims": ("x", "y"), "dtype": "float"}
assert color_model.data[2].factory is Image
6 changes: 3 additions & 3 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


# submodules
from xarray_dataclasses.typing import Data, get_dims, get_dtype
from xarray_dataclasses.typing import Data, get_dims, get_dtype, unannotate


# type hints
Expand Down Expand Up @@ -36,9 +36,9 @@
# test functions
@mark.parametrize("hint, dims", testdata_dims)
def test_get_dims(hint: Any, dims: Any) -> None:
assert get_dims(hint) == dims
assert get_dims(unannotate(hint)) == dims


@mark.parametrize("hint, dtype", testdata_dtype)
def test_get_dtype(hint: Any, dtype: Any) -> None:
assert get_dtype(hint) == dtype
assert get_dtype(unannotate(hint)) == dtype
34 changes: 29 additions & 5 deletions xarray_dataclasses/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
TDataArray = TypeVar("TDataArray", bound=xr.DataArray)
TDataArray_ = TypeVar("TDataArray_", bound=xr.DataArray, contravariant=True)
Order = Literal["C", "F"]
Shape = Union[Sequence[int], int]
Shape = Union[Dict[str, int], Sequence[int], int]


class DataClass(Protocol[P]):
Expand Down Expand Up @@ -161,7 +161,13 @@ def empty(
DataArray object filled without initializing data.

"""
name = DataModel.from_dataclass(cls).data[0].name
model = DataModel.from_dataclass(cls)
name = model.data[0].name
dims = model.data[0].type["dims"]

if isinstance(shape, dict):
shape = tuple(shape[dim] for dim in dims)

data = np.empty(shape, order=order)
return asdataarray(cls(**{name: data}, **kwargs))

Expand All @@ -184,7 +190,13 @@ def zeros(
DataArray object filled with zeros.

"""
name = DataModel.from_dataclass(cls).data[0].name
model = DataModel.from_dataclass(cls)
name = model.data[0].name
dims = model.data[0].type["dims"]

if isinstance(shape, dict):
shape = tuple(shape[dim] for dim in dims)

data = np.zeros(shape, order=order)
return asdataarray(cls(**{name: data}, **kwargs))

Expand All @@ -207,7 +219,13 @@ def ones(
DataArray object filled with ones.

"""
name = DataModel.from_dataclass(cls).data[0].name
model = DataModel.from_dataclass(cls)
name = model.data[0].name
dims = model.data[0].type["dims"]

if isinstance(shape, dict):
shape = tuple(shape[dim] for dim in dims)

data = np.ones(shape, order=order)
return asdataarray(cls(**{name: data}, **kwargs))

Expand All @@ -232,6 +250,12 @@ def full(
DataArray object filled with given value.

"""
name = DataModel.from_dataclass(cls).data[0].name
model = DataModel.from_dataclass(cls)
name = model.data[0].name
dims = model.data[0].type["dims"]

if isinstance(shape, dict):
shape = tuple(shape[dim] for dim in dims)

data = np.full(shape, fill_value, order=order)
return asdataarray(cls(**{name: data}, **kwargs))
120 changes: 64 additions & 56 deletions xarray_dataclasses/datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@


# standard library
from dataclasses import Field, InitVar, dataclass, field, is_dataclass
from typing import Any, List, Type, Union, cast
from dataclasses import Field, dataclass, field, is_dataclass
from typing import Any, List, Optional, Type, Union, cast


# dependencies
Expand All @@ -22,91 +22,99 @@
FieldType,
get_dims,
get_dtype,
get_first,
get_repr,
get_inner,
unannotate,
)


# type hints
DataType = TypedDict("DataType", dims=Dims, dtype=Dtype)
Reference = Union[xr.DataArray, xr.Dataset, None]
DataTypes = TypedDict("DataTypes", dims=Dims, dtype=Dtype)


# field models
@dataclass
@dataclass(frozen=True)
class Data:
"""Model for the coord or data fields."""
"""Field model for data-related fields."""

name: str
type: DataTypes
value: Any
"""Name of the field."""

def __call__(self, reference: Reference = None) -> xr.DataArray:
"""Create a DataArray object from the value and a reference."""
return typedarray(
self.value,
self.type["dims"],
self.type["dtype"],
reference,
)

@classmethod
def from_field(cls, field: Field[Any], value: Any) -> "Data":
"""Create a field model from a dataclass field and a value."""
return cls(
field.name,
{
"dims": get_dims(field.type),
"dtype": get_dtype(field.type),
},
value,
)


@dataclass
class Dataof:
"""Model for the coordof or dataof fields."""

name: str
type: str
value: Any
dataclass: InitVar[Type[DataClass]]
"""Value assigned to the field."""

type: DataType
"""Type (dims and dtype) of the field."""

def __post_init__(self, dataclass: Type[DataClass]) -> None:
self._dataclass = dataclass
factory: Optional[Type[DataClass]] = None
"""Factory dataclass to create a DataArray object."""

def __call__(self, reference: Reference = None) -> xr.DataArray:
"""Create a DataArray object from the value and a reference."""
from .dataarray import asdataarray

if self.factory is None:
return typedarray(
self.value,
self.type["dims"],
self.type["dtype"],
reference,
)

if is_dataclass(self.value):
return asdataarray(self.value, reference)
else:
return asdataarray(self._dataclass(self.value), reference)
return asdataarray(self.factory(self.value), reference)

@classmethod
def from_field(cls, field: Field[Any], value: Any) -> "Dataof":
def from_field(cls, field: Field[Any], value: Any, of: bool) -> "Data":
"""Create a field model from a dataclass field and a value."""
dataclass = get_first(field.type)
return cls(field.name, get_repr(dataclass), value, dataclass)
hint = unannotate(field.type)

if of:
dataclass = get_inner(hint, 0)
data = DataModel.from_dataclass(dataclass).data[0]
return cls(field.name, value, data.type, dataclass)
else:
return cls(
field.name,
value,
{"dims": get_dims(hint), "dtype": get_dtype(hint)},
)

@dataclass

@dataclass(frozen=True)
class General:
"""Model for the attribute or name fields."""
"""Field model for general fields."""

name: str
type: str
"""Name of the field."""

value: Any
"""Value assigned to the field."""

type: str
"""Type of the field."""

factory: Optional[Type[Any]] = None
"""Factory function to create an object."""

def __call__(self) -> Any:
"""Just return the value."""
return self.value
"""Create an object from the value."""
if self.factory is None:
return self.value
else:
return self.factory(self.value)

@classmethod
def from_field(cls, field: Field[Any], value: Any) -> "General":
"""Create a field model from a dataclass field and a value."""
return cls(field.name, get_repr(field.type), value)
hint = unannotate(field.type)

try:
return cls(field.name, value, f"{hint.__module__}.{hint.__qualname__}")
except AttributeError:
return cls(field.name, value, repr(hint))


# data models
Expand All @@ -117,10 +125,10 @@ class DataModel:
attr: List[General] = field(default_factory=list)
"""Model of the attribute fields."""

coord: List[Union[Data, Dataof]] = field(default_factory=list)
coord: List[Data] = field(default_factory=list)
"""Model of the coordinate fields."""

data: List[Union[Data, Dataof]] = field(default_factory=list)
data: List[Data] = field(default_factory=list)
"""Model of the data fields."""

name: List[General] = field(default_factory=list)
Expand All @@ -138,13 +146,13 @@ def from_dataclass(cls, dataclass: DataClass) -> "DataModel":
if FieldType.ATTR.annotates(field_.type):
model.attr.append(General.from_field(field_, value))
elif FieldType.COORD.annotates(field_.type):
model.coord.append(Data.from_field(field_, value))
model.coord.append(Data.from_field(field_, value, False))
elif FieldType.COORDOF.annotates(field_.type):
model.coord.append(Dataof.from_field(field_, value))
model.coord.append(Data.from_field(field_, value, True))
elif FieldType.DATA.annotates(field_.type):
model.data.append(Data.from_field(field_, value))
model.data.append(Data.from_field(field_, value, False))
elif FieldType.DATAOF.annotates(field_.type):
model.data.append(Dataof.from_field(field_, value))
model.data.append(Data.from_field(field_, value, True))
elif FieldType.NAME.annotates(field_.type):
model.name.append(General.from_field(field_, value))

Expand Down
Loading