Skip to content

Commit

Permalink
simplify type checking
Browse files Browse the repository at this point in the history
loop through attributes rather than sending indivudally
to helper function.
also, allow whole number floats instead for ints
  • Loading branch information
DaniBodor committed Nov 22, 2023
1 parent af172d7 commit 62bc8d4
Showing 1 changed file with 12 additions and 16 deletions.
28 changes: 12 additions & 16 deletions eitprocessing/roi_selection/gridselection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from dataclasses import field
from typing import Literal
from typing import get_type_hints
import numpy as np
from numpy.typing import NDArray
from . import ROISelection
Expand Down Expand Up @@ -97,29 +98,24 @@ class GridSelection(ROISelection):
ignore_nan_rows: bool = True
ignore_nan_columns: bool = True

def _check_attribute_type(self, name, type_):
"""Checks whether an attribute is an instance of the given type."""
attr = getattr(self, name)
if not isinstance(attr, type_):
message = f"Invalid type for `{name}`."
message += f"Should be {type_}, not {type(attr)}."
raise TypeError(message)

def __post_init__(self):
self._check_attribute_type("v_split", int)
self._check_attribute_type("h_split", int)
try:
if self.v_split == int(self.v_split):
self.v_split = int(self.v_split)
if self.h_split == int(self.h_split):
self.h_split = int(self.h_split)
finally:
for attr, type_ in get_type_hints(self).items():
if not isinstance(getattr(self, attr), type_):
raise TypeError(
f"Invalid type for `{attr}`. Should be {type_}, not {type(attr)}."
)

if self.v_split < 1:
raise InvalidVerticalDivision("`v_split` can't be smaller than 1.")

if self.h_split < 1:
raise InvalidHorizontalDivision("`h_split` can't be smaller than 1.")

self._check_attribute_type("split_columns", bool)
self._check_attribute_type("split_rows", bool)
self._check_attribute_type("ignore_nan_columns", bool)
self._check_attribute_type("ignore_nan_rows", bool)

def find_grid(self, data: NDArray) -> list[NDArray]:
"""
Create 2D arrays to split a grid into regions.
Expand Down

0 comments on commit 62bc8d4

Please sign in to comment.