Skip to content

Commit

Permalink
refactor: clean up reader (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
msto authored Apr 13, 2024
1 parent 7b3a578 commit c8c3e3a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 27 deletions.
22 changes: 22 additions & 0 deletions dataclass_io/_lib/dataclass_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,25 @@ def fieldnames(dataclass_type: type[DataclassInstance]) -> list[str]:
raise TypeError(f"The provided type must be a dataclass: {dataclass_type.__name__}")

return [f.name for f in fields(dataclass_type)]


def row_to_dataclass(
row: dict[str, str],
dataclass_type: type[DataclassInstance],
) -> DataclassInstance:
"""
Convert a row of a CSV file into a dataclass instance.
Args:
row: A dictionary mapping each fieldname to its (string) value.
dataclass_type: The dataclass to which the row will be casted.
"""

coerced_values: dict[str, Any] = {}

# Coerce each value in the row to the type of the corresponding field
for field in fields(dataclass_type):
value = row[field.name]
coerced_values[field.name] = field.type(value)

return dataclass_type(**coerced_values)
46 changes: 19 additions & 27 deletions dataclass_io/reader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from csv import DictReader
from dataclasses import fields
from pathlib import Path
from types import TracebackType
from typing import Any
Expand All @@ -9,11 +8,17 @@
from dataclass_io._lib.assertions import assert_file_is_readable
from dataclass_io._lib.dataclass_extensions import DataclassInstance
from dataclass_io._lib.dataclass_extensions import fieldnames
from dataclass_io._lib.dataclass_extensions import row_to_dataclass
from dataclass_io._lib.file import FileHeader
from dataclass_io._lib.file import ReadableFileHandle
from dataclass_io._lib.file import get_header


class DataclassReader:
_dataclass_type: type[DataclassInstance]
_fin: ReadableFileHandle
_reader: DictReader

def __init__(
self,
path: Path,
Expand All @@ -37,34 +42,31 @@ def __init__(
assert_file_is_readable(path)
assert_dataclass_is_valid(dataclass_type)

self.dataclass_type = dataclass_type
self.delimiter = delimiter
self.header_comment_char = header_comment_char

self._dataclass_type = dataclass_type
self._fin = path.open("r")

self._header: FileHeader = get_header(
header: FileHeader = get_header(
self._fin,
delimiter=delimiter,
header_comment_char=header_comment_char,
)

if self._header is None:
if header is None:
raise ValueError(f"Could not find a header in the provided file: {path}")

if self._header.fieldnames != fieldnames(dataclass_type):
if header.fieldnames != fieldnames(dataclass_type):
raise ValueError(
"The provided file does not have the same field names as the provided dataclass:\n"
f"\tDataclass: {dataclass_type.__name__}\n"
f"\tFile: {path}\n"
f"\tDataclass fields: {dataclass_type.__name__}\n"
f"\tFile: {path}\n"
f"\tDataclass fields: {', '.join(fieldnames(dataclass_type))}\n"
f"\tFile: {', '.join(header.fieldnames)}\n"
)

self._reader = DictReader(
self._fin,
fieldnames=self._header.fieldnames,
delimiter=self.delimiter,
fieldnames=header.fieldnames,
delimiter=delimiter,
)

def __enter__(self) -> "DataclassReader":
Expand All @@ -76,6 +78,10 @@ def __exit__(
exc_value: BaseException,
traceback: TracebackType,
) -> None:
self.close()

def close(self) -> None:
"""Close the reader."""
self._fin.close()

def __iter__(self) -> "DataclassReader":
Expand All @@ -84,18 +90,4 @@ def __iter__(self) -> "DataclassReader":
def __next__(self) -> DataclassInstance:
row = next(self._reader)

return self._row_to_dataclass(row)

def _row_to_dataclass(self, row: dict[str, str]) -> DataclassInstance:
"""
Convert a row of a CSV file into a dataclass instance.
"""

coerced_values: dict[str, Any] = {}

# Coerce each value in the row to the type of the corresponding field
for field in fields(self.dataclass_type):
value = row[field.name]
coerced_values[field.name] = field.type(value)

return self.dataclass_type(**coerced_values)
return row_to_dataclass(row, self._dataclass_type)

0 comments on commit c8c3e3a

Please sign in to comment.