diff --git a/rdt/transformers/datetime.py b/rdt/transformers/datetime.py index 7daeb1d9..28ad451f 100644 --- a/rdt/transformers/datetime.py +++ b/rdt/transformers/datetime.py @@ -39,16 +39,23 @@ class UnixTimestampEncoder(BaseTransformer): value was missing. Then use it to recreate missing values. * ``None``: Do nothing with the missing values on the reverse transform. Simply pass whatever data we get through. + enforce_min_max_values (bool): + Whether or not to clip the data returned by ``reverse_transform`` to the min and + max values seen during ``fit``. Defaults to ``False``. """ INPUT_SDTYPE = 'datetime' null_transformer = None + _min_value = None + _max_value = None def __init__(self, missing_value_replacement='mean', model_missing_values=None, - datetime_format=None, missing_value_generation='random'): + datetime_format=None, missing_value_generation='random', + enforce_min_max_values=False): super().__init__() self._set_missing_value_replacement('mean', missing_value_replacement) self._set_missing_value_generation(missing_value_generation) + self.enforce_min_max_values = enforce_min_max_values if model_missing_values is not None: self._set_model_missing_values(model_missing_values) @@ -124,6 +131,10 @@ def _fit(self, data): self.datetime_format = _guess_datetime_format_for_array(datetime_array) transformed = self._transform_helper(data) + if self.enforce_min_max_values: + self._min_value = transformed.min() + self._max_value = transformed.max() + self.null_transformer = NullTransformer( self.missing_value_replacement, self.missing_value_generation @@ -155,6 +166,9 @@ def _reverse_transform(self, data): Returns: pandas.Series """ + if self.enforce_min_max_values: + data = data.clip(self._min_value, self._max_value) + data = self._reverse_transform_helper(data) datetime_data = pd.to_datetime(data) if self.datetime_format: @@ -208,14 +222,19 @@ class OptimizedTimestampEncoder(UnixTimestampEncoder): value was missing. Then use it to recreate missing values. * ``None``: Do nothing with the missing values on the reverse transform. Simply pass whatever data we get through. + enforce_min_max_values (bool): + Whether or not to clip the data returned by ``reverse_transform`` to the min and + max values seen during ``fit``. Defaults to ``False``. """ divider = None def __init__(self, missing_value_replacement=None, model_missing_values=None, - datetime_format=None, missing_value_generation='random'): + datetime_format=None, missing_value_generation='random', + enforce_min_max_values=False): super().__init__(missing_value_replacement=missing_value_replacement, missing_value_generation=missing_value_generation, + enforce_min_max_values=enforce_min_max_values, model_missing_values=model_missing_values, datetime_format=datetime_format) diff --git a/tests/integration/transformers/test_datetime.py b/tests/integration/transformers/test_datetime.py index ceb4c948..59e189f1 100644 --- a/tests/integration/transformers/test_datetime.py +++ b/tests/integration/transformers/test_datetime.py @@ -156,6 +156,26 @@ def test_unixtimestampencoder_with_nans(self): pd.testing.assert_frame_equal(expected_transformed, transformed) pd.testing.assert_frame_equal(reverted, data) + def test_with_enforce_min_max_values_true(self): + """Test that the transformer properly clipped out of bounds values.""" + # Setup + ute = UnixTimestampEncoder(enforce_min_max_values=True) + data = pd.DataFrame({'column': ['Feb 03, 1981', 'Oct 17, 1996', 'May 23, 1965']}) + ute.fit(data, column='column') + + # Run + transformed = ute.transform(data) + min_val = transformed['column'].min() + max_val = transformed['column'].max() + transformed.loc[transformed['column'] == min_val, 'column'] = min_val - 1e17 + transformed.loc[transformed['column'] == max_val, 'column'] = max_val + 1e17 + reverted = ute.reverse_transform(transformed) + + # Asserts + assert ute._min_value == min_val + assert ute._max_value == max_val + pd.testing.assert_frame_equal(reverted, data) + class TestOptimizedTimestampEncoder: def test_optimizedtimestampencoder(self): diff --git a/tests/unit/transformers/test_datetime.py b/tests/unit/transformers/test_datetime.py index 685504ff..d0070723 100644 --- a/tests/unit/transformers/test_datetime.py +++ b/tests/unit/transformers/test_datetime.py @@ -17,13 +17,15 @@ def test___init__(self): transformer = UnixTimestampEncoder( missing_value_replacement='mode', missing_value_generation='from_column', - datetime_format='%M-%d-%Y' + datetime_format='%M-%d-%Y', + enforce_min_max_values=True, ) # Asserts assert transformer.missing_value_replacement == 'mode' assert transformer.missing_value_generation == 'from_column' assert transformer.datetime_format == '%M-%d-%Y' + assert transformer.enforce_min_max_values is True def test___init__with_model_missing_values(self): """Test the ``__init__`` method and the passed arguments are stored as attributes.""" @@ -267,6 +269,23 @@ def test__fit(self, null_transformer_mock): np.array([1.577837e+18, 1.580515e+18, 1.583021e+18]), rtol=1e-5 ) + def test__fit_enforce_min_max_values(self): + """Test the ``_fit`` method when enforce_min_max_values is True. + + It should compute the min and max values of the integer conversion + of the datetimes. + """ + # Setup + data = pd.to_datetime(['2020-01-01', '2020-02-01', '2020-03-01']) + transformer = UnixTimestampEncoder(enforce_min_max_values=True) + + # Run + transformer._fit(data) + + # Assert + assert transformer._min_value == 1.5778368e+18 + assert transformer._max_value == 1.5830208e+18 + def test__fit_calls_transform_helper(self): """Test the ``_fit`` method. @@ -381,6 +400,30 @@ def test__reverse_transform(self): expected = pd.Series(pd.to_datetime(['2020-01-01', '2020-02-01', '2020-03-01'])) pd.testing.assert_series_equal(output, expected) + def test__reverse_transform_enforce_min_max_values(self): + """Test the ``_reverse_transform`` with enforce_min_max_values True. + + All the values that are outside the min and max values should be clipped to the min and + max values. + """ + # Setup + ute = UnixTimestampEncoder(enforce_min_max_values=True) + transformed = np.array([ + 1.5678367e+18, 1.5778368e+18, 1.5805152e+18, 1.5830208e+18, 1.5930209e+18 + ]) + ute.null_transformer = NullTransformer('mean') + ute._min_value = 1.5778368e+18 + ute._max_value = 1.5830208e+18 + + # Run + output = ute._reverse_transform(transformed) + + # Assert + expected = pd.Series(pd.to_datetime([ + '2020-01-01', '2020-01-01', '2020-02-01', '2020-03-01', '2020-03-01' + ])) + pd.testing.assert_series_equal(output, expected) + def test__reverse_transform_datetime_format_dtype_is_datetime(self): """Test the ``_reverse_transform`` method returns the correct datetime format.""" # Setup @@ -464,6 +507,24 @@ def test__reverse_transform_only_nans(self): class TestOptimizedTimestampEncoder: + def test___init__(self): + """Test the ``__init__`` method.""" + # Run + transformer = OptimizedTimestampEncoder( + missing_value_replacement='mode', + missing_value_generation='from_column', + datetime_format='%M-%d-%Y', + enforce_min_max_values=True, + ) + + # Asserts + assert transformer.enforce_min_max_values is True + assert transformer.missing_value_replacement == 'mode' + assert transformer.missing_value_generation == 'from_column' + assert transformer.datetime_format == '%M-%d-%Y' + assert transformer.divider is None + assert transformer.null_transformer is None + def test__find_divider(self): """Test the ``_find_divider`` method.