Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enforce_min_max_values to datetime transformers #741

Merged
merged 4 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions rdt/transformers/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class UnixTimestampEncoder(BaseTransformer):
datetime_format (str):
The strftime to use for parsing time. For more information, see
https://docs.python.org/3/library/datetime.html#strftime-and-strptime-behavior.
enforce_min_max_values (bool):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should add this as the last parameter. The main reason is that if someone runs scripts without using the key-word argument names, it might crash after this release since the 4th argument used to be a string and now it is a boolean with a different purpose. This isn't a huge deal though

Whether or not to clip the data returned by ``reverse_transform`` to the min and
max values seen during ``fit``. Defaults to ``False``.
missing_value_generation (str or None):
The way missing values are being handled. There are three strategies:

Expand All @@ -43,12 +46,16 @@ class UnixTimestampEncoder(BaseTransformer):

INPUT_SDTYPE = 'datetime'
null_transformer = None
_min_value = None
_max_value = None

def __init__(self, missing_value_replacement='mean', model_missing_values=None,
datetime_format=None, missing_value_generation='random'):
datetime_format=None, enforce_min_max_values=False,
missing_value_generation='random'):
super().__init__()
self._set_missing_value_replacement('mean', missing_value_replacement)
self._set_missing_value_generation(missing_value_generation)
self.enforce_min_max_values = enforce_min_max_values
if model_missing_values is not None:
self._set_model_missing_values(model_missing_values)

Expand Down Expand Up @@ -124,6 +131,10 @@ def _fit(self, data):
self.datetime_format = _guess_datetime_format_for_array(datetime_array)

transformed = self._transform_helper(data)
if self.enforce_min_max_values:
self._min_value = transformed.min()
self._max_value = transformed.max()

self.null_transformer = NullTransformer(
self.missing_value_replacement,
self.missing_value_generation
Expand Down Expand Up @@ -155,6 +166,9 @@ def _reverse_transform(self, data):
Returns:
pandas.Series
"""
if self.enforce_min_max_values:
data = data.clip(self._min_value, self._max_value)

data = self._reverse_transform_helper(data)
datetime_data = pd.to_datetime(data)
if self.datetime_format:
Expand Down Expand Up @@ -199,6 +213,9 @@ class OptimizedTimestampEncoder(UnixTimestampEncoder):
datetime_format (str):
The strftime to use for parsing time. For more information, see
https://docs.python.org/3/library/datetime.html#strftime-and-strptime-behavior.
enforce_min_max_values (bool):
Whether or not to clip the data returned by ``reverse_transform`` to the min and
max values seen during ``fit``. Defaults to ``False``.
missing_value_generation (str or None):
The way missing values are being handled. There are three strategies:

Expand All @@ -213,9 +230,11 @@ class OptimizedTimestampEncoder(UnixTimestampEncoder):
divider = None

def __init__(self, missing_value_replacement=None, model_missing_values=None,
datetime_format=None, missing_value_generation='random'):
datetime_format=None, enforce_min_max_values=False,
missing_value_generation='random'):
super().__init__(missing_value_replacement=missing_value_replacement,
missing_value_generation=missing_value_generation,
enforce_min_max_values=enforce_min_max_values,
model_missing_values=model_missing_values,
datetime_format=datetime_format)

Expand Down
20 changes: 20 additions & 0 deletions tests/integration/transformers/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,26 @@ def test_unixtimestampencoder_with_nans(self):
pd.testing.assert_frame_equal(expected_transformed, transformed)
pd.testing.assert_frame_equal(reverted, data)

def test_with_enforce_min_max_values_true(self):
"""Test that the transformer properly clipped out of bounds values."""
# Setup
ute = UnixTimestampEncoder(enforce_min_max_values=True)
data = pd.DataFrame({'column': ['Feb 03, 1981', 'Oct 17, 1996', 'May 23, 1965']})
ute.fit(data, column='column')

# Run
transformed = ute.transform(data)
min_val = transformed['column'].min()
max_val = transformed['column'].max()
transformed.loc[transformed['column'] == min_val, 'column'] = min_val - 1e17
transformed.loc[transformed['column'] == max_val, 'column'] = max_val + 1e17
reverted = ute.reverse_transform(transformed)

# Asserts
assert ute._min_value == min_val
assert ute._max_value == max_val
pd.testing.assert_frame_equal(reverted, data)


class TestOptimizedTimestampEncoder:
def test_optimizedtimestampencoder(self):
Expand Down
63 changes: 62 additions & 1 deletion tests/unit/transformers/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ def test___init__(self):
transformer = UnixTimestampEncoder(
missing_value_replacement='mode',
missing_value_generation='from_column',
datetime_format='%M-%d-%Y'
datetime_format='%M-%d-%Y',
enforce_min_max_values=True,
)

# Asserts
assert transformer.missing_value_replacement == 'mode'
assert transformer.missing_value_generation == 'from_column'
assert transformer.datetime_format == '%M-%d-%Y'
assert transformer.enforce_min_max_values is True

def test___init__with_model_missing_values(self):
"""Test the ``__init__`` method and the passed arguments are stored as attributes."""
Expand Down Expand Up @@ -267,6 +269,23 @@ def test__fit(self, null_transformer_mock):
np.array([1.577837e+18, 1.580515e+18, 1.583021e+18]), rtol=1e-5
)

def test__fit_enforce_min_max_values(self):
"""Test the ``_fit`` method when enforce_min_max_values is True.

It should compute the min and max values of the integer conversion
of the datetimes.
"""
# Setup
data = pd.to_datetime(['2020-01-01', '2020-02-01', '2020-03-01'])
transformer = UnixTimestampEncoder(enforce_min_max_values=True)

# Run
transformer._fit(data)

# Assert
assert transformer._min_value == 1.5778368e+18
assert transformer._max_value == 1.5830208e+18

def test__fit_calls_transform_helper(self):
"""Test the ``_fit`` method.

Expand Down Expand Up @@ -381,6 +400,30 @@ def test__reverse_transform(self):
expected = pd.Series(pd.to_datetime(['2020-01-01', '2020-02-01', '2020-03-01']))
pd.testing.assert_series_equal(output, expected)

def test__reverse_transform_enforce_min_max_values(self):
"""Test the ``_reverse_transform`` with enforce_min_max_values True.

All the values that are outside the min and max values should be clipped to the min and
max values.
"""
# Setup
ute = UnixTimestampEncoder(enforce_min_max_values=True)
transformed = np.array([
1.5678367e+18, 1.5778368e+18, 1.5805152e+18, 1.5830208e+18, 1.5930209e+18
])
ute.null_transformer = NullTransformer('mean')
ute._min_value = 1.5778368e+18
ute._max_value = 1.5830208e+18

# Run
output = ute._reverse_transform(transformed)

# Assert
expected = pd.Series(pd.to_datetime([
'2020-01-01', '2020-01-01', '2020-02-01', '2020-03-01', '2020-03-01'
]))
pd.testing.assert_series_equal(output, expected)

def test__reverse_transform_datetime_format_dtype_is_datetime(self):
"""Test the ``_reverse_transform`` method returns the correct datetime format."""
# Setup
Expand Down Expand Up @@ -464,6 +507,24 @@ def test__reverse_transform_only_nans(self):

class TestOptimizedTimestampEncoder:

def test___init__(self):
"""Test the ``__init__`` method."""
# Run
transformer = OptimizedTimestampEncoder(
missing_value_replacement='mode',
missing_value_generation='from_column',
datetime_format='%M-%d-%Y',
enforce_min_max_values=True,
)

# Asserts
assert transformer.enforce_min_max_values is True
assert transformer.missing_value_replacement == 'mode'
assert transformer.missing_value_generation == 'from_column'
assert transformer.datetime_format == '%M-%d-%Y'
assert transformer.divider is None
assert transformer.null_transformer is None

def test__find_divider(self):
"""Test the ``_find_divider`` method.

Expand Down
Loading