diff --git a/pandas_dataclasses/core/parsers.py b/pandas_dataclasses/core/parsers.py index b377449..9afa87a 100644 --- a/pandas_dataclasses/core/parsers.py +++ b/pandas_dataclasses/core/parsers.py @@ -42,7 +42,7 @@ def asdataframe(obj: Any, *, factory: Any = None) -> Any: attrs = get_attrs(spec) data = get_data(spec) index = get_index(spec) - columns = get_columns(spec) + names = get_columns(spec) if factory is None: factory = spec.factory or pd.DataFrame @@ -50,7 +50,8 @@ def asdataframe(obj: Any, *, factory: Any = None) -> Any: if not issubclass(factory, pd.DataFrame): raise TypeError("Factory must be a subclass of DataFrame.") - dataframe = factory(data, index, columns) + dataframe = factory(data, index) + dataframe.columns.names = names dataframe.attrs.update(attrs) return dataframe @@ -109,31 +110,22 @@ def get_attrs(spec: Spec) -> "dict[Hashable, Any]": attrs: "dict[Hashable, Any]" = {} for field in spec.fields.of_attr: - attrs[field.hashable_name] = field.default + attrs[field.name] = field.default return attrs -def get_columns(spec: Spec) -> Optional[pd.Index]: - """Derive columns from a specification.""" - if not spec.fields.of_data: - return - - names_ = [field.name for field in spec.fields.of_data] - - if all(isinstance(name, Hashable) for name in names_): - return - - if not all(isinstance(name, dict) for name in names_): - raise ValueError("All names must be dictionaries.") +def get_columns(spec: Spec) -> "list[Hashable]": + """Derive column names from a specification.""" + objs: "dict[Hashable, Any]" = {} - names = [tuple(name.keys()) for name in names_] # type: ignore - indexes = [tuple(name.values()) for name in names_] # type: ignore + for field in spec.fields.of_column: + objs[field.name] = field.default - if not len(set(names)) == 1: - raise ValueError("All name keys must be same.") - - return pd.MultiIndex.from_tuples(indexes, names=names[0]) + if not objs: + return [None] + else: + return list(objs.keys()) def get_data(spec: Spec) -> Optional["dict[Hashable, Any]"]: @@ -145,7 +137,7 @@ def get_data(spec: Spec) -> Optional["dict[Hashable, Any]"]: data: List[Any] = [] for field in spec.fields.of_data: - names.append(field.hashable_name) + names.append(field.name) data.append(ensure(field.default, field.dtype)) return dict(zip(names, data)) @@ -160,7 +152,7 @@ def get_index(spec: Spec) -> Optional[pd.Index]: indexes: List[Any] = [] for field in spec.fields.of_index: - names.append(field.hashable_name) + names.append(field.name) indexes.append(ensure(field.default, field.dtype)) indexes = np.broadcast_arrays(*indexes) diff --git a/pandas_dataclasses/core/specs.py b/pandas_dataclasses/core/specs.py index 8438ece..9bc02c2 100644 --- a/pandas_dataclasses/core/specs.py +++ b/pandas_dataclasses/core/specs.py @@ -13,16 +13,7 @@ # submodules -from .typing import ( - P, - AnyName, - AnyPandas, - DataClass, - Role, - get_dtype, - get_name, - get_role, -) +from .typing import P, AnyPandas, DataClass, Role, get_dtype, get_name, get_role # runtime classes @@ -33,10 +24,10 @@ class Field: id: str """Identifier of the field.""" - name: AnyName + name: Hashable """Name of the field.""" - role: Literal["attr", "data", "index"] + role: Literal["attr", "column", "data", "index"] """Role of the field.""" type: Optional[Any] @@ -48,20 +39,12 @@ class Field: default: Any """Default value of the field data.""" - @property - def hashable_name(self) -> Hashable: - """Hashable name of the field.""" - if isinstance(self.name, dict): - return tuple(self.name.values()) - else: - return self.name - def update(self, obj: DataClass[P]) -> "Field": """Update the specification by a dataclass object.""" return replace( self, name=format_name(self.name, obj), - default=getattr(obj, self.id), + default=getattr(obj, self.id, self.default), ) @@ -73,6 +56,11 @@ def of_attr(self) -> "Fields": """Select only attribute field specifications.""" return Fields(field for field in self if field.role == "attr") + @property + def of_column(self) -> "Fields": + """Select only column field specifications.""" + return Fields(field for field in self if field.role == "column") + @property def of_data(self) -> "Fields": """Select only data field specifications.""" @@ -129,6 +117,8 @@ def convert_field(field_: "Field_[Any]") -> Optional[Field]: if role is Role.ATTR: role = "attr" + elif role is Role.COLUMN: + role = "column" elif role is Role.DATA: role = "data" elif role is Role.INDEX: @@ -157,19 +147,12 @@ def eval_types(dataclass: Type[DataClass[P]]) -> Type[DataClass[P]]: return dataclass -def format_name(name: AnyName, obj: DataClass[P]) -> AnyName: +def format_name(name: Hashable, obj: DataClass[P]) -> Hashable: """Format a name by a dataclass object.""" + if isinstance(name, tuple): + return type(name)(format_name(elem, obj) for elem in name) + if isinstance(name, str): return name.format(obj) - if isinstance(name, dict): - formatted: "dict[Hashable, Hashable]" = {} - - for key, val in name.items(): - key = key.format(obj) if isinstance(key, str) else key - val = val.format(obj) if isinstance(val, str) else val - formatted[key] = val - - return formatted - return name diff --git a/pandas_dataclasses/core/typing.py b/pandas_dataclasses/core/typing.py index ab43d1c..d3e9de7 100644 --- a/pandas_dataclasses/core/typing.py +++ b/pandas_dataclasses/core/typing.py @@ -1,4 +1,4 @@ -__all__ = ["Attr", "Data", "Index", "Other"] +__all__ = ["Attr", "Column", "Data", "Index", "Other"] # standard library @@ -12,7 +12,6 @@ Hashable, Iterable, Optional, - Tuple, Type, TypeVar, ) @@ -34,7 +33,6 @@ # type hints (private) -AnyName: TypeAlias = "Hashable | dict[Hashable, Hashable]" AnyPandas: TypeAlias = "pd.DataFrame | pd.Series" P = ParamSpec("P") T = TypeVar("T") @@ -63,6 +61,9 @@ class Role(Enum): ATTR = auto() """Annotation for attribute fields.""" + COLUMN = auto() + """Annotation for column fields.""" + DATA = auto() """Annotation for data fields.""" @@ -82,6 +83,9 @@ def annotates(cls, tp: Any) -> bool: Attr = Annotated[T, Role.ATTR] """Type hint for attribute fields (``Attr[T]``).""" +Column = Annotated[T, Role.COLUMN] +"""Type hint for column fields (``Column[T]``).""" + Data = Annotated[Collection[T], Role.DATA] """Type hint for data fields (``Data[T]``).""" @@ -121,7 +125,7 @@ def get_annotated(tp: Any) -> Any: raise TypeError("Could not find any role-annotated type.") -def get_annotations(tp: Any) -> Tuple[Any, ...]: +def get_annotations(tp: Any) -> "tuple[Any, ...]": """Extract annotations of the first role-annotated type.""" for annotated in filter(Role.annotates, find_annotated(tp)): return get_args(annotated)[1:] @@ -145,24 +149,19 @@ def get_dtype(tp: Any) -> Optional[str]: return pandas_dtype(dtype).name -def get_name(tp: Any, default: AnyName = None) -> AnyName: +def get_name(tp: Any, default: Hashable = None) -> Hashable: """Extract a name if found or return given default.""" try: name = get_annotations(tp)[1] except (IndexError, TypeError): return default - if isinstance(name, Hashable): - return name - - if ( - isinstance(name, dict) - and all(isinstance(key, Hashable) for key in name.keys()) - and all(isinstance(val, Hashable) for val in name.values()) - ): - return dict(name) + try: + hash(name) + except TypeError: + raise ValueError("Could not find any valid name.") - raise ValueError("Could not find any valid name.") + return name def get_role(tp: Any, default: Role = Role.OTHER) -> Role: diff --git a/tests/data.py b/tests/data.py index c03f21f..82861ce 100644 --- a/tests/data.py +++ b/tests/data.py @@ -7,21 +7,16 @@ # standard library -from dataclasses import dataclass +from dataclasses import dataclass, field # dependencies import pandas as pd -from pandas_dataclasses import Attr, Data, Index +from pandas_dataclasses import Attr, Column, Data, Index from typing_extensions import Annotated as Ann # test dataclass and object -def name(stat: str, cat: str) -> "dict[str, str]": - """Shorthand function for data names.""" - return {"Statistic": stat, "Category": cat} - - @dataclass class Weather: """Weather information.""" @@ -32,16 +27,22 @@ class Weather: month: Ann[Index[int], "Month"] """Month of the measured time.""" - temp_avg: Ann[Data[float], name("Temperature ({.temp_unit})", "Average")] + meas: Ann[Column[None], "Measurement"] = field(init=False, repr=False) + """Name of the measurement.""" + + stat: Ann[Column[None], "Statistic"] = field(init=False, repr=False) + """Name of the statistic.""" + + temp_avg: Ann[Data[float], ("Temperature ({.temp_unit})", "Average")] """Monthly average temperature with given units.""" - temp_max: Ann[Data[float], name("Temperature ({.temp_unit})", "Maximum")] + temp_max: Ann[Data[float], ("Temperature ({.temp_unit})", "Maximum")] """Monthly maximum temperature with given units.""" - wind_avg: Ann[Data[float], name("Wind speed ({.wind_unit})", "Average")] + wind_avg: Ann[Data[float], ("Wind speed ({.wind_unit})", "Average")] """Monthly average wind speed with given units.""" - wind_max: Ann[Data[float], name("Wind speed ({.wind_unit})", "Maximum")] + wind_max: Ann[Data[float], ("Wind speed ({.wind_unit})", "Maximum")] """Monthly maximum wind speed with given units.""" loc: Ann[Attr[str], "Location"] = "Tokyo" @@ -98,7 +99,7 @@ class Weather: ("Wind speed (m/s)", "Average"), ("Wind speed (m/s)", "Maximum"), ], - names=("Statistic", "Category"), + names=("Measurement", "Statistic"), ), ) df_weather_true.attrs = { diff --git a/tests/test_mixins.py b/tests/test_mixins.py index 50c36b4..eec4f1a 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -1,5 +1,5 @@ # standard library -from dataclasses import asdict, dataclass +from dataclasses import dataclass # dependencies @@ -39,28 +39,56 @@ class CustomSeriesWeather(Weather, As[CustomSeries]): # test functions def test_dataframe_weather() -> None: - df_weather = DataFrameWeather.new(**asdict(weather)) + df_weather = DataFrameWeather.new( + year=weather.year, + month=weather.month, + temp_avg=weather.temp_avg, + temp_max=weather.temp_max, + wind_avg=weather.wind_avg, + wind_max=weather.wind_max, + ) assert isinstance(df_weather, pd.DataFrame) assert (df_weather == df_weather_true).all().all() def test_custom_dataframe_weather() -> None: - df_weather = CustomDataFrameWeather.new(**asdict(weather)) + df_weather = CustomDataFrameWeather.new( + year=weather.year, + month=weather.month, + temp_avg=weather.temp_avg, + temp_max=weather.temp_max, + wind_avg=weather.wind_avg, + wind_max=weather.wind_max, + ) assert isinstance(df_weather, CustomDataFrame) assert (df_weather == df_weather_true).all().all() def test_series_weather() -> None: - ser_weather = SeriesWeather.new(**asdict(weather)) + ser_weather = SeriesWeather.new( + year=weather.year, + month=weather.month, + temp_avg=weather.temp_avg, + temp_max=weather.temp_max, + wind_avg=weather.wind_avg, + wind_max=weather.wind_max, + ) assert isinstance(ser_weather, pd.Series) assert (ser_weather == ser_weather_true).all() def test_custom_series_weather() -> None: - ser_weather = CustomSeriesWeather.new(**asdict(weather)) + ser_weather = CustomSeriesWeather.new( + year=weather.year, + month=weather.month, + temp_avg=weather.temp_avg, + temp_max=weather.temp_max, + wind_avg=weather.wind_avg, + wind_max=weather.wind_max, + ) assert isinstance(ser_weather, CustomSeries) assert (ser_weather == ser_weather_true).all() diff --git a/tests/test_parsers.py b/tests/test_parsers.py index ba6dd6a..fb85618 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -58,9 +58,9 @@ def test_get_attrs() -> None: keys = list(attrs.keys()) values = list(attrs.values()) - assert keys[0] == spec.fields.of_attr[0].hashable_name - assert keys[1] == spec.fields.of_attr[1].hashable_name - assert keys[2] == spec.fields.of_attr[2].hashable_name + assert keys[0] == spec.fields.of_attr[0].name + assert keys[1] == spec.fields.of_attr[1].name + assert keys[2] == spec.fields.of_attr[2].name assert values[0] == spec.fields.of_attr[0].default assert values[1] == spec.fields.of_attr[1].default @@ -68,13 +68,10 @@ def test_get_attrs() -> None: def test_get_columns() -> None: - index = cast(pd.Index, get_columns(spec)) + columns = get_columns(spec) - assert index.names == list(spec.fields.of_data[0].name) # type: ignore - assert index[0] == spec.fields.of_data[0].hashable_name - assert index[1] == spec.fields.of_data[1].hashable_name - assert index[2] == spec.fields.of_data[2].hashable_name - assert index[3] == spec.fields.of_data[3].hashable_name + assert columns[0] == spec.fields.of_column[0].name + assert columns[1] == spec.fields.of_column[1].name def test_get_data() -> None: @@ -82,10 +79,10 @@ def test_get_data() -> None: keys = list(data.keys()) values = list(data.values()) - assert keys[0] == spec.fields.of_data[0].hashable_name - assert keys[1] == spec.fields.of_data[1].hashable_name - assert keys[2] == spec.fields.of_data[2].hashable_name - assert keys[3] == spec.fields.of_data[3].hashable_name + assert keys[0] == spec.fields.of_data[0].name + assert keys[1] == spec.fields.of_data[1].name + assert keys[2] == spec.fields.of_data[2].name + assert keys[3] == spec.fields.of_data[3].name assert values[0].dtype.name == spec.fields.of_data[0].dtype assert values[1].dtype.name == spec.fields.of_data[1].dtype @@ -102,8 +99,8 @@ def test_get_index() -> None: index = cast(pd.Index, get_index(spec)) df = cast(pd.DataFrame, index.to_frame()) - assert df.iloc[:, 0].name == spec.fields.of_index[0].hashable_name - assert df.iloc[:, 1].name == spec.fields.of_index[1].hashable_name + assert df.iloc[:, 0].name == spec.fields.of_index[0].name + assert df.iloc[:, 1].name == spec.fields.of_index[1].name assert df.iloc[:, 0].dtype == spec.fields.of_index[0].dtype assert df.iloc[:, 1].dtype == spec.fields.of_index[1].dtype diff --git a/tests/test_specs.py b/tests/test_specs.py index 3e582b8..8d8ee84 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -13,10 +13,6 @@ # test functions -def name(stat: str, cat: str) -> "dict[str, str]": - return {"Statistic": stat, "Category": cat} - - def test_year() -> None: field = spec.fields.of_index[0] @@ -57,11 +53,47 @@ def test_month_updated() -> None: assert field.default == [1, 7, 1, 7, 1] +def test_meas() -> None: + field = spec.fields.of_column[0] + + assert field.id == "meas" + assert field.name == "Measurement" + assert field.role == "column" + assert field.default is MISSING + + +def test_meas_updated() -> None: + field = spec_updated.fields.of_column[0] + + assert field.id == "meas" + assert field.name == "Measurement" + assert field.role == "column" + assert field.default is MISSING + + +def test_stat() -> None: + field = spec.fields.of_column[1] + + assert field.id == "stat" + assert field.name == "Statistic" + assert field.role == "column" + assert field.default is MISSING + + +def test_stat_updated() -> None: + field = spec_updated.fields.of_column[1] + + assert field.id == "stat" + assert field.name == "Statistic" + assert field.role == "column" + assert field.default is MISSING + + def test_temp_avg() -> None: field = spec.fields.of_data[0] assert field.id == "temp_avg" - assert field.name == name("Temperature ({.temp_unit})", "Average") + assert field.name == ("Temperature ({.temp_unit})", "Average") assert field.role == "data" assert field.dtype == "float64" assert field.default is MISSING @@ -71,7 +103,7 @@ def test_temp_avg_updated() -> None: field = spec_updated.fields.of_data[0] assert field.id == "temp_avg" - assert field.name == name("Temperature (deg C)", "Average") + assert field.name == ("Temperature (deg C)", "Average") assert field.role == "data" assert field.dtype == "float64" assert field.default == [7.1, 24.3, 5.4, 25.9, 4.9] @@ -81,7 +113,7 @@ def test_temp_max() -> None: field = spec.fields.of_data[1] assert field.id == "temp_max" - assert field.name == name("Temperature ({.temp_unit})", "Maximum") + assert field.name == ("Temperature ({.temp_unit})", "Maximum") assert field.role == "data" assert field.dtype == "float64" assert field.default is MISSING @@ -91,7 +123,7 @@ def test_temp_max_updated() -> None: field = spec_updated.fields.of_data[1] assert field.id == "temp_max" - assert field.name == name("Temperature (deg C)", "Maximum") + assert field.name == ("Temperature (deg C)", "Maximum") assert field.role == "data" assert field.dtype == "float64" assert field.default == [11.1, 27.7, 10.3, 30.3, 9.4] @@ -101,7 +133,7 @@ def test_wind_avg() -> None: field = spec.fields.of_data[2] assert field.id == "wind_avg" - assert field.name == name("Wind speed ({.wind_unit})", "Average") + assert field.name == ("Wind speed ({.wind_unit})", "Average") assert field.role == "data" assert field.dtype == "float64" assert field.default is MISSING @@ -111,7 +143,7 @@ def test_wind_avg_updated() -> None: field = spec_updated.fields.of_data[2] assert field.id == "wind_avg" - assert field.name == name("Wind speed (m/s)", "Average") + assert field.name == ("Wind speed (m/s)", "Average") assert field.role == "data" assert field.dtype == "float64" assert field.default == [2.4, 3.1, 2.3, 2.4, 2.6] @@ -121,7 +153,7 @@ def test_wind_max() -> None: field = spec.fields.of_data[3] assert field.id == "wind_max" - assert field.name == name("Wind speed ({.wind_unit})", "Maximum") + assert field.name == ("Wind speed ({.wind_unit})", "Maximum") assert field.role == "data" assert field.dtype == "float64" assert field.default is MISSING @@ -131,7 +163,7 @@ def test_wind_max_updated() -> None: field = spec_updated.fields.of_data[3] assert field.id == "wind_max" - assert field.name == name("Wind speed (m/s)", "Maximum") + assert field.name == ("Wind speed (m/s)", "Maximum") assert field.role == "data" assert field.dtype == "float64" assert field.default == [8.8, 10.2, 10.7, 9.0, 8.8] diff --git a/tests/test_typing.py b/tests/test_typing.py index 4079b3c..655ad33 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -46,9 +46,6 @@ (Ann[Attr[Any], "attr"], "attr"), (Ann[Data[Any], "data"], "data"), (Ann[Index[Any], "index"], "index"), - (Ann[Attr[Any], dict(name="attr")], dict(name="attr")), - (Ann[Data[Any], dict(name="data")], dict(name="data")), - (Ann[Index[Any], dict(name="index")], dict(name="index")), (Ann[Any, "other"], None), (Union[Ann[Attr[Any], "attr"], Ann[Any, "any"]], "attr"), (Union[Ann[Data[Any], "data"], Ann[Any, "any"]], "data"),