Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions dataprofiler/labelers/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def set_params(self, **kwargs: Any) -> None:
self._parameters[param] = kwargs[param]

@abc.abstractmethod
def process(self, *args: Any) -> Any:
def process(self, *args: Any, **kwargs: Any) -> Any:
"""Process data."""
raise NotImplementedError()

Expand Down Expand Up @@ -169,13 +169,15 @@ def __init__(self, **parameters: Any) -> None:
super().__init__(**parameters)

@abc.abstractmethod
def process( # type: ignore
def process(
self,
data: np.ndarray,
labels: np.ndarray | None = None,
label_mapping: dict[str, int] | None = None,
batch_size: int = 32,
) -> Generator[tuple[np.ndarray, np.ndarray] | np.ndarray, None, None]:
) -> Generator[tuple[np.ndarray, np.ndarray] | np.ndarray, None, None] | tuple[
np.ndarray, np.ndarray
] | np.ndarray:
"""Preprocess data."""
raise NotImplementedError()

Expand All @@ -191,7 +193,7 @@ def __init__(self, **parameters: Any) -> None:
super().__init__(**parameters)

@abc.abstractmethod
def process( # type: ignore
def process(
self,
data: np.ndarray,
results: dict,
Expand Down Expand Up @@ -240,7 +242,7 @@ def help(cls) -> None:
)
print(help_str)

def process( # type: ignore
def process(
self,
data: np.ndarray,
labels: np.ndarray | None = None,
Expand Down Expand Up @@ -668,7 +670,7 @@ def gen_none() -> Generator[None, None, None]:
if batch_data["samples"]:
yield batch_data

def process( # type: ignore
def process(
self,
data: np.ndarray,
labels: np.ndarray | None = None,
Expand Down Expand Up @@ -836,7 +838,7 @@ def _validate_parameters(self, parameters: dict) -> None:
if errors:
raise ValueError("\n".join(errors))

def process( # type: ignore
def process(
self,
data: np.ndarray,
labels: np.ndarray | None = None,
Expand Down Expand Up @@ -1269,7 +1271,7 @@ def match_sentence_lengths(

return results

def process( # type: ignore
def process(
self,
data: np.ndarray,
results: dict,
Expand Down Expand Up @@ -1439,7 +1441,7 @@ def convert_to_unstructured_format(

return text, entities

def process( # type: ignore
def process(
self,
data: np.ndarray,
labels: np.ndarray | None = None,
Expand Down Expand Up @@ -1800,7 +1802,7 @@ def convert_to_structured_analysis(

return results

def process( # type: ignore
def process(
self,
data: np.ndarray,
results: dict,
Expand Down Expand Up @@ -2022,7 +2024,7 @@ def split_prediction(results: dict) -> None:
pred, axis=1, ord=1, keepdims=True
)

def process( # type: ignore
def process(
self,
data: np.ndarray,
results: dict,
Expand Down Expand Up @@ -2160,7 +2162,7 @@ def _save_processor(self, dirpath: str) -> None:
) as fp:
json.dump(params, fp)

def process( # type: ignore
def process(
self,
data: np.ndarray,
results: dict,
Expand Down Expand Up @@ -2253,7 +2255,7 @@ def help(cls) -> None:
)
print(help_str)

def process( # type: ignore
def process(
self,
data: np.ndarray,
results: dict,
Expand Down