From 9d6a06ffdd0769e8414a1ba2accfc6496a61c386 Mon Sep 17 00:00:00 2001 From: Matt Stone Date: Sat, 13 Apr 2024 10:32:30 -0400 Subject: [PATCH] feat: Add initial DataclassWriter implementation (#3) * chore: update gitignore * wip: add assertions for writable file * refactor: extract get_header() * feat: add fieldnames() * wip: add assert_file_is_appendable() * wip: initial DataclassWriter * wip: include/exclude fields * refactor: clean up attributes and docs * refactor: long-form mode * doc: update docstring --- .gitignore | 2 + dataclass_io/_lib/assertions.py | 97 +++++++++++ dataclass_io/_lib/dataclass_extensions.py | 14 +- dataclass_io/_lib/file.py | 67 ++++++++ dataclass_io/reader.py | 77 ++------- dataclass_io/writer.py | 187 ++++++++++++++++++++++ tests/_lib/test_assertions.py | 52 ++++++ tests/_lib/test_dataclass_extensions.py | 27 ++++ tests/test_writer.py | 118 ++++++++++++++ 9 files changed, 573 insertions(+), 68 deletions(-) create mode 100644 dataclass_io/_lib/file.py create mode 100644 dataclass_io/writer.py create mode 100644 tests/_lib/test_dataclass_extensions.py create mode 100644 tests/test_writer.py diff --git a/.gitignore b/.gitignore index 68bc17f..b745cbd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +.vscode/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/dataclass_io/_lib/assertions.py b/dataclass_io/_lib/assertions.py index cdb67b7..1d7770d 100644 --- a/dataclass_io/_lib/assertions.py +++ b/dataclass_io/_lib/assertions.py @@ -1,9 +1,13 @@ from dataclasses import is_dataclass from os import R_OK +from os import W_OK from os import access +from os import stat from pathlib import Path from dataclass_io._lib.dataclass_extensions import DataclassInstance +from dataclass_io._lib.dataclass_extensions import fieldnames +from dataclass_io._lib.file import get_header def assert_file_is_readable(path: Path) -> None: @@ -26,6 +30,78 @@ def assert_file_is_readable(path: Path) -> None: raise PermissionError(f"The input file is not readable: {path}") +def assert_file_is_writable(path: Path, overwrite: bool = True) -> None: + """ + Check that the output file path is writable. + + Optionally, ensure the output file does not exist. + + Raises: + FileExistsError: If the provided file path exists when `overwrite` is set to `False`. + FileNotFoundError: If the provided file path's parent directory does not exist. + IsADirectoryError: If the provided file path is a directory. + PermissionError: If the provided file path is not writable. + """ + + if path.exists(): + if not overwrite: + raise FileExistsError( + f"The output file already exists: {path}\n" + "Specify `overwrite=True` to overwrite the existing file." + ) + + if not path.is_file(): + raise IsADirectoryError(f"The output file path is a directory: {path}") + + if not access(path, W_OK): + raise PermissionError(f"The output file is not writable: {path}") + + else: + if not path.parent.exists(): + raise FileNotFoundError( + f"The specified directory for the output file path does not exist: {path.parent}" + ) + + if not access(path.parent, W_OK): + raise PermissionError( + f"The specified directory for the output file path is not writable: {path.parent}" + ) + + +def assert_file_is_appendable(path: Path, dataclass_type: type[DataclassInstance]) -> None: + if not path.exists(): + raise FileNotFoundError(f"The specified output file does not exist: {path}") + + if not path.is_file(): + raise IsADirectoryError(f"The specified output file path is a directory: {path}") + + if not access(path, W_OK): + raise PermissionError(f"The specified output file is not writable: {path}") + + if stat(path).st_size == 0: + raise ValueError(f"The specified output file is empty: {path}") + + if not access(path, R_OK): + raise PermissionError( + f"The specified output file is not readable: {path}\n" + "The output file must be readable to append to it. " + "The header of the existing output file is checked for consistency with the provided " + "dataclass before appending to it." + ) + + # TODO: pass delimiter and header_comment_char to get_header + with path.open("r") as f: + header = get_header(f) + if header is None: + raise ValueError(f"Could not find a header in the specified output file: {path}") + + if header.fieldnames != fieldnames(dataclass_type): + raise ValueError( + "The specified output file does not have the same field names as the provided " + f"dataclass {path}" + ) + + def assert_dataclass_is_valid(dataclass_type: type[DataclassInstance]) -> None: """ Check that the input type is a parseable dataclass. @@ -36,3 +112,24 @@ def assert_dataclass_is_valid(dataclass_type: type[DataclassInstance]) -> None: if not is_dataclass(dataclass_type): raise TypeError(f"The provided type must be a dataclass: {dataclass_type.__name__}") + + +def assert_fieldnames_are_dataclass_attributes( + specified_fieldnames: list[str], + dataclass_type: type[DataclassInstance], +) -> None: + """ + Check that all of the specified fields are attributes on the given dataclass. + + Raises: + ValueError: if any of the specified fieldnames are not an attribute on the given dataclass. + """ + + invalid_fieldnames = [f for f in specified_fieldnames if f not in fieldnames(dataclass_type)] + + if len(invalid_fieldnames) > 0: + raise ValueError( + "One or more of the specified fields are not attributes on the dataclass " + + f"{dataclass_type.__name__}: " + + ", ".join(invalid_fieldnames) + ) diff --git a/dataclass_io/_lib/dataclass_extensions.py b/dataclass_io/_lib/dataclass_extensions.py index 71d3c70..5b4c4ee 100644 --- a/dataclass_io/_lib/dataclass_extensions.py +++ b/dataclass_io/_lib/dataclass_extensions.py @@ -1,4 +1,5 @@ - +from dataclasses import fields +from dataclasses import is_dataclass from typing import Any from typing import ClassVar from typing import Protocol @@ -16,3 +17,14 @@ class DataclassInstance(Protocol): """ __dataclass_fields__: ClassVar[dict[str, Any]] + + +def fieldnames(dataclass_type: type[DataclassInstance]) -> list[str]: + """ + Return the fieldnames of the specified dataclass. + """ + + if not is_dataclass(dataclass_type): + raise TypeError(f"The provided type must be a dataclass: {dataclass_type.__name__}") + + return [f.name for f in fields(dataclass_type)] diff --git a/dataclass_io/_lib/file.py b/dataclass_io/_lib/file.py new file mode 100644 index 0000000..61b4f5a --- /dev/null +++ b/dataclass_io/_lib/file.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass +from io import TextIOWrapper +from typing import IO +from typing import Optional +from typing import TextIO +from typing import TypeAlias + +ReadableFileHandle: TypeAlias = TextIOWrapper | IO | TextIO + + +@dataclass(frozen=True, kw_only=True) +class FileHeader: + """ + Header of a file. + + A file's header contains an optional preface, consisting of lines prefixed by a comment + character and/or empty lines, and a required row of fieldnames before the data rows begin. + + Attributes: + preface: A list of any lines preceding the fieldnames. + fieldnames: The field names specified in the final line of the header. + """ + + preface: list[str] + fieldnames: list[str] + + +def get_header( + reader: ReadableFileHandle, + delimiter: str = "\t", + header_comment_char: str = "#", +) -> Optional[FileHeader]: + """ + Read the header from an open file. + + The first row after any commented or empty lines will be used as the fieldnames. + + Lines preceding the fieldnames will be returned in the `preface.` + + NB: This function returns `Optional` instead of raising an error because the name of the + source file is not in scope, making it difficult to provide a helpful error message. It is + the responsibility of the caller to raise an error if the file is empty. + + See original proof-of-concept here: https://github.com/fulcrumgenomics/fgpyo/pull/103 + + Args: + reader: An open, readable file handle. + comment_char: The character which indicates the start of a comment line. + + Returns: + A `FileHeader` containing the field names and any preceding lines. + None if the file was empty or contained only comments or empty lines. + """ + + preface: list[str] = [] + + for line in reader: + if line.startswith(header_comment_char) or line.strip() == "": + preface.append(line.strip()) + else: + break + else: + return None + + fieldnames = line.strip().split(delimiter) + + return FileHeader(preface=preface, fieldnames=fieldnames) diff --git a/dataclass_io/reader.py b/dataclass_io/reader.py index 957adaf..d3907f9 100644 --- a/dataclass_io/reader.py +++ b/dataclass_io/reader.py @@ -1,38 +1,16 @@ from csv import DictReader -from dataclasses import dataclass from dataclasses import fields -from io import TextIOWrapper from pathlib import Path from types import TracebackType -from typing import IO from typing import Any -from typing import Optional -from typing import TextIO from typing import Type -from typing import TypeAlias from dataclass_io._lib.assertions import assert_dataclass_is_valid from dataclass_io._lib.assertions import assert_file_is_readable from dataclass_io._lib.dataclass_extensions import DataclassInstance - -ReadableFileHandle: TypeAlias = TextIOWrapper | IO | TextIO - - -@dataclass(frozen=True, kw_only=True) -class FileHeader: - """ - Header of a file. - - A file's header contains an optional preface, consisting of lines prefixed by a comment - character and/or empty lines, and a required row of fieldnames before the data rows begin. - - Attributes: - preface: A list of any lines preceding the fieldnames. - fieldnames: The field names specified in the final line of the header. - """ - - preface: list[str] - fieldnames: list[str] +from dataclass_io._lib.dataclass_extensions import fieldnames +from dataclass_io._lib.file import FileHeader +from dataclass_io._lib.file import get_header class DataclassReader: @@ -65,11 +43,16 @@ def __init__( self._fin = path.open("r") - self._header = self._get_header(self._fin) + self._header: FileHeader = get_header( + self._fin, + delimiter=delimiter, + header_comment_char=header_comment_char, + ) + if self._header is None: raise ValueError(f"Could not find a header in the provided file: {path}") - if self._header.fieldnames != [f.name for f in fields(dataclass_type)]: + if self._header.fieldnames != fieldnames(dataclass_type): raise ValueError( "The provided file does not have the same field names as the provided dataclass:\n" f"\tDataclass: {dataclass_type.__name__}\n" @@ -116,43 +99,3 @@ def _row_to_dataclass(self, row: dict[str, str]) -> DataclassInstance: coerced_values[field.name] = field.type(value) return self.dataclass_type(**coerced_values) - - def _get_header( - self, - reader: ReadableFileHandle, - ) -> Optional[FileHeader]: - """ - Read the header from an open file. - - The first row after any commented or empty lines will be used as the fieldnames. - - Lines preceding the fieldnames will be returned in the `preface.` - - NB: This function returns `Optional` instead of raising an error because the name of the - source file is not in scope, making it difficult to provide a helpful error message. It is - the responsibility of the caller to raise an error if the file is empty. - - See original proof-of-concept here: https://github.com/fulcrumgenomics/fgpyo/pull/103 - - Args: - reader: An open, readable file handle. - comment_char: The character which indicates the start of a comment line. - - Returns: - A `FileHeader` containing the field names and any preceding lines. - None if the file was empty or contained only comments or empty lines. - """ - - preface: list[str] = [] - - for line in reader: - if line.startswith(self.header_comment_char) or line.strip() == "": - preface.append(line.strip()) - else: - break - else: - return None - - fieldnames = line.strip().split(self.delimiter) - - return FileHeader(preface=preface, fieldnames=fieldnames) diff --git a/dataclass_io/writer.py b/dataclass_io/writer.py new file mode 100644 index 0000000..6f3dff2 --- /dev/null +++ b/dataclass_io/writer.py @@ -0,0 +1,187 @@ +from csv import DictWriter +from dataclasses import asdict +from enum import Enum +from enum import unique +from io import TextIOWrapper +from pathlib import Path +from types import TracebackType +from typing import IO +from typing import Any +from typing import Iterable +from typing import TextIO +from typing import Type +from typing import TypeAlias + +from dataclass_io._lib.assertions import assert_dataclass_is_valid +from dataclass_io._lib.assertions import assert_fieldnames_are_dataclass_attributes +from dataclass_io._lib.assertions import assert_file_is_appendable +from dataclass_io._lib.assertions import assert_file_is_writable +from dataclass_io._lib.dataclass_extensions import DataclassInstance +from dataclass_io._lib.dataclass_extensions import fieldnames + +WritableFileHandle: TypeAlias = TextIOWrapper | IO | TextIO + + +@unique +class WriteMode(Enum): + """ + The mode in which to open the file. + + Attributes: + value: The mode. + abbreviation: The short version of the mode (used with Python's `open()`). + """ + + value: str + abbreviation: str + + def __new__(cls, value: str, abbreviation: str) -> "WriteMode": + enum = object.__new__(cls) + enum._value_ = value + + return enum + + # NB: Specifying the additional fields in the `__init__` method instead of `__new__` is + # necessary in order to construct `WriteMode` from only the value (e.g. `WriteMode("append")`). + # Otherwise, `mypy` complains about a missing positional argument. + # https://stackoverflow.com/a/54732120 + def __init__(self, _: str, abbreviation: str = None): + self.abbreviation = abbreviation + + WRITE = "write", "w" + """Write to a new file.""" + + APPEND = "append", "a" + """Append to an existing file.""" + + +class DataclassWriter: + def __init__( + self, + path: Path, + dataclass_type: type[DataclassInstance], + mode: str = "write", + delimiter: str = "\t", + overwrite: bool = True, + include_fields: list[str] | None = None, + exclude_fields: list[str] | None = None, + **kwds: Any, + ) -> None: + """ + Args: + path: Path to the file to write. + dataclass_type: Dataclass type. + mode: Either `"write"` or `"append"`. + If `"write"`, the specified file `path` must not already exist unless + `overwrite=True` is specified. + If `"append"`, the specified file `path` must already exist and contain a header row + matching the specified dataclass and any specified `include_fields` or + `exclude_fields`. + delimiter: The output file delimiter. + overwrite: If `True`, and `mode="write"`, the file specified at `path` will be + overwritten if it exists. + include_fields: If specified, only the listed fieldnames will be included when writing + records to file. Fields will be written in the order provided. + May not be used together with `exclude_fields`. + exclude_fields: If specified, any listed fieldnames will be excluded when writing + records to file. + May not be used together with `include_fields`. + + Raises: + FileNotFoundError: If the input file does not exist. + IsADirectoryError: If the input file path is a directory. + PermissionError: If the input file is not readable. + TypeError: If the provided type is not a dataclass. + """ + + try: + write_mode = WriteMode(mode) + except ValueError: + raise ValueError(f"`mode` must be either 'write' or 'append': {mode}") from None + + assert_dataclass_is_valid(dataclass_type) + + if include_fields is not None and exclude_fields is not None: + raise ValueError( + "Only one of `include_fields` and `exclude_fields` may be specified, not both." + ) + elif exclude_fields is not None: + assert_fieldnames_are_dataclass_attributes(exclude_fields, dataclass_type) + self._fieldnames = [f for f in fieldnames(dataclass_type) if f not in exclude_fields] + elif include_fields is not None: + assert_fieldnames_are_dataclass_attributes(include_fields, dataclass_type) + self._fieldnames = include_fields + else: + self._fieldnames = fieldnames(dataclass_type) + + if write_mode is WriteMode.WRITE: + assert_file_is_writable(path, overwrite=overwrite) + else: + assert_file_is_appendable(path, dataclass_type=dataclass_type) + raise NotImplementedError + + self._dataclass_type = dataclass_type + self._fout = path.open(write_mode.abbreviation) + + self._writer = DictWriter( + f=self._fout, + fieldnames=self._fieldnames, + delimiter=delimiter, + ) + + # TODO: permit writing comment/preface rows before header + # If we aren't appending, write the header before any rows + if write_mode is WriteMode.WRITE: + self._writer.writeheader() + + def __enter__(self) -> "DataclassWriter": + return self + + def __exit__( + self, + exc_type: Type[BaseException], + exc_value: BaseException, + traceback: TracebackType, + ) -> None: + self.close() + + def close(self) -> None: + self._fout.close() + + def write(self, dataclass_instance: DataclassInstance) -> None: + """ + Write a single dataclass instance to file. + + The dataclass is converted to a dictionary and then written using the underlying + `csv.DictWriter`. If the `DataclassWriter` was created using the `include_fields` or + `exclude_fields` arguments, the attributes of the dataclass are subset and/or reordered + accordingly before writing. + + Args: + dataclass_instance: An instance of the specified dataclass. + """ + + # TODO: consider permitting other dataclass types *if* they contain the required attributes + if not isinstance(dataclass_instance, self._dataclass_type): + raise ValueError(f"Must provide instances of {self._dataclass_type.__name__}") + + # Filter and/or re-order output fields if necessary + row = asdict(dataclass_instance) + row = {fieldname: row[fieldname] for fieldname in self._fieldnames} + + self._writer.writerow(row) + + def writeall(self, dataclass_instances: Iterable[DataclassInstance]) -> None: + """ + Write multiple dataclass instances to file. + + Each dataclass is converted to a dictionary and then written using the underlying + `csv.DictWriter`. If the `DataclassWriter` was created using the `include_fields` or + `exclude_fields` arguments, the attributes of each dataclass are subset and/or reordered + accordingly before writing. + + Args: + dataclass_instances: A sequence of instances of the specified dataclass. + """ + for dataclass_instance in dataclass_instances: + self.write(dataclass_instance) diff --git a/tests/_lib/test_assertions.py b/tests/_lib/test_assertions.py index 339ee2c..f0c9d9e 100644 --- a/tests/_lib/test_assertions.py +++ b/tests/_lib/test_assertions.py @@ -5,6 +5,7 @@ from dataclass_io._lib.assertions import assert_dataclass_is_valid from dataclass_io._lib.assertions import assert_file_is_readable +from dataclass_io._lib.assertions import assert_file_is_writable def test_assert_dataclass_is_valid() -> None: @@ -80,3 +81,54 @@ def test_assert_file_is_readable_raises_if_file_is_unreadable(tmp_path: Path) -> with pytest.raises(PermissionError, match="The input file is not readable: "): assert_file_is_readable(fpath) + + +def test_assert_file_is_writable(tmp_path: Path) -> None: + """ + Test that we can validate if a file path is valid for writing. + """ + + # Non-existing files are writable + fpath = tmp_path / "test.txt" + try: + assert_file_is_writable(fpath, overwrite=False) + except Exception: + raise AssertionError("Failed to validate a valid file") from None + + # Existing files are writable if `overwrite=True` + fpath.touch() + try: + assert_file_is_writable(fpath, overwrite=True) + except Exception: + raise AssertionError("Failed to validate a valid file") from None + + +def test_assert_file_is_writable_raises_if_file_exists(tmp_path: Path) -> None: + """ + Test that we raise an error if the output file already exists when `overwrite=False`. + """ + + fpath = tmp_path / "test.txt" + fpath.touch() + + with pytest.raises(FileExistsError, match="The output file already exists: "): + assert_file_is_writable(tmp_path, overwrite=False) + + +def test_assert_file_is_writable_raises_if_file_is_directory(tmp_path: Path) -> None: + """ + Test that we raise an error if the output file path exists and is a directory. + """ + with pytest.raises(IsADirectoryError, match="The output file path is a directory: "): + assert_file_is_writable(tmp_path, overwrite=True) + + +def test_assert_file_is_writable_raises_if_parent_directory_does_not_exist(tmp_path: Path) -> None: + """ + Test that we raise an error if the parent directory of the output file path does not exist. + """ + + fpath = tmp_path / "abc" / "test.txt" + + with pytest.raises(FileNotFoundError, match="The specified directory for the output"): + assert_file_is_writable(fpath, overwrite=True) diff --git a/tests/_lib/test_dataclass_extensions.py b/tests/_lib/test_dataclass_extensions.py new file mode 100644 index 0000000..5d2a377 --- /dev/null +++ b/tests/_lib/test_dataclass_extensions.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass + +import pytest + +from dataclass_io._lib.dataclass_extensions import fieldnames + + +def test_fieldnames() -> None: + """Test we can get the fieldnames of a dataclass.""" + + @dataclass + class FakeDataclass: + foo: str + bar: int + + assert fieldnames(FakeDataclass) == ["foo", "bar"] + + +def test_fieldnames_raises_if_not_a_dataclass() -> None: + """Test we raise if we get a non-dataclass.""" + + class BadDataclass: + foo: str + bar: int + + with pytest.raises(TypeError, match="The provided type must be a dataclass: BadDataclass"): + fieldnames(BadDataclass) # type: ignore[arg-type] diff --git a/tests/test_writer.py b/tests/test_writer.py new file mode 100644 index 0000000..e75bf4b --- /dev/null +++ b/tests/test_writer.py @@ -0,0 +1,118 @@ +from dataclasses import dataclass +from pathlib import Path + +import pytest + +from dataclass_io.writer import DataclassWriter + + +@dataclass +class FakeDataclass: + foo: str + bar: int + + +def test_writer(tmp_path: Path) -> None: + fpath = tmp_path / "test.txt" + + with DataclassWriter(path=fpath, mode="write", dataclass_type=FakeDataclass) as writer: + writer.write(FakeDataclass(foo="abc", bar=1)) + writer.write(FakeDataclass(foo="def", bar=2)) + + with open(fpath, "r") as f: + assert next(f) == "foo\tbar\n" + assert next(f) == "abc\t1\n" + assert next(f) == "def\t2\n" + with pytest.raises(StopIteration): + next(f) + + +def test_writer_writeall(tmp_path: Path) -> None: + fpath = tmp_path / "test.txt" + + data = [ + FakeDataclass(foo="abc", bar=1), + FakeDataclass(foo="def", bar=2), + ] + with DataclassWriter(path=fpath, mode="write", dataclass_type=FakeDataclass) as writer: + writer.writeall(data) + + with open(fpath, "r") as f: + assert next(f) == "foo\tbar\n" + assert next(f) == "abc\t1\n" + assert next(f) == "def\t2\n" + with pytest.raises(StopIteration): + next(f) + + +def test_writer_include_fields(tmp_path: Path) -> None: + """Test that we can include only a subset of fields.""" + fpath = tmp_path / "test.txt" + + data = [ + FakeDataclass(foo="abc", bar=1), + FakeDataclass(foo="def", bar=2), + ] + with DataclassWriter( + path=fpath, + mode="write", + dataclass_type=FakeDataclass, + include_fields=["foo"], + ) as writer: + writer.writeall(data) + + with open(fpath, "r") as f: + assert next(f) == "foo\n" + assert next(f) == "abc\n" + assert next(f) == "def\n" + with pytest.raises(StopIteration): + next(f) + + +def test_writer_include_fields_reorders(tmp_path: Path) -> None: + """Test that we can reorder the output fields.""" + fpath = tmp_path / "test.txt" + + data = [ + FakeDataclass(foo="abc", bar=1), + FakeDataclass(foo="def", bar=2), + ] + with DataclassWriter( + path=fpath, + mode="write", + dataclass_type=FakeDataclass, + include_fields=["bar", "foo"], + ) as writer: + writer.writeall(data) + + with open(fpath, "r") as f: + assert next(f) == "bar\tfoo\n" + assert next(f) == "1\tabc\n" + assert next(f) == "2\tdef\n" + with pytest.raises(StopIteration): + next(f) + + +def test_writer_exclude_fields(tmp_path: Path) -> None: + """Test that we can exclude fields from being written.""" + + fpath = tmp_path / "test.txt" + + data = [ + FakeDataclass(foo="abc", bar=1), + FakeDataclass(foo="def", bar=2), + ] + with DataclassWriter( + path=fpath, + mode="write", + dataclass_type=FakeDataclass, + exclude_fields=["bar"], + ) as writer: + writer.writeall(data) + + with open(fpath, "r") as f: + assert next(f) == "foo\n" + assert next(f) == "abc\n" + assert next(f) == "def\n" + with pytest.raises(StopIteration): + next(f)