Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions rdt/transformers/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,11 @@ class OneHotEncodingTransformer(BaseTransformer):
"""

dummies = None
dummy_na = None
num_dummies = None
dummy_encoded = False
indexer = None
decoder = None
_dummy_na = None
_num_dummies = None
_dummy_encoded = False
_indexer = None
_uniques = None

def __init__(self, error_on_unknown=True):
self.error_on_unknown = error_on_unknown
Expand Down Expand Up @@ -297,19 +297,19 @@ def _prepare_data(data):
return data

def _transform(self, data):
if self.dummy_encoded:
coder = self.indexer
codes = pd.Categorical(data, categories=self.dummies).codes
if self._dummy_encoded:
coder = self._indexer
codes = pd.Categorical(data, categories=self._uniques).codes
else:
coder = self.dummies
coder = self._uniques
codes = data

rows = len(data)
dummies = np.broadcast_to(coder, (rows, self.num_dummies))
coded = np.broadcast_to(codes, (self.num_dummies, rows)).T
dummies = np.broadcast_to(coder, (rows, self._num_dummies))
coded = np.broadcast_to(codes, (self._num_dummies, rows)).T
array = (coded == dummies).astype(int)

if self.dummy_na:
if self._dummy_na:
null = np.zeros((rows, 1), dtype=int)
null[pd.isnull(data)] = 1
array = np.append(array, null, axis=1)
Expand All @@ -328,17 +328,17 @@ def fit(self, data):
data = self._prepare_data(data)

null = pd.isnull(data)
self.dummy_na = null.any()
self.dummies = list(pd.unique(data[~null]))
self.num_dummies = len(self.dummies)
self.indexer = list(range(self.num_dummies))
self.decoder = self.dummies.copy()
self._uniques = list(pd.unique(data[~null]))
self._dummy_na = null.any()
self._num_dummies = len(self._uniques)
self._indexer = list(range(self._num_dummies))
self.dummies = self._uniques.copy()

if not np.issubdtype(data.dtype, np.number):
self.dummy_encoded = True
self._dummy_encoded = True

if self.dummy_na:
self.decoder.append(np.nan)
if self._dummy_na:
self.dummies.append(np.nan)

def transform(self, data):
"""Replace each category with the OneHot vectors.
Expand Down Expand Up @@ -375,7 +375,7 @@ def reverse_transform(self, data):
data = data.reshape(-1, 1)

indices = np.argmax(data, axis=1)
return pd.Series(indices).map(self.decoder.__getitem__)
return pd.Series(indices).map(self.dummies.__getitem__)


class LabelEncodingTransformer(BaseTransformer):
Expand Down
125 changes: 82 additions & 43 deletions tests/unit/transformers/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,45 @@ def test__prepare_data_pandas_series(self):
expected = pd.Series(['a', 'b', 'c'])
np.testing.assert_array_equal(out, expected)

def test_fit_dummies_no_nans(self):
"""Test the ``fit`` method without nans.

Check that ``self.dummies`` does not
contain nans.

Input:
- Series with values
"""

# Setup
ohet = OneHotEncodingTransformer()

# Run
data = pd.Series(['a', 2, 'c'])
ohet.fit(data)

# Assert
np.testing.assert_array_equal(ohet.dummies, ['a', 2, 'c'])

def test_fit_dummies_nans(self):
"""Test the ``fit`` method without nans.

Check that ``self.dummies`` contain ``np.nan``.

Input:
- Series with values
"""

# Setup
ohet = OneHotEncodingTransformer()

# Run
data = pd.Series(['a', 2, 'c', None])
ohet.fit(data)

# Assert
np.testing.assert_array_equal(ohet.dummies, ['a', 2, 'c', np.nan])

def test_fit_no_nans(self):
"""Test the ``fit`` method without nans.

Expand All @@ -535,9 +574,9 @@ def test_fit_no_nans(self):

# Assert
np.testing.assert_array_equal(ohet.dummies, ['a', 'b', 'c'])
np.testing.assert_array_equal(ohet.decoder, ['a', 'b', 'c'])
assert ohet.dummy_encoded
assert not ohet.dummy_na
np.testing.assert_array_equal(ohet._uniques, ['a', 'b', 'c'])
assert ohet._dummy_encoded
assert not ohet._dummy_na

def test_fit_no_nans_numeric(self):
"""Test the ``fit`` method without nans.
Expand All @@ -559,9 +598,9 @@ def test_fit_no_nans_numeric(self):

# Assert
np.testing.assert_array_equal(ohet.dummies, [1, 2, 3])
np.testing.assert_array_equal(ohet.decoder, [1, 2, 3])
assert not ohet.dummy_encoded
assert not ohet.dummy_na
np.testing.assert_array_equal(ohet._uniques, [1, 2, 3])
assert not ohet._dummy_encoded
assert not ohet._dummy_na

def test_fit_nans(self):
"""Test the ``fit`` method with nans.
Expand All @@ -582,10 +621,10 @@ def test_fit_nans(self):
ohet.fit(data)

# Assert
np.testing.assert_array_equal(ohet.dummies, ['a', 'b'])
np.testing.assert_array_equal(ohet.decoder, ['a', 'b', np.nan])
assert ohet.dummy_encoded
assert ohet.dummy_na
np.testing.assert_array_equal(ohet.dummies, ['a', 'b', np.nan])
np.testing.assert_array_equal(ohet._uniques, ['a', 'b'])
assert ohet._dummy_encoded
assert ohet._dummy_na

def test_fit_nans_numeric(self):
"""Test the ``fit`` method with nans.
Expand All @@ -606,10 +645,10 @@ def test_fit_nans_numeric(self):
ohet.fit(data)

# Assert
np.testing.assert_array_equal(ohet.dummies, [1, 2])
np.testing.assert_array_equal(ohet.decoder, [1, 2, np.nan])
assert not ohet.dummy_encoded
assert ohet.dummy_na
np.testing.assert_array_equal(ohet.dummies, [1, 2, np.nan])
np.testing.assert_array_equal(ohet._uniques, [1, 2])
assert not ohet._dummy_encoded
assert ohet._dummy_na

def test_fit_single(self):
# Setup
Expand All @@ -636,8 +675,8 @@ def test__transform_no_nan(self):
# Setup
ohet = OneHotEncodingTransformer()
data = pd.Series(['a', 'b', 'c'])
ohet.dummies = ['a', 'b', 'c']
ohet.num_dummies = 3
ohet._uniques = ['a', 'b', 'c']
ohet._num_dummies = 3

# Run
out = ohet._transform(data)
Expand Down Expand Up @@ -665,10 +704,10 @@ def test__transform_no_nan_categorical(self):
# Setup
ohet = OneHotEncodingTransformer()
data = pd.Series(['a', 'b', 'c'])
ohet.dummies = ['a', 'b', 'c']
ohet.indexer = [0, 1, 2]
ohet.num_dummies = 3
ohet.dummy_encoded = True
ohet._uniques = ['a', 'b', 'c']
ohet._indexer = [0, 1, 2]
ohet._num_dummies = 3
ohet._dummy_encoded = True

# Run
out = ohet._transform(data)
Expand Down Expand Up @@ -696,9 +735,9 @@ def test__transform_nans(self):
# Setup
ohet = OneHotEncodingTransformer()
data = pd.Series([np.nan, None, 'a', 'b'])
ohet.dummies = ['a', 'b']
ohet.dummy_na = True
ohet.num_dummies = 2
ohet._uniques = ['a', 'b']
ohet._dummy_na = True
ohet._num_dummies = 2

# Run
out = ohet._transform(data)
Expand Down Expand Up @@ -728,11 +767,11 @@ def test__transform_nans_categorical(self):
# Setup
ohet = OneHotEncodingTransformer()
data = pd.Series([np.nan, None, 'a', 'b'])
ohet.dummies = ['a', 'b']
ohet.indexer = [0, 1]
ohet.dummy_na = True
ohet.num_dummies = 2
ohet.dummy_encoded = True
ohet._uniques = ['a', 'b']
ohet._indexer = [0, 1]
ohet._dummy_na = True
ohet._num_dummies = 2
ohet._dummy_encoded = True

# Run
out = ohet._transform(data)
Expand Down Expand Up @@ -761,8 +800,8 @@ def test__transform_single(self):
# Setup
ohet = OneHotEncodingTransformer()
data = pd.Series(['a', 'a', 'a'])
ohet.dummies = ['a']
ohet.num_dummies = 1
ohet._uniques = ['a']
ohet._num_dummies = 1

# Run
out = ohet._transform(data)
Expand Down Expand Up @@ -791,10 +830,10 @@ def test__transform_single_categorical(self):
# Setup
ohet = OneHotEncodingTransformer()
data = pd.Series(['a', 'a', 'a'])
ohet.dummies = ['a']
ohet.indexer = [0]
ohet.num_dummies = 1
ohet.dummy_encoded = True
ohet._uniques = ['a']
ohet._indexer = [0]
ohet._num_dummies = 1
ohet._dummy_encoded = True

# Run
out = ohet._transform(data)
Expand Down Expand Up @@ -822,8 +861,8 @@ def test__transform_zeros(self):
# Setup
ohet = OneHotEncodingTransformer()
pd.Series(['a'])
ohet.dummies = ['a']
ohet.num_dummies = 1
ohet._uniques = ['a']
ohet._num_dummies = 1

# Run
out = ohet._transform(pd.Series(['b', 'b', 'b']))
Expand Down Expand Up @@ -852,9 +891,9 @@ def test__transform_zeros_categorical(self):
# Setup
ohet = OneHotEncodingTransformer()
pd.Series(['a'])
ohet.dummies = ['a']
ohet.indexer = [0]
ohet.num_dummies = 1
ohet._uniques = ['a']
ohet._indexer = [0]
ohet._num_dummies = 1
ohet.dummy_encoded = True

# Run
Expand Down Expand Up @@ -883,9 +922,9 @@ def test__transform_unknown_nan(self):
# Setup
ohet = OneHotEncodingTransformer()
pd.Series(['a'])
ohet.dummies = ['a']
ohet.dummy_na = True
ohet.num_dummies = 1
ohet._uniques = ['a']
ohet._dummy_na = True
ohet._num_dummies = 1

# Run
out = ohet._transform(pd.Series(['b', 'b', np.nan]))
Expand Down Expand Up @@ -1023,7 +1062,7 @@ def test_transform_numeric(self):
out = ohet.transform(data)

# Assert
assert not ohet.dummy_encoded
assert not ohet._dummy_encoded
np.testing.assert_array_equal(out, expected)

def test_reverse_transform_no_nans(self):
Expand Down