Skip to content

Commit

Permalink
use get_supported_sdtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Aug 21, 2023
1 parent 931e4a7 commit f242a3b
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 10 deletions.
2 changes: 1 addition & 1 deletion rdt/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def get_transformers_by_type():
sdtype_transformers = defaultdict(list)
transformer_classes = BaseTransformer.get_subclasses()
for transformer in transformer_classes:
input_sdtype = transformer.get_input_sdtype()
input_sdtype = transformer.get_supported_sdtypes()[0]
sdtype_transformers[input_sdtype].append(transformer)

return sdtype_transformers
Expand Down
6 changes: 5 additions & 1 deletion rdt/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,11 @@ def get_input_sdtype(cls):
string:
Accepted input sdtype of the transformer.
"""
return cls.INPUT_SDTYPE
warnings.warn(
'``get_input_sdtype`` is deprecated. Please use ``get_supported_sdtypes`` instead.',
FutureWarning
)
return cls.get_supported_sdtypes()

@classmethod
def get_supported_sdtypes(cls):
Expand Down
2 changes: 1 addition & 1 deletion tests/contributing.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def validate_transformer_performance(transformer):

print(f'Validating Performance for transformer {transformer.get_name()}\n')

sdtype = transformer.get_input_sdtype()
sdtype = transformer.get_supported_sdtypes()[0]
transformers = get_transformers_by_type().get(sdtype, [])
dataset_generators = get_dataset_generators_by_type().get(sdtype, [])

Expand Down
6 changes: 3 additions & 3 deletions tests/integration/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _validate_reverse_transformed_data(transformer, reversed_data, input_dtype):
Expect that the dtype is equal to the dtype of the input data.
"""
expected_sdtype = transformer.get_input_sdtype()
expected_sdtype = transformer.get_supported_sdtypes()[0]
message = f'Reverse transformed data is not the expected sdtype {expected_sdtype}'
assert reversed_data.dtypes[TEST_COL].kind in SDTYPE_TO_DTYPES[expected_sdtype], message

Expand Down Expand Up @@ -181,7 +181,7 @@ def _validate_hypertransformer_transformed_data(transformed_data):

def _validate_hypertransformer_reverse_transformed_data(transformer, reversed_data):
"""Check that the reverse transformed data has the same dtype as the input."""
expected_sdtype = transformer().get_input_sdtype()
expected_sdtype = transformer().get_supported_sdtypes()[0]
message = f'Reversed transformed data is not the expected sdtype {expected_sdtype}'
assert reversed_data.dtype.kind in SDTYPE_TO_DTYPES[expected_sdtype], message

Expand Down Expand Up @@ -250,7 +250,7 @@ def validate_transformer(transformer, steps=None, subtests=None):
subtests:
Whether or not to test with subtests.
"""
input_sdtype = transformer.get_input_sdtype()
input_sdtype = transformer.get_supported_sdtypes()[0]

dataset_generators = _find_dataset_generators(input_sdtype, generators)
_validate_helper(_validate_dataset_generators, [dataset_generators], steps)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2246,7 +2246,7 @@ def test_update_transformers_no_field_transformers(self):
instance = HyperTransformer()
instance._fitted = False
mock_transformer = Mock()
mock_transformer.get_supported_sdtype.return_value = ['datetime']
mock_transformer.get_supported_sdtypes.return_value = ['datetime']
column_name_to_transformer = {
'my_column': mock_transformer
}
Expand Down
11 changes: 8 additions & 3 deletions tests/unit/transformers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ class Child(Parent):
assert Child in subclasses
assert Parent not in subclasses

def test_get_input_sdtype(self):
def test_get_input_sdtype_raises_warning(self):
"""Test the ``get_input_sdtype`` method.
This method should return the value defined in the ``INPUT_SDTYPE`` of the child classes.
This method should raise a FutureWarning and then call ``get_supported_sdtypes_`` method.
Setup:
- create a ``Dummy`` class which inherits from the ``BaseTransformer``,
Expand All @@ -172,7 +172,12 @@ class Dummy(BaseTransformer):
INPUT_SDTYPE = 'categorical'

# Run
input_sdtype = Dummy.get_input_sdtype()
expected_message = (
'``get_input_sdtype`` is deprecated. Please use '
'``get_supported_sdtypes`` instead.'
)
with pytest.warns(FutureWarning, match=expected_message):
input_sdtype = Dummy.get_input_sdtype()[0]

# Assert
assert input_sdtype == 'categorical'
Expand Down

0 comments on commit f242a3b

Please sign in to comment.