Skip to content

Commit 6c96e8e

Browse files
authored
Merge pull request #107 from astropenguin/astropenguin/issue39
Update NumPy-like zeros, ones, ... to be compatible with sizes
2 parents ac6b3d2 + a8494c0 commit 6c96e8e

File tree

5 files changed

+130
-94
lines changed

5 files changed

+130
-94
lines changed

tests/test_datamodel.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,44 +57,51 @@ class ColorImage:
5757
def test_xaxis_attr() -> None:
5858
assert xaxis_model.attr[0].name == "units"
5959
assert xaxis_model.attr[0].value == "pixel"
60-
assert xaxis_model.attr[0].type == "str"
60+
assert xaxis_model.attr[0].type == "builtins.str"
6161

6262

6363
def test_xaxis_data() -> None:
6464
assert xaxis_model.data[0].name == "data"
6565
assert xaxis_model.data[0].type == {"dims": ("x",), "dtype": "int"}
66+
assert xaxis_model.data[0].factory is None
6667

6768

6869
def test_xaxis_name() -> None:
6970
assert xaxis_model.name[0].name == "name"
7071
assert xaxis_model.name[0].value == "x axis"
71-
assert xaxis_model.name[0].type == "str"
72+
assert xaxis_model.name[0].type == "builtins.str"
7273

7374

7475
def test_yaxis_attr() -> None:
7576
assert yaxis_model.attr[0].name == "units"
7677
assert yaxis_model.attr[0].value == "pixel"
77-
assert yaxis_model.attr[0].type == "str"
78+
assert yaxis_model.attr[0].type == "builtins.str"
7879

7980

8081
def test_yaxis_data() -> None:
8182
assert yaxis_model.data[0].name == "data"
8283
assert yaxis_model.data[0].type == {"dims": ("y",), "dtype": "int"}
84+
assert yaxis_model.data[0].factory is None
8385

8486

8587
def test_yaxis_name() -> None:
8688
assert yaxis_model.name[0].name == "name"
8789
assert yaxis_model.name[0].value == "y axis"
88-
assert yaxis_model.name[0].type == "str"
90+
assert yaxis_model.name[0].type == "builtins.str"
8991

9092

9193
def test_matrix_coord() -> None:
9294
assert image_model.coord[0].name == "mask"
9395
assert image_model.coord[0].type == {"dims": ("x", "y"), "dtype": "bool"}
96+
assert image_model.coord[0].factory is None
97+
9498
assert image_model.coord[1].name == "x"
95-
assert image_model.coord[1].type == "test_datamodel.XAxis"
99+
assert image_model.coord[1].type == {"dims": ("x",), "dtype": "int"}
100+
assert image_model.coord[1].factory is XAxis
101+
96102
assert image_model.coord[2].name == "y"
97-
assert image_model.coord[2].type == "test_datamodel.YAxis"
103+
assert image_model.coord[2].type == {"dims": ("y",), "dtype": "int"}
104+
assert image_model.coord[2].factory is YAxis
98105

99106

100107
def test_matrix_data() -> None:
@@ -104,8 +111,13 @@ def test_matrix_data() -> None:
104111

105112
def test_image_data() -> None:
106113
assert color_model.data[0].name == "red"
107-
assert color_model.data[0].type == "test_datamodel.Image"
114+
assert color_model.data[0].type == {"dims": ("x", "y"), "dtype": "float"}
115+
assert color_model.data[0].factory is Image
116+
108117
assert color_model.data[1].name == "green"
109-
assert color_model.data[1].type == "test_datamodel.Image"
118+
assert color_model.data[1].type == {"dims": ("x", "y"), "dtype": "float"}
119+
assert color_model.data[1].factory is Image
120+
110121
assert color_model.data[2].name == "blue"
111-
assert color_model.data[2].type == "test_datamodel.Image"
122+
assert color_model.data[2].type == {"dims": ("x", "y"), "dtype": "float"}
123+
assert color_model.data[2].factory is Image

tests/test_typing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
# submodules
11-
from xarray_dataclasses.typing import Data, get_dims, get_dtype
11+
from xarray_dataclasses.typing import Data, get_dims, get_dtype, unannotate
1212

1313

1414
# type hints
@@ -36,9 +36,9 @@
3636
# test functions
3737
@mark.parametrize("hint, dims", testdata_dims)
3838
def test_get_dims(hint: Any, dims: Any) -> None:
39-
assert get_dims(hint) == dims
39+
assert get_dims(unannotate(hint)) == dims
4040

4141

4242
@mark.parametrize("hint, dtype", testdata_dtype)
4343
def test_get_dtype(hint: Any, dtype: Any) -> None:
44-
assert get_dtype(hint) == dtype
44+
assert get_dtype(unannotate(hint)) == dtype

xarray_dataclasses/dataarray.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
TDataArray = TypeVar("TDataArray", bound=xr.DataArray)
2525
TDataArray_ = TypeVar("TDataArray_", bound=xr.DataArray, contravariant=True)
2626
Order = Literal["C", "F"]
27-
Shape = Union[Sequence[int], int]
27+
Shape = Union[Dict[str, int], Sequence[int], int]
2828

2929

3030
class DataClass(Protocol[P]):
@@ -161,7 +161,13 @@ def empty(
161161
DataArray object filled without initializing data.
162162
163163
"""
164-
name = DataModel.from_dataclass(cls).data[0].name
164+
model = DataModel.from_dataclass(cls)
165+
name = model.data[0].name
166+
dims = model.data[0].type["dims"]
167+
168+
if isinstance(shape, dict):
169+
shape = tuple(shape[dim] for dim in dims)
170+
165171
data = np.empty(shape, order=order)
166172
return asdataarray(cls(**{name: data}, **kwargs))
167173

@@ -184,7 +190,13 @@ def zeros(
184190
DataArray object filled with zeros.
185191
186192
"""
187-
name = DataModel.from_dataclass(cls).data[0].name
193+
model = DataModel.from_dataclass(cls)
194+
name = model.data[0].name
195+
dims = model.data[0].type["dims"]
196+
197+
if isinstance(shape, dict):
198+
shape = tuple(shape[dim] for dim in dims)
199+
188200
data = np.zeros(shape, order=order)
189201
return asdataarray(cls(**{name: data}, **kwargs))
190202

@@ -207,7 +219,13 @@ def ones(
207219
DataArray object filled with ones.
208220
209221
"""
210-
name = DataModel.from_dataclass(cls).data[0].name
222+
model = DataModel.from_dataclass(cls)
223+
name = model.data[0].name
224+
dims = model.data[0].type["dims"]
225+
226+
if isinstance(shape, dict):
227+
shape = tuple(shape[dim] for dim in dims)
228+
211229
data = np.ones(shape, order=order)
212230
return asdataarray(cls(**{name: data}, **kwargs))
213231

@@ -232,6 +250,12 @@ def full(
232250
DataArray object filled with given value.
233251
234252
"""
235-
name = DataModel.from_dataclass(cls).data[0].name
253+
model = DataModel.from_dataclass(cls)
254+
name = model.data[0].name
255+
dims = model.data[0].type["dims"]
256+
257+
if isinstance(shape, dict):
258+
shape = tuple(shape[dim] for dim in dims)
259+
236260
data = np.full(shape, fill_value, order=order)
237261
return asdataarray(cls(**{name: data}, **kwargs))

xarray_dataclasses/datamodel.py

Lines changed: 64 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33

44
# standard library
5-
from dataclasses import Field, InitVar, dataclass, field, is_dataclass
6-
from typing import Any, List, Type, Union, cast
5+
from dataclasses import Field, dataclass, field, is_dataclass
6+
from typing import Any, List, Optional, Type, Union, cast
77

88

99
# dependencies
@@ -22,91 +22,99 @@
2222
FieldType,
2323
get_dims,
2424
get_dtype,
25-
get_first,
26-
get_repr,
25+
get_inner,
26+
unannotate,
2727
)
2828

2929

3030
# type hints
31+
DataType = TypedDict("DataType", dims=Dims, dtype=Dtype)
3132
Reference = Union[xr.DataArray, xr.Dataset, None]
32-
DataTypes = TypedDict("DataTypes", dims=Dims, dtype=Dtype)
3333

3434

3535
# field models
36-
@dataclass
36+
@dataclass(frozen=True)
3737
class Data:
38-
"""Model for the coord or data fields."""
38+
"""Field model for data-related fields."""
3939

4040
name: str
41-
type: DataTypes
42-
value: Any
41+
"""Name of the field."""
4342

44-
def __call__(self, reference: Reference = None) -> xr.DataArray:
45-
"""Create a DataArray object from the value and a reference."""
46-
return typedarray(
47-
self.value,
48-
self.type["dims"],
49-
self.type["dtype"],
50-
reference,
51-
)
52-
53-
@classmethod
54-
def from_field(cls, field: Field[Any], value: Any) -> "Data":
55-
"""Create a field model from a dataclass field and a value."""
56-
return cls(
57-
field.name,
58-
{
59-
"dims": get_dims(field.type),
60-
"dtype": get_dtype(field.type),
61-
},
62-
value,
63-
)
64-
65-
66-
@dataclass
67-
class Dataof:
68-
"""Model for the coordof or dataof fields."""
69-
70-
name: str
71-
type: str
7243
value: Any
73-
dataclass: InitVar[Type[DataClass]]
44+
"""Value assigned to the field."""
45+
46+
type: DataType
47+
"""Type (dims and dtype) of the field."""
7448

75-
def __post_init__(self, dataclass: Type[DataClass]) -> None:
76-
self._dataclass = dataclass
49+
factory: Optional[Type[DataClass]] = None
50+
"""Factory dataclass to create a DataArray object."""
7751

7852
def __call__(self, reference: Reference = None) -> xr.DataArray:
7953
"""Create a DataArray object from the value and a reference."""
8054
from .dataarray import asdataarray
8155

56+
if self.factory is None:
57+
return typedarray(
58+
self.value,
59+
self.type["dims"],
60+
self.type["dtype"],
61+
reference,
62+
)
63+
8264
if is_dataclass(self.value):
8365
return asdataarray(self.value, reference)
8466
else:
85-
return asdataarray(self._dataclass(self.value), reference)
67+
return asdataarray(self.factory(self.value), reference)
8668

8769
@classmethod
88-
def from_field(cls, field: Field[Any], value: Any) -> "Dataof":
70+
def from_field(cls, field: Field[Any], value: Any, of: bool) -> "Data":
8971
"""Create a field model from a dataclass field and a value."""
90-
dataclass = get_first(field.type)
91-
return cls(field.name, get_repr(dataclass), value, dataclass)
72+
hint = unannotate(field.type)
9273

74+
if of:
75+
dataclass = get_inner(hint, 0)
76+
data = DataModel.from_dataclass(dataclass).data[0]
77+
return cls(field.name, value, data.type, dataclass)
78+
else:
79+
return cls(
80+
field.name,
81+
value,
82+
{"dims": get_dims(hint), "dtype": get_dtype(hint)},
83+
)
9384

94-
@dataclass
85+
86+
@dataclass(frozen=True)
9587
class General:
96-
"""Model for the attribute or name fields."""
88+
"""Field model for general fields."""
9789

9890
name: str
99-
type: str
91+
"""Name of the field."""
92+
10093
value: Any
94+
"""Value assigned to the field."""
95+
96+
type: str
97+
"""Type of the field."""
98+
99+
factory: Optional[Type[Any]] = None
100+
"""Factory function to create an object."""
101101

102102
def __call__(self) -> Any:
103-
"""Just return the value."""
104-
return self.value
103+
"""Create an object from the value."""
104+
if self.factory is None:
105+
return self.value
106+
else:
107+
return self.factory(self.value)
105108

106109
@classmethod
107110
def from_field(cls, field: Field[Any], value: Any) -> "General":
108111
"""Create a field model from a dataclass field and a value."""
109-
return cls(field.name, get_repr(field.type), value)
112+
hint = unannotate(field.type)
113+
114+
try:
115+
return cls(field.name, value, f"{hint.__module__}.{hint.__qualname__}")
116+
except AttributeError:
117+
return cls(field.name, value, repr(hint))
110118

111119

112120
# data models
@@ -117,10 +125,10 @@ class DataModel:
117125
attr: List[General] = field(default_factory=list)
118126
"""Model of the attribute fields."""
119127

120-
coord: List[Union[Data, Dataof]] = field(default_factory=list)
128+
coord: List[Data] = field(default_factory=list)
121129
"""Model of the coordinate fields."""
122130

123-
data: List[Union[Data, Dataof]] = field(default_factory=list)
131+
data: List[Data] = field(default_factory=list)
124132
"""Model of the data fields."""
125133

126134
name: List[General] = field(default_factory=list)
@@ -138,13 +146,13 @@ def from_dataclass(cls, dataclass: DataClass) -> "DataModel":
138146
if FieldType.ATTR.annotates(field_.type):
139147
model.attr.append(General.from_field(field_, value))
140148
elif FieldType.COORD.annotates(field_.type):
141-
model.coord.append(Data.from_field(field_, value))
149+
model.coord.append(Data.from_field(field_, value, False))
142150
elif FieldType.COORDOF.annotates(field_.type):
143-
model.coord.append(Dataof.from_field(field_, value))
151+
model.coord.append(Data.from_field(field_, value, True))
144152
elif FieldType.DATA.annotates(field_.type):
145-
model.data.append(Data.from_field(field_, value))
153+
model.data.append(Data.from_field(field_, value, False))
146154
elif FieldType.DATAOF.annotates(field_.type):
147-
model.data.append(Dataof.from_field(field_, value))
155+
model.data.append(Data.from_field(field_, value, True))
148156
elif FieldType.NAME.annotates(field_.type):
149157
model.name.append(General.from_field(field_, value))
150158

0 commit comments

Comments
 (0)