diff --git a/rdt/errors.py b/rdt/errors.py index d531984d6..9e36119a7 100644 --- a/rdt/errors.py +++ b/rdt/errors.py @@ -7,3 +7,7 @@ class NotFittedError(Exception): class Error(Exception): """Error to raise when ``HyperTransformer`` produces a controlled error message.""" + + +class TransformerInputError(Exception): + """Error to raise when ``HyperTransformer`` receives incorrect input.""" diff --git a/rdt/hyper_transformer.py b/rdt/hyper_transformer.py index 48a8467bd..8020a2def 100644 --- a/rdt/hyper_transformer.py +++ b/rdt/hyper_transformer.py @@ -9,7 +9,7 @@ import pandas as pd import yaml -from rdt.errors import Error, NotFittedError +from rdt.errors import Error, NotFittedError, TransformerInputError from rdt.transformers import ( BaseTransformer, get_class_by_transformer_name, get_default_transformer, get_transformer_instance, get_transformers_by_type) @@ -414,21 +414,17 @@ def update_transformers(self, column_name_to_transformer): self._validate_update_columns(update_columns) self._validate_transformers(column_name_to_transformer) - incompatible_sdtypes = [] for column_name, transformer in column_name_to_transformer.items(): if transformer is not None: current_sdtype = self.field_sdtypes.get(column_name) if current_sdtype and current_sdtype not in transformer.get_supported_sdtypes(): - incompatible_sdtypes.append(column_name) + raise TransformerInputError( + f"Column '{column_name}' is a {current_sdtype} column, which is " + f"incompatible with the '{transformer.get_name()}' transformer." + ) self.field_transformers[column_name] = transformer - if incompatible_sdtypes: - warnings.warn( - "Some transformers you've assigned are not compatible with the sdtypes. " - f"Use 'update_sdtypes' to update: {incompatible_sdtypes}" - ) - self._modified_config = True def remove_transformers(self, column_names): @@ -575,7 +571,7 @@ def _get_transformer_tree_yaml(self): """ modified_tree = deepcopy(self._transformers_tree) for field in modified_tree: - class_name = modified_tree[field]['transformer'].__class__.__name__ + class_name = modified_tree[field]['transformer'].__class__.get_name() modified_tree[field]['transformer'] = class_name return yaml.safe_dump(dict(modified_tree)) diff --git a/rdt/performance/performance.py b/rdt/performance/performance.py index c24d8b3b9..1b10565ac 100644 --- a/rdt/performance/performance.py +++ b/rdt/performance/performance.py @@ -82,7 +82,7 @@ def evaluate_transformer_performance(transformer, dataset_generator, verbose=Fal pandas.DataFrame: The performance test results. """ - transformer_args = TRANSFORMER_ARGS.get(transformer.__name__, {}) + transformer_args = TRANSFORMER_ARGS.get(transformer.get_name(), {}) transformer_instance = transformer(**transformer_args) sizes = _get_dataset_sizes(dataset_generator.SDTYPE) @@ -102,7 +102,7 @@ def evaluate_transformer_performance(transformer, dataset_generator, verbose=Fal performance['Number of fit rows'] = fit_size performance['Number of transform rows'] = transform_size performance['Dataset'] = dataset_generator.__name__ - performance['Transformer'] = f'{transformer.__module__ }.{transformer.__name__}' + performance['Transformer'] = f'{transformer.__module__ }.{transformer.get_name()}' out.append(performance) diff --git a/rdt/transformers/__init__.py b/rdt/transformers/__init__.py index 543c819d8..82faa85cd 100644 --- a/rdt/transformers/__init__.py +++ b/rdt/transformers/__init__.py @@ -75,7 +75,7 @@ def get_transformer_name(transformer): The path of the transformer. """ if inspect.isclass(transformer): - return transformer.__module__ + '.' + transformer.__name__ + return transformer.__module__ + '.' + transformer.get_name() raise ValueError(f'The transformer {transformer} must be passed as a class.') @@ -106,7 +106,7 @@ def get_class_by_transformer_name(): BaseTransformer: BaseTransformer subclass class object. """ - return {class_.__name__: class_ for class_ in BaseTransformer.get_subclasses()} + return {class_.get_name(): class_ for class_ in BaseTransformer.get_subclasses()} def get_transformer_class(transformer): diff --git a/rdt/transformers/base.py b/rdt/transformers/base.py index 79815c3f3..7414280f6 100644 --- a/rdt/transformers/base.py +++ b/rdt/transformers/base.py @@ -26,6 +26,16 @@ class BaseTransformer: column_prefix = None output_columns = None + @classmethod + def get_name(cls): + """Return transformer name. + + Returns: + str: + Transformer name. + """ + return cls.__name__ + @classmethod def get_subclasses(cls): """Recursively find subclasses of this Baseline. @@ -208,7 +218,7 @@ def __repr__(self): str: The name of the transformer followed by any non-default parameters. """ - class_name = self.__class__.__name__ + class_name = self.__class__.get_name() custom_args = [] args = inspect.getfullargspec(self.__init__) keys = args.args[1:] diff --git a/rdt/transformers/pii/anonymizer.py b/rdt/transformers/pii/anonymizer.py index 4a4b14796..284b61bdf 100644 --- a/rdt/transformers/pii/anonymizer.py +++ b/rdt/transformers/pii/anonymizer.py @@ -161,7 +161,7 @@ def __repr__(self): str: The name of the transformer followed by any non-default parameters. """ - class_name = self.__class__.__name__ + class_name = self.__class__.get_name() custom_args = [] args = inspect.getfullargspec(self.__init__) keys = args.args[1:] diff --git a/tests/code_style.py b/tests/code_style.py index d2818f9de..efe213fbe 100644 --- a/tests/code_style.py +++ b/tests/code_style.py @@ -76,7 +76,7 @@ def validate_transformer_addon(transformer): module_py = True elif document.match('config.json'): config_json_exist = True - _validate_config_json(document, transformer.__name__) + _validate_config_json(document, transformer.get_name()) assert init_file_exist, 'Missing __init__.py file within the addon folder.' assert config_json_exist, 'Missing the config.json file within the addon folder.' @@ -85,7 +85,7 @@ def validate_transformer_addon(transformer): def validate_transformer_importable_from_parent_module(transformer): """Validate wheter the transformer can be imported from the parent module.""" - name = transformer.__name__ + name = transformer.get_name() module = getattr(transformer, '__module__', '') module = module.rsplit('.', 1)[0] imported_transformer = getattr(importlib.import_module(module), name, None) @@ -156,7 +156,7 @@ def validate_test_names(transformer): test_file = get_test_location(transformer) module = _load_module_from_path(test_file) - test_class = getattr(module, f'Test{transformer.__name__}', None) + test_class = getattr(module, f'Test{transformer.get_name()}', None) assert test_class is not None, 'The expected test class was not found.' test_functions = inspect.getmembers(test_class, predicate=inspect.isfunction) diff --git a/tests/contributing.py b/tests/contributing.py index 1596e99a3..2d4de4a3a 100644 --- a/tests/contributing.py +++ b/tests/contributing.py @@ -89,7 +89,7 @@ def validate_transformer_integration(transformer): if isinstance(transformer, str): transformer = get_transformer_class(transformer) - print(f'Validating Integration Tests for transformer {transformer.__name__}\n') + print(f'Validating Integration Tests for transformer {transformer.get_name()}\n') steps = [] validation_error = None @@ -384,14 +384,14 @@ def validate_transformer_quality(transformer): if isinstance(transformer, str): transformer = get_transformer_class(transformer) - print(f'Validating Quality Tests for transformer {transformer.__name__}\n') + print(f'Validating Quality Tests for transformer {transformer.get_name()}\n') input_sdtype = transformer.get_input_sdtype() test_cases = get_test_cases({input_sdtype}) regression_scores = get_regression_scores(test_cases, get_transformers_by_type()) results = get_results_table(regression_scores) - transformer_results = results[results['transformer_name'] == transformer.__name__] + transformer_results = results[results['transformer_name'] == transformer.get_name()] transformer_results = transformer_results.drop('transformer_name', axis=1) transformer_results['Acceptable'] = False passing_relative_scores = transformer_results['score_relative_to_average'] > TEST_THRESHOLD @@ -430,7 +430,7 @@ def validate_transformer_performance(transformer): if isinstance(transformer, str): transformer = get_transformer_class(transformer) - print(f'Validating Performance for transformer {transformer.__name__}\n') + print(f'Validating Performance for transformer {transformer.get_name()}\n') sdtype = transformer.get_input_sdtype() transformers = get_transformers_by_type().get(sdtype, []) @@ -445,8 +445,8 @@ def validate_transformer_performance(transformer): results = pd.DataFrame({ 'Value': performance.to_numpy(), 'Valid': valid, - 'transformer': current_transformer.__name__, - 'dataset': dataset_generator.__name__, + 'transformer': current_transformer.get_name(), + 'dataset': dataset_generator.get_name(), }) results['Evaluation Metric'] = performance.index total_results = total_results.append(results) @@ -456,10 +456,10 @@ def validate_transformer_performance(transformer): else: print('ERROR: One or more Performance Tests were NOT successful.') - other_results = total_results[total_results.transformer != transformer.__name__] + other_results = total_results[total_results.transformer != transformer.get_name()] average = other_results.groupby('Evaluation Metric')['Value'].mean() - total_results = total_results[total_results.transformer == transformer.__name__] + total_results = total_results[total_results.transformer == transformer.get_name()] final_results = total_results.groupby('Evaluation Metric').agg({ 'Value': 'mean', 'Valid': 'any' diff --git a/tests/integration/test_transformers.py b/tests/integration/test_transformers.py index 5ae16721a..d4dbaf8d5 100644 --- a/tests/integration/test_transformers.py +++ b/tests/integration/test_transformers.py @@ -254,8 +254,15 @@ def _test_transformer_with_hypertransformer(transformer_class, input_data, steps TEST_COL: transformer_class } - hypertransformer.detect_initial_config(input_data) - hypertransformer.update_transformers(field_transformers) + sdtypes = {} + for field, transformer in field_transformers.items(): + sdtypes[field] = transformer.get_supported_sdtypes()[0] + + config = { + 'sdtypes': sdtypes, + 'transformers': field_transformers + } + hypertransformer.set_config(config) hypertransformer.fit(input_data) transformed = hypertransformer.transform(input_data) diff --git a/tests/quality/test_quality.py b/tests/quality/test_quality.py index 710dbfedc..f7a004be7 100644 --- a/tests/quality/test_quality.py +++ b/tests/quality/test_quality.py @@ -101,7 +101,7 @@ def get_transformer_regression_scores(data, sdtype, dataset_name, transformers, transformed_features = transformed_features[~nans] score = get_regression_score(transformed_features, target) row = pd.Series({ - 'transformer_name': transformer.__name__, + 'transformer_name': transformer.get_name(), 'dataset_name': dataset_name, 'column': column, 'score': score diff --git a/tests/unit/test_hyper_transformer.py b/tests/unit/test_hyper_transformer.py index edd97ebe9..ee33bebd7 100644 --- a/tests/unit/test_hyper_transformer.py +++ b/tests/unit/test_hyper_transformer.py @@ -10,7 +10,7 @@ import pytest from rdt import HyperTransformer -from rdt.errors import Error, NotFittedError +from rdt.errors import Error, NotFittedError, TransformerInputError from rdt.transformers import ( AnonymizedFaker, BinaryEncoder, FloatFormatter, FrequencyEncoder, GaussianNormalizer, LabelEncoder, OneHotEncoder, RegexGenerator, UnixTimestampEncoder) @@ -2484,16 +2484,15 @@ def test_update_transformers_missmatch_sdtypes(self, mock_warnings): 'my_column': transformer } - # Run - instance.update_transformers(column_name_to_transformer) - - # Assert - expected_call = ( - "Some transformers you've assigned are not compatible with the sdtypes. " - f"Use 'update_sdtypes' to update: {'my_column'}" + # Run and Assert + err_msg = re.escape( + "Column 'my_column' is a categorical column, which is incompatible " + "with the 'BinaryEncoder' transformer." ) + with pytest.raises(TransformerInputError, match=err_msg): + instance.update_transformers(column_name_to_transformer) - assert mock_warnings.called_once_with(expected_call) + assert mock_warnings.called_once_with(err_msg) instance._validate_transformers.assert_called_once_with(column_name_to_transformer) def test_update_transformers_transformer_is_none(self):