Skip to content

Commit

Permalink
Fix calling nulltransformer when model missing values is set to True (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer authored Sep 9, 2022
1 parent 857959a commit 3e91980
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
2 changes: 1 addition & 1 deletion rdt/transformers/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _reverse_transform_helper(self, data):
if not isinstance(data, np.ndarray):
data = data.to_numpy()

if self.missing_value_replacement is not None:
if self.model_missing_values or self.missing_value_replacement is not None:
data = self.null_transformer.reverse_transform(data)

data = np.round(data.astype(np.float64))
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/transformers/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,36 @@ def test__reverse_transform_helper_nulls(self):
datetimes = transformer.null_transformer.reverse_transform.mock_calls[0][1][0]
np.testing.assert_array_equal(data.to_numpy(), datetimes)

def test__reverse_transform_helper_model_missing_values_true(self):
"""Test the ``_reverse_transform_helper`` with null values.
Setup:
- Mock the ``instance.null_transformer``.
- Set the ``model_missing_values``.
Input:
- a pandas series.
Output:
- a pandas datetime index.
Expected behavior:
- The mock should call its ``reverse_transform`` method.
"""
# Setup
data = pd.to_datetime(['2020-01-01', '2020-02-01', '2020-03-01'])
transformer = UnixTimestampEncoder(model_missing_values=True)
transformer.null_transformer = Mock()
transformer.null_transformer.reverse_transform.return_value = pd.Series([1, 2, 3])

# Run
transformer._reverse_transform_helper(data)

# Assert
transformer.null_transformer.reverse_transform.assert_called_once()
datetimes = transformer.null_transformer.reverse_transform.mock_calls[0][1][0]
np.testing.assert_array_equal(data.to_numpy(), datetimes)

@patch('rdt.transformers.datetime.NullTransformer')
def test__fit(self, null_transformer_mock):
"""Test the ``_fit`` method for numpy arrays.
Expand Down

0 comments on commit 3e91980

Please sign in to comment.