-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multi-column transformers #692
Changes from all commits
febb95b
b0ad789
44f5cc1
18d48ad
43b2fa0
be30d90
054d9eb
afbf35d
acfd2c8
9153334
ba302ba
83af68f
e8e2d9d
8f88b1e
c0d6294
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
from rdt.transformers import ( | ||
BaseTransformer, get_class_by_transformer_name, get_default_transformer, | ||
get_transformers_by_type) | ||
from rdt.transformers.utils import flatten_column_list | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
@@ -25,7 +26,7 @@ def __repr__(self): | |
"""Pretty print the dictionary.""" | ||
config = { | ||
'sdtypes': self['sdtypes'], | ||
'transformers': {k: repr(v) for k, v in self['transformers'].items()} | ||
'transformers': {str(k): repr(v) for k, v in self['transformers'].items()} | ||
} | ||
|
||
printed = json.dumps(config, indent=4) | ||
|
@@ -115,7 +116,7 @@ def __init__(self): | |
self._specified_fields = set() | ||
self._validate_field_transformers() | ||
self._valid_output_sdtypes = self._DEFAULT_OUTPUT_SDTYPES | ||
self._multi_column_fields = self._create_multi_column_fields() | ||
self._multi_column_fields = {} | ||
self._transformers_sequence = [] | ||
self._output_columns = [] | ||
self._input_columns = [] | ||
|
@@ -205,7 +206,18 @@ def _validate_config(config): | |
|
||
sdtypes = config['sdtypes'] | ||
transformers = config['transformers'] | ||
if set(sdtypes.keys()) != set(transformers.keys()): | ||
|
||
sdtype_keys = sdtypes.keys() | ||
transformer_keys = flatten_column_list(transformers.keys()) | ||
|
||
is_transformer_keys_unique = len(transformer_keys) == len(set(transformer_keys)) | ||
if not is_transformer_keys_unique: | ||
raise InvalidConfigError( | ||
'Error: Invalid config. Please provide unique keys for the sdtypes ' | ||
'and transformers.' | ||
) | ||
|
||
if set(sdtype_keys) != set(transformer_keys): | ||
raise InvalidConfigError( | ||
"The column names in the 'sdtypes' dictionary must match the " | ||
"column names in the 'transformers' dictionary." | ||
|
@@ -216,10 +228,14 @@ def _validate_config(config): | |
|
||
mismatched_columns = [] | ||
for column_name, transformer in transformers.items(): | ||
if transformer is not None: | ||
sdtype = sdtypes.get(column_name) | ||
if transformer is None: | ||
continue | ||
|
||
columns = column_name if isinstance(column_name, tuple) else [column_name] | ||
for column in columns: | ||
sdtype = sdtypes.get(column) | ||
if sdtype not in transformer.get_supported_sdtypes(): | ||
mismatched_columns.append(column_name) | ||
mismatched_columns.append(column) | ||
|
||
if mismatched_columns: | ||
raise InvalidConfigError( | ||
|
@@ -519,6 +535,19 @@ def detect_initial_config(self, data): | |
LOGGER.info('Config:') | ||
LOGGER.info(str(config)) | ||
|
||
def _get_columns_to_sdtypes(self, field): | ||
"""Generate the ``columns_to_sdtypes`` dict for the given field. | ||
|
||
Args: | ||
field (tuple): | ||
Names of the column for the multi column trnasformer. | ||
""" | ||
columns_to_sdtypes = {} | ||
for column in field: | ||
columns_to_sdtypes[column] = self.field_sdtypes[column] | ||
|
||
return columns_to_sdtypes | ||
|
||
def _fit_field_transformer(self, data, field, transformer): | ||
"""Fit a transformer to its corresponding field. | ||
|
||
|
@@ -541,7 +570,12 @@ def _fit_field_transformer(self, data, field, transformer): | |
self._output_columns.append(field) | ||
|
||
else: | ||
transformer.fit(data, field) | ||
if isinstance(field, tuple): | ||
columns_to_sdtypes = self._get_columns_to_sdtypes(field) | ||
transformer.fit(data, columns_to_sdtypes) | ||
else: | ||
transformer.fit(data, field) | ||
|
||
self._transformers_sequence.append(transformer) | ||
data = transformer.transform(data) | ||
|
||
|
@@ -603,8 +637,8 @@ def fit(self, data): | |
self._validate_detect_config_called(data) | ||
self._unfit() | ||
self._input_columns = list(data.columns) | ||
for field in self._input_columns: | ||
data = self._fit_field_transformer(data, field, self.field_transformers[field]) | ||
for field_column, field_transformer in self.field_transformers.items(): | ||
data = self._fit_field_transformer(data, field_column, field_transformer) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is going to work alone. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the multi column case we pass the columns as a tuple or as a dict? I supposed it was a tuple and so it was good for me based on the docstring of
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the base class you made fit takes in data and a dictionary. It needs to be a dict otherwise the values will always have to be passed in a specific order There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @R-Palazzo I think this is still the case. The issue states that the |
||
|
||
self._validate_all_fields_fitted() | ||
self._fitted = True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Break line above would be appreciated.