Skip to content

Commit

Permalink
#82 Update runtime features for columns
Browse files Browse the repository at this point in the history
  • Loading branch information
astropenguin committed Aug 31, 2022
1 parent d7bc128 commit 191d244
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 65 deletions.
38 changes: 15 additions & 23 deletions pandas_dataclasses/core/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,16 @@ def asdataframe(obj: Any, *, factory: Any = None) -> Any:
attrs = get_attrs(spec)
data = get_data(spec)
index = get_index(spec)
columns = get_columns(spec)
names = get_columns(spec)

if factory is None:
factory = spec.factory or pd.DataFrame

if not issubclass(factory, pd.DataFrame):
raise TypeError("Factory must be a subclass of DataFrame.")

dataframe = factory(data, index, columns)
dataframe = factory(data, index)
dataframe.columns.names = names
dataframe.attrs.update(attrs)
return dataframe

Expand Down Expand Up @@ -109,31 +110,22 @@ def get_attrs(spec: Spec) -> "dict[Hashable, Any]":
attrs: "dict[Hashable, Any]" = {}

for field in spec.fields.of_attr:
attrs[field.hashable_name] = field.default
attrs[field.name] = field.default

return attrs


def get_columns(spec: Spec) -> Optional[pd.Index]:
"""Derive columns from a specification."""
if not spec.fields.of_data:
return

names_ = [field.name for field in spec.fields.of_data]

if all(isinstance(name, Hashable) for name in names_):
return

if not all(isinstance(name, dict) for name in names_):
raise ValueError("All names must be dictionaries.")
def get_columns(spec: Spec) -> "list[Hashable]":
"""Derive column names from a specification."""
objs: "dict[Hashable, Any]" = {}

names = [tuple(name.keys()) for name in names_] # type: ignore
indexes = [tuple(name.values()) for name in names_] # type: ignore
for field in spec.fields.of_column:
objs[field.name] = field.default

if not len(set(names)) == 1:
raise ValueError("All name keys must be same.")

return pd.MultiIndex.from_tuples(indexes, names=names[0])
if not objs:
return [None]
else:
return list(objs.keys())


def get_data(spec: Spec) -> Optional["dict[Hashable, Any]"]:
Expand All @@ -145,7 +137,7 @@ def get_data(spec: Spec) -> Optional["dict[Hashable, Any]"]:
data: List[Any] = []

for field in spec.fields.of_data:
names.append(field.hashable_name)
names.append(field.name)
data.append(ensure(field.default, field.dtype))

return dict(zip(names, data))
Expand All @@ -160,7 +152,7 @@ def get_index(spec: Spec) -> Optional[pd.Index]:
indexes: List[Any] = []

for field in spec.fields.of_index:
names.append(field.hashable_name)
names.append(field.name)
indexes.append(ensure(field.default, field.dtype))

indexes = np.broadcast_arrays(*indexes)
Expand Down
36 changes: 6 additions & 30 deletions pandas_dataclasses/core/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,7 @@


# submodules
from .typing import (
P,
AnyName,
AnyPandas,
DataClass,
Role,
get_dtype,
get_name,
get_role,
)
from .typing import P, AnyPandas, DataClass, Role, get_dtype, get_name, get_role


# runtime classes
Expand All @@ -33,7 +24,7 @@ class Field:
id: str
"""Identifier of the field."""

name: AnyName
name: Hashable
"""Name of the field."""

role: Literal["attr", "column", "data", "index"]
Expand All @@ -48,14 +39,6 @@ class Field:
default: Any
"""Default value of the field data."""

@property
def hashable_name(self) -> Hashable:
"""Hashable name of the field."""
if isinstance(self.name, dict):
return tuple(self.name.values())
else:
return self.name

def update(self, obj: DataClass[P]) -> "Field":
"""Update the specification by a dataclass object."""
return replace(
Expand Down Expand Up @@ -164,19 +147,12 @@ def eval_types(dataclass: Type[DataClass[P]]) -> Type[DataClass[P]]:
return dataclass


def format_name(name: AnyName, obj: DataClass[P]) -> AnyName:
def format_name(name: Hashable, obj: DataClass[P]) -> Hashable:
"""Format a name by a dataclass object."""
if isinstance(name, tuple):
return type(name)(format_name(elem, obj) for elem in name)

if isinstance(name, str):
return name.format(obj)

if isinstance(name, dict):
formatted: "dict[Hashable, Hashable]" = {}

for key, val in name.items():
key = key.format(obj) if isinstance(key, str) else key
val = val.format(obj) if isinstance(val, str) else val
formatted[key] = val

return formatted

return name
18 changes: 6 additions & 12 deletions pandas_dataclasses/core/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@


# type hints (private)
AnyName: TypeAlias = "Hashable | dict[Hashable, Hashable]"
AnyPandas: TypeAlias = "pd.DataFrame | pd.Series"
P = ParamSpec("P")
T = TypeVar("T")
Expand Down Expand Up @@ -151,24 +150,19 @@ def get_dtype(tp: Any) -> Optional[str]:
return pandas_dtype(dtype).name


def get_name(tp: Any, default: AnyName = None) -> AnyName:
def get_name(tp: Any, default: Hashable = None) -> Hashable:
"""Extract a name if found or return given default."""
try:
name = get_annotations(tp)[1]
except (IndexError, TypeError):
return default

if isinstance(name, Hashable):
return name

if (
isinstance(name, dict)
and all(isinstance(key, Hashable) for key in name.keys())
and all(isinstance(val, Hashable) for val in name.values())
):
return dict(name)
try:
hash(name)
except TypeError:
raise ValueError("Could not find any valid name.")

raise ValueError("Could not find any valid name.")
return name


def get_role(tp: Any, default: Role = Role.OTHER) -> Role:
Expand Down

0 comments on commit 191d244

Please sign in to comment.