From 04500caa454789cbde69ead1457d54711dea693a Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Tue, 17 Nov 2020 21:12:08 -0800 Subject: [PATCH] Added flag for unknown value error --- rdt/transformers/categorical.py | 11 ++++++++++- tests/transformers/test_boolean.py | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/rdt/transformers/categorical.py b/rdt/transformers/categorical.py index 40ad408f8..2b0a460e2 100644 --- a/rdt/transformers/categorical.py +++ b/rdt/transformers/categorical.py @@ -218,11 +218,20 @@ class OneHotEncodingTransformer(BaseTransformer): is found and 0s on the rest. Null values are considered just another category. + + Args: + error_on_unknown (bool): + + If a value that was not seen during the fit stage is passed to + transform, then an error will be raised if this is True. """ dummy_na = None dummies = None + def __init__(self, error_on_unknown=True): + self.error_on_unknown = error_on_unknown + def fit(self, data): """Fit the transformer to the data. @@ -248,7 +257,7 @@ def transform(self, data): dummies = pd.get_dummies(data, dummy_na=self.dummy_na) array = dummies.reindex(columns=self.dummies, fill_value=0).values.astype(int) for i, row in enumerate(array): - if np.all(row == 0): + if np.all(row == 0) and self.error_on_unknown: raise ValueError(f"The value {data[i]} was not seen during the fit stage.") return array diff --git a/tests/transformers/test_boolean.py b/tests/transformers/test_boolean.py index ef49d1f1a..201f3ff9b 100644 --- a/tests/transformers/test_boolean.py +++ b/tests/transformers/test_boolean.py @@ -180,5 +180,5 @@ def test_reverse_transform_not_null_values(self): # Asserts expected = np.array([True, False, True]) - assert type(result) == pd.Series + assert isinstance(result, pd.Series) np.testing.assert_equal(result.values, expected)