Skip to content

Commit

Permalink
#84 Merge pull request from astropenguin/astropenguin/issue82
Browse files Browse the repository at this point in the history
Add type hint for column fields
  • Loading branch information
astropenguin authored Sep 6, 2022
2 parents 6466bec + 7d18d25 commit dcf281d
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 117 deletions.
38 changes: 15 additions & 23 deletions pandas_dataclasses/core/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,16 @@ def asdataframe(obj: Any, *, factory: Any = None) -> Any:
attrs = get_attrs(spec)
data = get_data(spec)
index = get_index(spec)
columns = get_columns(spec)
names = get_columns(spec)

if factory is None:
factory = spec.factory or pd.DataFrame

if not issubclass(factory, pd.DataFrame):
raise TypeError("Factory must be a subclass of DataFrame.")

dataframe = factory(data, index, columns)
dataframe = factory(data, index)
dataframe.columns.names = names
dataframe.attrs.update(attrs)
return dataframe

Expand Down Expand Up @@ -109,31 +110,22 @@ def get_attrs(spec: Spec) -> "dict[Hashable, Any]":
attrs: "dict[Hashable, Any]" = {}

for field in spec.fields.of_attr:
attrs[field.hashable_name] = field.default
attrs[field.name] = field.default

return attrs


def get_columns(spec: Spec) -> Optional[pd.Index]:
"""Derive columns from a specification."""
if not spec.fields.of_data:
return

names_ = [field.name for field in spec.fields.of_data]

if all(isinstance(name, Hashable) for name in names_):
return

if not all(isinstance(name, dict) for name in names_):
raise ValueError("All names must be dictionaries.")
def get_columns(spec: Spec) -> "list[Hashable]":
"""Derive column names from a specification."""
objs: "dict[Hashable, Any]" = {}

names = [tuple(name.keys()) for name in names_] # type: ignore
indexes = [tuple(name.values()) for name in names_] # type: ignore
for field in spec.fields.of_column:
objs[field.name] = field.default

if not len(set(names)) == 1:
raise ValueError("All name keys must be same.")

return pd.MultiIndex.from_tuples(indexes, names=names[0])
if not objs:
return [None]
else:
return list(objs.keys())


def get_data(spec: Spec) -> Optional["dict[Hashable, Any]"]:
Expand All @@ -145,7 +137,7 @@ def get_data(spec: Spec) -> Optional["dict[Hashable, Any]"]:
data: List[Any] = []

for field in spec.fields.of_data:
names.append(field.hashable_name)
names.append(field.name)
data.append(ensure(field.default, field.dtype))

return dict(zip(names, data))
Expand All @@ -160,7 +152,7 @@ def get_index(spec: Spec) -> Optional[pd.Index]:
indexes: List[Any] = []

for field in spec.fields.of_index:
names.append(field.hashable_name)
names.append(field.name)
indexes.append(ensure(field.default, field.dtype))

indexes = np.broadcast_arrays(*indexes)
Expand Down
47 changes: 15 additions & 32 deletions pandas_dataclasses/core/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,7 @@


# submodules
from .typing import (
P,
AnyName,
AnyPandas,
DataClass,
Role,
get_dtype,
get_name,
get_role,
)
from .typing import P, AnyPandas, DataClass, Role, get_dtype, get_name, get_role


# runtime classes
Expand All @@ -33,10 +24,10 @@ class Field:
id: str
"""Identifier of the field."""

name: AnyName
name: Hashable
"""Name of the field."""

role: Literal["attr", "data", "index"]
role: Literal["attr", "column", "data", "index"]
"""Role of the field."""

type: Optional[Any]
Expand All @@ -48,20 +39,12 @@ class Field:
default: Any
"""Default value of the field data."""

@property
def hashable_name(self) -> Hashable:
"""Hashable name of the field."""
if isinstance(self.name, dict):
return tuple(self.name.values())
else:
return self.name

def update(self, obj: DataClass[P]) -> "Field":
"""Update the specification by a dataclass object."""
return replace(
self,
name=format_name(self.name, obj),
default=getattr(obj, self.id),
default=getattr(obj, self.id, self.default),
)


Expand All @@ -73,6 +56,11 @@ def of_attr(self) -> "Fields":
"""Select only attribute field specifications."""
return Fields(field for field in self if field.role == "attr")

@property
def of_column(self) -> "Fields":
"""Select only column field specifications."""
return Fields(field for field in self if field.role == "column")

@property
def of_data(self) -> "Fields":
"""Select only data field specifications."""
Expand Down Expand Up @@ -129,6 +117,8 @@ def convert_field(field_: "Field_[Any]") -> Optional[Field]:

if role is Role.ATTR:
role = "attr"
elif role is Role.COLUMN:
role = "column"
elif role is Role.DATA:
role = "data"
elif role is Role.INDEX:
Expand Down Expand Up @@ -157,19 +147,12 @@ def eval_types(dataclass: Type[DataClass[P]]) -> Type[DataClass[P]]:
return dataclass


def format_name(name: AnyName, obj: DataClass[P]) -> AnyName:
def format_name(name: Hashable, obj: DataClass[P]) -> Hashable:
"""Format a name by a dataclass object."""
if isinstance(name, tuple):
return type(name)(format_name(elem, obj) for elem in name)

if isinstance(name, str):
return name.format(obj)

if isinstance(name, dict):
formatted: "dict[Hashable, Hashable]" = {}

for key, val in name.items():
key = key.format(obj) if isinstance(key, str) else key
val = val.format(obj) if isinstance(val, str) else val
formatted[key] = val

return formatted

return name
29 changes: 14 additions & 15 deletions pandas_dataclasses/core/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["Attr", "Data", "Index", "Other"]
__all__ = ["Attr", "Column", "Data", "Index", "Other"]


# standard library
Expand All @@ -12,7 +12,6 @@
Hashable,
Iterable,
Optional,
Tuple,
Type,
TypeVar,
)
Expand All @@ -34,7 +33,6 @@


# type hints (private)
AnyName: TypeAlias = "Hashable | dict[Hashable, Hashable]"
AnyPandas: TypeAlias = "pd.DataFrame | pd.Series"
P = ParamSpec("P")
T = TypeVar("T")
Expand Down Expand Up @@ -63,6 +61,9 @@ class Role(Enum):
ATTR = auto()
"""Annotation for attribute fields."""

COLUMN = auto()
"""Annotation for column fields."""

DATA = auto()
"""Annotation for data fields."""

Expand All @@ -82,6 +83,9 @@ def annotates(cls, tp: Any) -> bool:
Attr = Annotated[T, Role.ATTR]
"""Type hint for attribute fields (``Attr[T]``)."""

Column = Annotated[T, Role.COLUMN]
"""Type hint for column fields (``Column[T]``)."""

Data = Annotated[Collection[T], Role.DATA]
"""Type hint for data fields (``Data[T]``)."""

Expand Down Expand Up @@ -121,7 +125,7 @@ def get_annotated(tp: Any) -> Any:
raise TypeError("Could not find any role-annotated type.")


def get_annotations(tp: Any) -> Tuple[Any, ...]:
def get_annotations(tp: Any) -> "tuple[Any, ...]":
"""Extract annotations of the first role-annotated type."""
for annotated in filter(Role.annotates, find_annotated(tp)):
return get_args(annotated)[1:]
Expand All @@ -145,24 +149,19 @@ def get_dtype(tp: Any) -> Optional[str]:
return pandas_dtype(dtype).name


def get_name(tp: Any, default: AnyName = None) -> AnyName:
def get_name(tp: Any, default: Hashable = None) -> Hashable:
"""Extract a name if found or return given default."""
try:
name = get_annotations(tp)[1]
except (IndexError, TypeError):
return default

if isinstance(name, Hashable):
return name

if (
isinstance(name, dict)
and all(isinstance(key, Hashable) for key in name.keys())
and all(isinstance(val, Hashable) for val in name.values())
):
return dict(name)
try:
hash(name)
except TypeError:
raise ValueError("Could not find any valid name.")

raise ValueError("Could not find any valid name.")
return name


def get_role(tp: Any, default: Role = Role.OTHER) -> Role:
Expand Down
25 changes: 13 additions & 12 deletions tests/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,16 @@


# standard library
from dataclasses import dataclass
from dataclasses import dataclass, field


# dependencies
import pandas as pd
from pandas_dataclasses import Attr, Data, Index
from pandas_dataclasses import Attr, Column, Data, Index
from typing_extensions import Annotated as Ann


# test dataclass and object
def name(stat: str, cat: str) -> "dict[str, str]":
"""Shorthand function for data names."""
return {"Statistic": stat, "Category": cat}


@dataclass
class Weather:
"""Weather information."""
Expand All @@ -32,16 +27,22 @@ class Weather:
month: Ann[Index[int], "Month"]
"""Month of the measured time."""

temp_avg: Ann[Data[float], name("Temperature ({.temp_unit})", "Average")]
meas: Ann[Column[None], "Measurement"] = field(init=False, repr=False)
"""Name of the measurement."""

stat: Ann[Column[None], "Statistic"] = field(init=False, repr=False)
"""Name of the statistic."""

temp_avg: Ann[Data[float], ("Temperature ({.temp_unit})", "Average")]
"""Monthly average temperature with given units."""

temp_max: Ann[Data[float], name("Temperature ({.temp_unit})", "Maximum")]
temp_max: Ann[Data[float], ("Temperature ({.temp_unit})", "Maximum")]
"""Monthly maximum temperature with given units."""

wind_avg: Ann[Data[float], name("Wind speed ({.wind_unit})", "Average")]
wind_avg: Ann[Data[float], ("Wind speed ({.wind_unit})", "Average")]
"""Monthly average wind speed with given units."""

wind_max: Ann[Data[float], name("Wind speed ({.wind_unit})", "Maximum")]
wind_max: Ann[Data[float], ("Wind speed ({.wind_unit})", "Maximum")]
"""Monthly maximum wind speed with given units."""

loc: Ann[Attr[str], "Location"] = "Tokyo"
Expand Down Expand Up @@ -98,7 +99,7 @@ class Weather:
("Wind speed (m/s)", "Average"),
("Wind speed (m/s)", "Maximum"),
],
names=("Statistic", "Category"),
names=("Measurement", "Statistic"),
),
)
df_weather_true.attrs = {
Expand Down
38 changes: 33 additions & 5 deletions tests/test_mixins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# standard library
from dataclasses import asdict, dataclass
from dataclasses import dataclass


# dependencies
Expand Down Expand Up @@ -39,28 +39,56 @@ class CustomSeriesWeather(Weather, As[CustomSeries]):

# test functions
def test_dataframe_weather() -> None:
df_weather = DataFrameWeather.new(**asdict(weather))
df_weather = DataFrameWeather.new(
year=weather.year,
month=weather.month,
temp_avg=weather.temp_avg,
temp_max=weather.temp_max,
wind_avg=weather.wind_avg,
wind_max=weather.wind_max,
)

assert isinstance(df_weather, pd.DataFrame)
assert (df_weather == df_weather_true).all().all()


def test_custom_dataframe_weather() -> None:
df_weather = CustomDataFrameWeather.new(**asdict(weather))
df_weather = CustomDataFrameWeather.new(
year=weather.year,
month=weather.month,
temp_avg=weather.temp_avg,
temp_max=weather.temp_max,
wind_avg=weather.wind_avg,
wind_max=weather.wind_max,
)

assert isinstance(df_weather, CustomDataFrame)
assert (df_weather == df_weather_true).all().all()


def test_series_weather() -> None:
ser_weather = SeriesWeather.new(**asdict(weather))
ser_weather = SeriesWeather.new(
year=weather.year,
month=weather.month,
temp_avg=weather.temp_avg,
temp_max=weather.temp_max,
wind_avg=weather.wind_avg,
wind_max=weather.wind_max,
)

assert isinstance(ser_weather, pd.Series)
assert (ser_weather == ser_weather_true).all()


def test_custom_series_weather() -> None:
ser_weather = CustomSeriesWeather.new(**asdict(weather))
ser_weather = CustomSeriesWeather.new(
year=weather.year,
month=weather.month,
temp_avg=weather.temp_avg,
temp_max=weather.temp_max,
wind_avg=weather.wind_avg,
wind_max=weather.wind_max,
)

assert isinstance(ser_weather, CustomSeries)
assert (ser_weather == ser_weather_true).all()
Loading

0 comments on commit dcf281d

Please sign in to comment.