From f242a3b5f3e92227462c894bbc021a589c6e2cad Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 21 Aug 2023 11:16:55 +0200 Subject: [PATCH] use get_supported_sdtypes --- rdt/transformers/__init__.py | 2 +- rdt/transformers/base.py | 6 +++++- tests/contributing.py | 2 +- tests/integration/test_transformers.py | 6 +++--- tests/unit/test_hyper_transformer.py | 2 +- tests/unit/transformers/test_base.py | 11 ++++++++--- 6 files changed, 19 insertions(+), 10 deletions(-) diff --git a/rdt/transformers/__init__.py b/rdt/transformers/__init__.py index af10c11bb..75a8b5dd9 100644 --- a/rdt/transformers/__init__.py +++ b/rdt/transformers/__init__.py @@ -145,7 +145,7 @@ def get_transformers_by_type(): sdtype_transformers = defaultdict(list) transformer_classes = BaseTransformer.get_subclasses() for transformer in transformer_classes: - input_sdtype = transformer.get_input_sdtype() + input_sdtype = transformer.get_supported_sdtypes()[0] sdtype_transformers[input_sdtype].append(transformer) return sdtype_transformers diff --git a/rdt/transformers/base.py b/rdt/transformers/base.py index fed362a5c..f1e28ea4e 100644 --- a/rdt/transformers/base.py +++ b/rdt/transformers/base.py @@ -184,7 +184,11 @@ def get_input_sdtype(cls): string: Accepted input sdtype of the transformer. """ - return cls.INPUT_SDTYPE + warnings.warn( + '``get_input_sdtype`` is deprecated. Please use ``get_supported_sdtypes`` instead.', + FutureWarning + ) + return cls.get_supported_sdtypes() @classmethod def get_supported_sdtypes(cls): diff --git a/tests/contributing.py b/tests/contributing.py index 5bedd1184..4e8586844 100644 --- a/tests/contributing.py +++ b/tests/contributing.py @@ -365,7 +365,7 @@ def validate_transformer_performance(transformer): print(f'Validating Performance for transformer {transformer.get_name()}\n') - sdtype = transformer.get_input_sdtype() + sdtype = transformer.get_supported_sdtypes()[0] transformers = get_transformers_by_type().get(sdtype, []) dataset_generators = get_dataset_generators_by_type().get(sdtype, []) diff --git a/tests/integration/test_transformers.py b/tests/integration/test_transformers.py index 4ed5021df..8a1cf1307 100644 --- a/tests/integration/test_transformers.py +++ b/tests/integration/test_transformers.py @@ -130,7 +130,7 @@ def _validate_reverse_transformed_data(transformer, reversed_data, input_dtype): Expect that the dtype is equal to the dtype of the input data. """ - expected_sdtype = transformer.get_input_sdtype() + expected_sdtype = transformer.get_supported_sdtypes()[0] message = f'Reverse transformed data is not the expected sdtype {expected_sdtype}' assert reversed_data.dtypes[TEST_COL].kind in SDTYPE_TO_DTYPES[expected_sdtype], message @@ -181,7 +181,7 @@ def _validate_hypertransformer_transformed_data(transformed_data): def _validate_hypertransformer_reverse_transformed_data(transformer, reversed_data): """Check that the reverse transformed data has the same dtype as the input.""" - expected_sdtype = transformer().get_input_sdtype() + expected_sdtype = transformer().get_supported_sdtypes()[0] message = f'Reversed transformed data is not the expected sdtype {expected_sdtype}' assert reversed_data.dtype.kind in SDTYPE_TO_DTYPES[expected_sdtype], message @@ -250,7 +250,7 @@ def validate_transformer(transformer, steps=None, subtests=None): subtests: Whether or not to test with subtests. """ - input_sdtype = transformer.get_input_sdtype() + input_sdtype = transformer.get_supported_sdtypes()[0] dataset_generators = _find_dataset_generators(input_sdtype, generators) _validate_helper(_validate_dataset_generators, [dataset_generators], steps) diff --git a/tests/unit/test_hyper_transformer.py b/tests/unit/test_hyper_transformer.py index 3add05603..516390a6a 100644 --- a/tests/unit/test_hyper_transformer.py +++ b/tests/unit/test_hyper_transformer.py @@ -2246,7 +2246,7 @@ def test_update_transformers_no_field_transformers(self): instance = HyperTransformer() instance._fitted = False mock_transformer = Mock() - mock_transformer.get_supported_sdtype.return_value = ['datetime'] + mock_transformer.get_supported_sdtypes.return_value = ['datetime'] column_name_to_transformer = { 'my_column': mock_transformer } diff --git a/tests/unit/transformers/test_base.py b/tests/unit/transformers/test_base.py index d666414fc..33de0f9e2 100644 --- a/tests/unit/transformers/test_base.py +++ b/tests/unit/transformers/test_base.py @@ -155,10 +155,10 @@ class Child(Parent): assert Child in subclasses assert Parent not in subclasses - def test_get_input_sdtype(self): + def test_get_input_sdtype_raises_warning(self): """Test the ``get_input_sdtype`` method. - This method should return the value defined in the ``INPUT_SDTYPE`` of the child classes. + This method should raise a FutureWarning and then call ``get_supported_sdtypes_`` method. Setup: - create a ``Dummy`` class which inherits from the ``BaseTransformer``, @@ -172,7 +172,12 @@ class Dummy(BaseTransformer): INPUT_SDTYPE = 'categorical' # Run - input_sdtype = Dummy.get_input_sdtype() + expected_message = ( + '``get_input_sdtype`` is deprecated. Please use ' + '``get_supported_sdtypes`` instead.' + ) + with pytest.warns(FutureWarning, match=expected_message): + input_sdtype = Dummy.get_input_sdtype()[0] # Assert assert input_sdtype == 'categorical'