Skip to content

Commit

Permalink
Added flag for unknown value error
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Nov 18, 2020
1 parent 6431706 commit 04500ca
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
11 changes: 10 additions & 1 deletion rdt/transformers/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,20 @@ class OneHotEncodingTransformer(BaseTransformer):
is found and 0s on the rest.
Null values are considered just another category.
Args:
error_on_unknown (bool):
If a value that was not seen during the fit stage is passed to
transform, then an error will be raised if this is True.
"""

dummy_na = None
dummies = None

def __init__(self, error_on_unknown=True):
self.error_on_unknown = error_on_unknown

def fit(self, data):
"""Fit the transformer to the data.
Expand All @@ -248,7 +257,7 @@ def transform(self, data):
dummies = pd.get_dummies(data, dummy_na=self.dummy_na)
array = dummies.reindex(columns=self.dummies, fill_value=0).values.astype(int)
for i, row in enumerate(array):
if np.all(row == 0):
if np.all(row == 0) and self.error_on_unknown:
raise ValueError(f"The value {data[i]} was not seen during the fit stage.")

return array
Expand Down
2 changes: 1 addition & 1 deletion tests/transformers/test_boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,5 +180,5 @@ def test_reverse_transform_not_null_values(self):
# Asserts
expected = np.array([True, False, True])

assert type(result) == pd.Series
assert isinstance(result, pd.Series)
np.testing.assert_equal(result.values, expected)

0 comments on commit 04500ca

Please sign in to comment.