Skip to content

Commit

Permalink
Support 2d ndarrays for Boolean and Datetime transformers (#161)
Browse files Browse the repository at this point in the history
* Add 2d ndarray and upgrade tests

* Fix LabelEncodingTransformer and update tests.
  • Loading branch information
pvk-developer authored Mar 29, 2021
1 parent d693d4c commit 2de0b85
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 5 deletions.
3 changes: 3 additions & 0 deletions rdt/transformers/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def reverse_transform(self, data):
data = self.null_transformer.reverse_transform(data)

if isinstance(data, np.ndarray):
if data.ndim == 2:
data = data[:, 0]

data = pd.Series(data)

data[pd.notnull(data)] = np.round(data[pd.notnull(data)]).astype(bool)
Expand Down
3 changes: 3 additions & 0 deletions rdt/transformers/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,4 +316,7 @@ def reverse_transform(self, data):
Returns:
pandas.Series
"""
if isinstance(data, np.ndarray) and (data.ndim == 2):
data = data[:, 0]

return pd.Series(data).astype(int).map(self.values_to_categories)
3 changes: 3 additions & 0 deletions rdt/transformers/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def reverse_transform(self, data):
if self.nan is not None:
data = self.null_transformer.reverse_transform(data)

if isinstance(data, np.ndarray) and (data.ndim == 2):
data = data[:, 0]

data[pd.notnull(data)] = np.round(data[pd.notnull(data)]).astype(np.int64)
if self.strip_constant:
data = data.astype(float) * self.divider
Expand Down
125 changes: 120 additions & 5 deletions tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from rdt.transformers import OneHotEncodingTransformer


def get_input_data():
def get_input_data_with_nan():
data = pd.DataFrame({
'integer': [1, 2, 1, 3, 1],
'float': [0.1, 0.2, 0.1, np.nan, 0.1],
Expand All @@ -21,7 +21,40 @@ def get_input_data():
return data


def get_input_data_without_nan():
data = pd.DataFrame({
'integer': [1, 2, 1, 3],
'float': [0.1, 0.2, 0.1, 0.1],
'categorical': ['a', 'b', 'b', 'a'],
'bool': [False, False, True, False],
'datetime': [
'2010-02-01', '2010-01-01', '2010-02-01', '2010-01-01'
],
'names': ['Jon', 'Arya', 'Sansa', 'Jon'],
})
data['datetime'] = pd.to_datetime(data['datetime'])
data['bool'] = data['bool'].astype('O') # boolean transformer returns O instead of bool

return data


def get_transformed_data():
return pd.DataFrame({
'integer': [1, 2, 1, 3],
'float': [0.1, 0.2, 0.1, 0.1],
'categorical': [0.75, 0.25, 0.25, 0.75],
'bool': [0.0, 0.0, 1.0, 0.0],
'datetime': [
1.2649824e+18,
1.262304e+18,
1.2649824e+18,
1.262304e+18
],
'names': [0.25, 0.875, 0.625, 0.25]
})


def get_transformed_nan_data():
return pd.DataFrame({
'integer': [1, 2, 1, 3, 1],
'float': [0.1, 0.2, 0.1, 0.125, 0.1],
Expand Down Expand Up @@ -68,7 +101,7 @@ def get_transformers():


def test_hypertransformer_with_transformers():
data = get_input_data()
data = get_input_data_without_nan()
transformers = get_transformers()

ht = HyperTransformer(transformers)
Expand All @@ -93,8 +126,34 @@ def test_hypertransformer_with_transformers():
assert name not in reversed_names


def test_hypertransformer_with_transformers_nan_data():
data = get_input_data_with_nan()
transformers = get_transformers()

ht = HyperTransformer(transformers)
ht.fit(data)
transformed = ht.transform(data)

expected = get_transformed_nan_data()

np.testing.assert_allclose(
transformed.sort_index(axis=1).values,
expected.sort_index(axis=1).values
)

reversed_data = ht.reverse_transform(transformed)

original_names = data.pop('names')
reversed_names = reversed_data.pop('names')

pd.testing.assert_frame_equal(data.sort_index(axis=1), reversed_data.sort_index(axis=1))

for name in original_names:
assert name not in reversed_names


def test_hypertransformer_without_transformers():
data = get_input_data()
data = get_input_data_without_nan()

ht = HyperTransformer()
ht.fit(data)
Expand All @@ -118,6 +177,31 @@ def test_hypertransformer_without_transformers():
assert name not in reversed_names


def test_hypertransformer_without_transformers_nan_data():
data = get_input_data_with_nan()

ht = HyperTransformer()
ht.fit(data)
transformed = ht.transform(data)

expected = get_transformed_nan_data()

np.testing.assert_allclose(
transformed.sort_index(axis=1).values,
expected.sort_index(axis=1).values
)

reversed_data = ht.reverse_transform(transformed)

original_names = data.pop('names')
reversed_names = reversed_data.pop('names')

pd.testing.assert_frame_equal(data.sort_index(axis=1), reversed_data.sort_index(axis=1))

for name in original_names:
assert name not in reversed_names


def test_single_category():
ht = HyperTransformer(transformers={
'a': OneHotEncodingTransformer()
Expand Down Expand Up @@ -149,7 +233,21 @@ def test_dtype_category():

def test_empty_transformers():
"""If transformers is an empty dict, do nothing."""
data = get_input_data()
data = get_input_data_without_nan()

ht = HyperTransformer(transformers={})
ht.fit(data)

transformed = ht.transform(data)
reverse = ht.reverse_transform(transformed)

pd.testing.assert_frame_equal(data, transformed)
pd.testing.assert_frame_equal(data, reverse)


def test_empty_transformers_nan_data():
"""If transformers is an empty dict, do nothing."""
data = get_input_data_with_nan()

ht = HyperTransformer(transformers={})
ht.fit(data)
Expand All @@ -166,7 +264,24 @@ def test_subset_of_columns():
See https://github.com/sdv-dev/RDT/issues/152
"""
data = get_input_data()
data = get_input_data_without_nan()

ht = HyperTransformer()
ht.fit(data)

subset = data[[data.columns[0]]]
transformed = ht.transform(subset)
reverse = ht.reverse_transform(transformed)

pd.testing.assert_frame_equal(subset, reverse)


def test_subset_of_columns_nan_data():
"""HyperTransform should be able to transform a subset of the training columns.
See https://github.com/sdv-dev/RDT/issues/152
"""
data = get_input_data_with_nan()

ht = HyperTransformer()
ht.fit(data)
Expand Down
13 changes: 13 additions & 0 deletions tests/integration/transformers/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ def test_one_hot_numerical_nans():
pd.testing.assert_series_equal(reverse, data)


def test_label_numerical_2d_array():
"""Ensure LabelEncodingTransformer works on numerical + nan only columns."""

data = pd.Series([1, 2, 3, 4])

transformer = LabelEncodingTransformer()
transformer.fit(data)
transformed = np.array([[0], [1], [2], [3]])
reverse = transformer.reverse_transform(transformed)

pd.testing.assert_series_equal(reverse, data)


def test_label_numerical_nans():
"""Ensure LabelEncodingTransformer works on numerical + nan only columns."""

Expand Down
17 changes: 17 additions & 0 deletions tests/transformers/test_boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,20 @@ def test_reverse_transform_not_null_values(self):

assert isinstance(result, pd.Series)
np.testing.assert_equal(result.values, expected)

def test_reverse_transform_2d_ndarray(self):
"""Test reverse_transform not null values correctly"""
# Setup
data = np.array([[1.], [0.], [1.]])

# Run
transformer = Mock()
transformer.nan = None

result = BooleanTransformer.reverse_transform(transformer, data)

# Asserts
expected = np.array([True, False, True])

assert isinstance(result, pd.Series)
np.testing.assert_equal(result.values, expected)
11 changes: 11 additions & 0 deletions tests/transformers/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,14 @@ def test_reverse_transform_all_none(self):

expected = pd.to_datetime(['NaT'])
pd.testing.assert_series_equal(output.to_series(), expected.to_series())

def test_reverse_transform_2d_ndarray(self):
dt = pd.to_datetime(['2020-01-01', '2020-02-01', '2020-03-01'])
dtt = DatetimeTransformer(strip_constant=True)
dtt.fit(dt)

transformed = np.array([[18262.], [18293.], [18322.]])
output = dtt.reverse_transform(transformed)

expected = pd.to_datetime(['2020-01-01', '2020-02-01', '2020-03-01'])
pd.testing.assert_series_equal(output.to_series(), expected.to_series())

0 comments on commit 2de0b85

Please sign in to comment.