Skip to content

Commit

Permalink
Update update_transformers validation (#563)
Browse files Browse the repository at this point in the history
* Add sdtype validation

* Add get_name method + rename error messag

* Fix lint

* Change error name

* Fix one liner
  • Loading branch information
fealho authored Oct 13, 2022
1 parent 601b678 commit 9b89282
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 39 deletions.
4 changes: 4 additions & 0 deletions rdt/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ class NotFittedError(Exception):

class Error(Exception):
"""Error to raise when ``HyperTransformer`` produces a controlled error message."""


class TransformerInputError(Exception):
"""Error to raise when ``HyperTransformer`` receives incorrect input."""
16 changes: 6 additions & 10 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
import yaml

from rdt.errors import Error, NotFittedError
from rdt.errors import Error, NotFittedError, TransformerInputError
from rdt.transformers import (
BaseTransformer, get_class_by_transformer_name, get_default_transformer,
get_transformer_instance, get_transformers_by_type)
Expand Down Expand Up @@ -414,21 +414,17 @@ def update_transformers(self, column_name_to_transformer):
self._validate_update_columns(update_columns)
self._validate_transformers(column_name_to_transformer)

incompatible_sdtypes = []
for column_name, transformer in column_name_to_transformer.items():
if transformer is not None:
current_sdtype = self.field_sdtypes.get(column_name)
if current_sdtype and current_sdtype not in transformer.get_supported_sdtypes():
incompatible_sdtypes.append(column_name)
raise TransformerInputError(
f"Column '{column_name}' is a {current_sdtype} column, which is "
f"incompatible with the '{transformer.get_name()}' transformer."
)

self.field_transformers[column_name] = transformer

if incompatible_sdtypes:
warnings.warn(
"Some transformers you've assigned are not compatible with the sdtypes. "
f"Use 'update_sdtypes' to update: {incompatible_sdtypes}"
)

self._modified_config = True

def remove_transformers(self, column_names):
Expand Down Expand Up @@ -575,7 +571,7 @@ def _get_transformer_tree_yaml(self):
"""
modified_tree = deepcopy(self._transformers_tree)
for field in modified_tree:
class_name = modified_tree[field]['transformer'].__class__.__name__
class_name = modified_tree[field]['transformer'].__class__.get_name()
modified_tree[field]['transformer'] = class_name

return yaml.safe_dump(dict(modified_tree))
Expand Down
4 changes: 2 additions & 2 deletions rdt/performance/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def evaluate_transformer_performance(transformer, dataset_generator, verbose=Fal
pandas.DataFrame:
The performance test results.
"""
transformer_args = TRANSFORMER_ARGS.get(transformer.__name__, {})
transformer_args = TRANSFORMER_ARGS.get(transformer.get_name(), {})
transformer_instance = transformer(**transformer_args)

sizes = _get_dataset_sizes(dataset_generator.SDTYPE)
Expand All @@ -102,7 +102,7 @@ def evaluate_transformer_performance(transformer, dataset_generator, verbose=Fal
performance['Number of fit rows'] = fit_size
performance['Number of transform rows'] = transform_size
performance['Dataset'] = dataset_generator.__name__
performance['Transformer'] = f'{transformer.__module__ }.{transformer.__name__}'
performance['Transformer'] = f'{transformer.__module__ }.{transformer.get_name()}'

out.append(performance)

Expand Down
4 changes: 2 additions & 2 deletions rdt/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_transformer_name(transformer):
The path of the transformer.
"""
if inspect.isclass(transformer):
return transformer.__module__ + '.' + transformer.__name__
return transformer.__module__ + '.' + transformer.get_name()

raise ValueError(f'The transformer {transformer} must be passed as a class.')

Expand Down Expand Up @@ -106,7 +106,7 @@ def get_class_by_transformer_name():
BaseTransformer:
BaseTransformer subclass class object.
"""
return {class_.__name__: class_ for class_ in BaseTransformer.get_subclasses()}
return {class_.get_name(): class_ for class_ in BaseTransformer.get_subclasses()}


def get_transformer_class(transformer):
Expand Down
12 changes: 11 additions & 1 deletion rdt/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ class BaseTransformer:
column_prefix = None
output_columns = None

@classmethod
def get_name(cls):
"""Return transformer name.
Returns:
str:
Transformer name.
"""
return cls.__name__

@classmethod
def get_subclasses(cls):
"""Recursively find subclasses of this Baseline.
Expand Down Expand Up @@ -208,7 +218,7 @@ def __repr__(self):
str:
The name of the transformer followed by any non-default parameters.
"""
class_name = self.__class__.__name__
class_name = self.__class__.get_name()
custom_args = []
args = inspect.getfullargspec(self.__init__)
keys = args.args[1:]
Expand Down
2 changes: 1 addition & 1 deletion rdt/transformers/pii/anonymizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __repr__(self):
str:
The name of the transformer followed by any non-default parameters.
"""
class_name = self.__class__.__name__
class_name = self.__class__.get_name()
custom_args = []
args = inspect.getfullargspec(self.__init__)
keys = args.args[1:]
Expand Down
6 changes: 3 additions & 3 deletions tests/code_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def validate_transformer_addon(transformer):
module_py = True
elif document.match('config.json'):
config_json_exist = True
_validate_config_json(document, transformer.__name__)
_validate_config_json(document, transformer.get_name())

assert init_file_exist, 'Missing __init__.py file within the addon folder.'
assert config_json_exist, 'Missing the config.json file within the addon folder.'
Expand All @@ -85,7 +85,7 @@ def validate_transformer_addon(transformer):

def validate_transformer_importable_from_parent_module(transformer):
"""Validate wheter the transformer can be imported from the parent module."""
name = transformer.__name__
name = transformer.get_name()
module = getattr(transformer, '__module__', '')
module = module.rsplit('.', 1)[0]
imported_transformer = getattr(importlib.import_module(module), name, None)
Expand Down Expand Up @@ -156,7 +156,7 @@ def validate_test_names(transformer):
test_file = get_test_location(transformer)
module = _load_module_from_path(test_file)

test_class = getattr(module, f'Test{transformer.__name__}', None)
test_class = getattr(module, f'Test{transformer.get_name()}', None)
assert test_class is not None, 'The expected test class was not found.'

test_functions = inspect.getmembers(test_class, predicate=inspect.isfunction)
Expand Down
16 changes: 8 additions & 8 deletions tests/contributing.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def validate_transformer_integration(transformer):
if isinstance(transformer, str):
transformer = get_transformer_class(transformer)

print(f'Validating Integration Tests for transformer {transformer.__name__}\n')
print(f'Validating Integration Tests for transformer {transformer.get_name()}\n')

steps = []
validation_error = None
Expand Down Expand Up @@ -384,14 +384,14 @@ def validate_transformer_quality(transformer):
if isinstance(transformer, str):
transformer = get_transformer_class(transformer)

print(f'Validating Quality Tests for transformer {transformer.__name__}\n')
print(f'Validating Quality Tests for transformer {transformer.get_name()}\n')

input_sdtype = transformer.get_input_sdtype()
test_cases = get_test_cases({input_sdtype})
regression_scores = get_regression_scores(test_cases, get_transformers_by_type())
results = get_results_table(regression_scores)

transformer_results = results[results['transformer_name'] == transformer.__name__]
transformer_results = results[results['transformer_name'] == transformer.get_name()]
transformer_results = transformer_results.drop('transformer_name', axis=1)
transformer_results['Acceptable'] = False
passing_relative_scores = transformer_results['score_relative_to_average'] > TEST_THRESHOLD
Expand Down Expand Up @@ -430,7 +430,7 @@ def validate_transformer_performance(transformer):
if isinstance(transformer, str):
transformer = get_transformer_class(transformer)

print(f'Validating Performance for transformer {transformer.__name__}\n')
print(f'Validating Performance for transformer {transformer.get_name()}\n')

sdtype = transformer.get_input_sdtype()
transformers = get_transformers_by_type().get(sdtype, [])
Expand All @@ -445,8 +445,8 @@ def validate_transformer_performance(transformer):
results = pd.DataFrame({
'Value': performance.to_numpy(),
'Valid': valid,
'transformer': current_transformer.__name__,
'dataset': dataset_generator.__name__,
'transformer': current_transformer.get_name(),
'dataset': dataset_generator.get_name(),
})
results['Evaluation Metric'] = performance.index
total_results = total_results.append(results)
Expand All @@ -456,10 +456,10 @@ def validate_transformer_performance(transformer):
else:
print('ERROR: One or more Performance Tests were NOT successful.')

other_results = total_results[total_results.transformer != transformer.__name__]
other_results = total_results[total_results.transformer != transformer.get_name()]
average = other_results.groupby('Evaluation Metric')['Value'].mean()

total_results = total_results[total_results.transformer == transformer.__name__]
total_results = total_results[total_results.transformer == transformer.get_name()]
final_results = total_results.groupby('Evaluation Metric').agg({
'Value': 'mean',
'Valid': 'any'
Expand Down
11 changes: 9 additions & 2 deletions tests/integration/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,15 @@ def _test_transformer_with_hypertransformer(transformer_class, input_data, steps
TEST_COL: transformer_class
}

hypertransformer.detect_initial_config(input_data)
hypertransformer.update_transformers(field_transformers)
sdtypes = {}
for field, transformer in field_transformers.items():
sdtypes[field] = transformer.get_supported_sdtypes()[0]

config = {
'sdtypes': sdtypes,
'transformers': field_transformers
}
hypertransformer.set_config(config)
hypertransformer.fit(input_data)

transformed = hypertransformer.transform(input_data)
Expand Down
2 changes: 1 addition & 1 deletion tests/quality/test_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_transformer_regression_scores(data, sdtype, dataset_name, transformers,
transformed_features = transformed_features[~nans]
score = get_regression_score(transformed_features, target)
row = pd.Series({
'transformer_name': transformer.__name__,
'transformer_name': transformer.get_name(),
'dataset_name': dataset_name,
'column': column,
'score': score
Expand Down
17 changes: 8 additions & 9 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest

from rdt import HyperTransformer
from rdt.errors import Error, NotFittedError
from rdt.errors import Error, NotFittedError, TransformerInputError
from rdt.transformers import (
AnonymizedFaker, BinaryEncoder, FloatFormatter, FrequencyEncoder, GaussianNormalizer,
LabelEncoder, OneHotEncoder, RegexGenerator, UnixTimestampEncoder)
Expand Down Expand Up @@ -2484,16 +2484,15 @@ def test_update_transformers_missmatch_sdtypes(self, mock_warnings):
'my_column': transformer
}

# Run
instance.update_transformers(column_name_to_transformer)

# Assert
expected_call = (
"Some transformers you've assigned are not compatible with the sdtypes. "
f"Use 'update_sdtypes' to update: {'my_column'}"
# Run and Assert
err_msg = re.escape(
"Column 'my_column' is a categorical column, which is incompatible "
"with the 'BinaryEncoder' transformer."
)
with pytest.raises(TransformerInputError, match=err_msg):
instance.update_transformers(column_name_to_transformer)

assert mock_warnings.called_once_with(expected_call)
assert mock_warnings.called_once_with(err_msg)
instance._validate_transformers.assert_called_once_with(column_name_to_transformer)

def test_update_transformers_transformer_is_none(self):
Expand Down

0 comments on commit 9b89282

Please sign in to comment.