Skip to content

Commit

Permalink
HyperTransformer can’t detect UInt or uint (#867)
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo authored Aug 14, 2024
1 parent e15fb59 commit 1ca5c0e
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 1 deletion.
1 change: 1 addition & 0 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class HyperTransformer:
# pylint: disable=too-many-instance-attributes

_DTYPES_TO_SDTYPES = {
'u': 'numerical',
'i': 'numerical',
'f': 'numerical',
'O': 'categorical',
Expand Down
6 changes: 5 additions & 1 deletion rdt/transformers/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ def _reverse_transform(self, data):
data = data.clip(min_bound, max_bound)

is_integer = pd.api.types.is_integer_dtype(self._dtype)
np_integer_with_nans = isinstance(data, np.ndarray) and is_integer and pd.isna(data).any()
np_integer_with_nans = (
not pd.api.types.is_extension_array_dtype(self._dtype)
and is_integer
and pd.isna(data).any()
)
if self.learn_rounding_scheme and self._rounding_digits is not None:
data = data.round(self._rounding_digits)
elif is_integer:
Expand Down
148 changes: 148 additions & 0 deletions tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,3 +2071,151 @@ def _validate_sdtypes(cls, columns_to_sdtype):
expected_multi_columns = {}
assert ht._multi_column_fields == expected_multi_columns
assert repr(new_config) == repr(expected_config)

def test_detect_unsigned_integer_dtypes(self):
"""Test that the HyperTransformer can detect unsigned integer dtypes."""
# Setup
data = pd.DataFrame({
'Int8': pd.Series([1, 2, -3, pd.NA], dtype='Int8'),
'Int16': pd.Series([1, 2, -3, pd.NA], dtype='Int16'),
'Int32': pd.Series([1, 2, -3, pd.NA], dtype='Int32'),
'Int64': pd.Series([1, 2, -3, pd.NA], dtype='Int64'),
'UInt8': pd.Series([1, 2, 3, pd.NA], dtype='UInt8'),
'UInt16': pd.Series([1, 2, 3, pd.NA], dtype='UInt16'),
'UInt32': pd.Series([1, 2, 3, pd.NA], dtype='UInt32'),
'UInt64': pd.Series([1, 2, 3, pd.NA], dtype='UInt64'),
'Float32': pd.Series([1.1, 2.2, 3.3, pd.NA], dtype='Float32'),
'Float64': pd.Series([1.1, 2.2, 3.3, pd.NA], dtype='Float64'),
'uint8': np.array([1, 2, 3, 4], dtype='uint8'),
'uint16': np.array([1, 2, 3, 4], dtype='uint16'),
'uint32': np.array([1, 2, 3, 4], dtype='uint32'),
'uint64': np.array([1, 2, 3, 4], dtype='uint64'),
})
ht = HyperTransformer()

# Run
ht.detect_initial_config(data)

# Assert
config = ht.get_config()
for column_name, sdtype in config['sdtypes'].items():
assert sdtype == 'numerical'
assert config['transformers'][column_name].__class__.__name__ == 'FloatFormatter'

def test_numerical_dtype_handling(self):
"""Test that the HyperTransformer correctly handle all numerical dtypes."""
# Setup
original_data = pd.DataFrame({
'Int8': pd.Series([1, 2, 3, pd.NA], dtype='Int8'),
'Int16': pd.Series([1, 2, 3, pd.NA], dtype='Int16'),
'Int32': pd.Series([1, 2, 3, pd.NA], dtype='Int32'),
'Int64': pd.Series([1, 2, 3, pd.NA], dtype='Int64'),
'UInt8': pd.Series([1, 2, 3, pd.NA], dtype='UInt8'),
'UInt16': pd.Series([1, 2, 3, pd.NA], dtype='UInt16'),
'UInt32': pd.Series([1, 2, 3, pd.NA], dtype='UInt32'),
'UInt64': pd.Series([1, 2, 3, pd.NA], dtype='UInt64'),
'Float32': pd.Series([1.1, 2.2, 3.3, pd.NA], dtype='Float32'),
'Float64': pd.Series([1.1, 2.2, 3.3, pd.NA], dtype='Float64'),
'uint8': np.array([1, 2, 3, 4], dtype='uint8'),
'uint16': np.array([1, 2, 3, 4], dtype='uint16'),
'uint32': np.array([1, 2, 3, 4], dtype='uint32'),
'uint64': np.array([1, 2, 3, 4], dtype='uint64'),
'float': np.array([1.1, 2.2, 3.3, 4.4], dtype='float'),
'int8': np.array([1, 2, 3, 4], dtype='int8'),
'int16': np.array([1, 2, 3, 4], dtype='int16'),
'int32': np.array([1, 2, 3, 4], dtype='int32'),
'int64': np.array([1, 2, 3, 4], dtype='int64'),
})

ht = HyperTransformer()

# Run
ht.detect_initial_config(original_data)
ht.fit(original_data)
transformed_data = ht.transform(original_data)
reverse_transformed_data = ht.reverse_transform(transformed_data)

# Assert
assert transformed_data.dtypes.unique() == 'float'
for column in original_data.columns:
assert reverse_transformed_data[column].dtype == column

def test_numerical_handling_with_nans(self):
"""Test all numerical dtypes handling when there is NaN in the transformed data."""
# Setup
original_data = pd.DataFrame({
'Int8': pd.Series([1, 2, 3, pd.NA], dtype='Int8'),
'Int16': pd.Series([1, 2, 3, pd.NA], dtype='Int16'),
'Int32': pd.Series([1, 2, 3, pd.NA], dtype='Int32'),
'Int64': pd.Series([1, 2, 3, pd.NA], dtype='Int64'),
'UInt8': pd.Series([1, 2, 3, pd.NA], dtype='UInt8'),
'UInt16': pd.Series([1, 2, 3, pd.NA], dtype='UInt16'),
'UInt32': pd.Series([1, 2, 3, pd.NA], dtype='UInt32'),
'UInt64': pd.Series([1, 2, 3, pd.NA], dtype='UInt64'),
'Float32': pd.Series([1.1, 2.2, 3.3, pd.NA], dtype='Float32'),
'Float64': pd.Series([1.1, 2.2, 3.3, pd.NA], dtype='Float64'),
'uint8': np.array([1, 2, 3, 4], dtype='uint8'),
'uint16': np.array([1, 2, 3, 4], dtype='uint16'),
'uint32': np.array([1, 2, 3, 4], dtype='uint32'),
'uint64': np.array([1, 2, 3, 4], dtype='uint64'),
'float': np.array([1.1, 2.2, 3.3, 4.4], dtype='float'),
'int8': np.array([1, 2, 3, 4], dtype='int8'),
'int16': np.array([1, 2, 3, 4], dtype='int16'),
'int32': np.array([1, 2, 3, 4], dtype='int32'),
'int64': np.array([1, 2, 3, 4], dtype='int64'),
})

data_with_nans = pd.DataFrame({
'Int8': [1.1, 2.2, 3.3, np.nan],
'Int16': [1.1, 2.2, 3.3, np.nan],
'Int32': [1.1, 2.2, 3.3, np.nan],
'Int64': [1.1, 2.2, 3.3, np.nan],
'UInt8': [1.1, 2.2, 3.3, np.nan],
'UInt16': [1.1, 2.2, 3.3, np.nan],
'UInt32': [1.1, 2.2, 3.3, np.nan],
'UInt64': [1.1, 2.2, 3.3, np.nan],
'Float32': [1.1, 2.2, 3.3, np.nan],
'Float64': [1.1, 2.2, 3.3, np.nan],
'uint8': [1.1, 2.2, 3.3, np.nan],
'uint16': [1.1, 2.2, 3.3, np.nan],
'uint32': [1.1, 2.2, 3.3, np.nan],
'uint64': [1.1, 2.2, 3.3, np.nan],
'float': [1.1, 2.2, 3.3, np.nan],
'int8': [1.1, 2.2, 3.3, np.nan],
'int16': [1.1, 2.2, 3.3, np.nan],
'int32': [1.1, 2.2, 3.3, np.nan],
'int64': [1.1, 2.2, 3.3, np.nan],
})

ht = HyperTransformer()
ht.detect_initial_config(original_data)
ht.fit(original_data)

# Run
reverse_transformed_data = ht.reverse_transform(data_with_nans)

# Assert
expected_output_dtypes = {
'Int8': 'Int8',
'Int16': 'Int16',
'Int32': 'Int32',
'Int64': 'Int64',
'UInt8': 'UInt8',
'UInt16': 'UInt16',
'UInt32': 'UInt32',
'UInt64': 'UInt64',
'Float32': 'Float32',
'Float64': 'Float64',
'uint8': 'float',
'uint16': 'float',
'uint32': 'float',
'uint64': 'float',
'float': 'float',
'int8': 'float',
'int16': 'float',
'int32': 'float',
'int64': 'float',
}
assert data_with_nans.dtypes.unique() == 'float'
for column_name, expected_dtype in expected_output_dtypes.items():
assert reverse_transformed_data[column_name].dtype == expected_dtype
28 changes: 28 additions & 0 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,34 @@ def test___init__(self, validation_mock):
assert ht._modified_config is False
validation_mock.assert_called_once()

def test__set_field_sdtype_numerical(self):
"""Test the ``_set_field_sdtype`` method for numerical data."""
# Setup
data = pd.DataFrame({
'Int8': pd.Series([1, 2, -3, pd.NA], dtype='Int8'),
'Int16': pd.Series([1, 2, -3, pd.NA], dtype='Int16'),
'Int32': pd.Series([1, 2, -3, pd.NA], dtype='Int32'),
'Int64': pd.Series([1, 2, -3, pd.NA], dtype='Int64'),
'UInt8': pd.Series([1, 2, 3, pd.NA], dtype='UInt8'),
'UInt16': pd.Series([1, 2, 3, pd.NA], dtype='UInt16'),
'UInt32': pd.Series([1, 2, 3, pd.NA], dtype='UInt32'),
'UInt64': pd.Series([1, 2, 3, pd.NA], dtype='UInt64'),
'Float32': pd.Series([1.1, 2.2, 3.3, pd.NA], dtype='Float32'),
'Float64': pd.Series([1.1, 2.2, 3.3, pd.NA], dtype='Float64'),
'uint8': np.array([1, 2, 3, 4], dtype='uint8'),
'uint16': np.array([1, 2, 3, 4], dtype='uint16'),
'uint32': np.array([1, 2, 3, 4], dtype='uint32'),
'uint64': np.array([1, 2, 3, 4], dtype='uint64'),
})
ht = HyperTransformer()

# Run
for column in data.columns:
ht._set_field_sdtype(data, column)

# Assert
assert ht.field_sdtypes == {column: 'numerical' for column in data.columns}

def test__unfit(self):
"""Test the ``_unfit`` method.
Expand Down

0 comments on commit 1ca5c0e

Please sign in to comment.