Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create ways for HyperTransformer to know which transformers to apply to each data type #232 #239

Merged
merged 9 commits into from
Sep 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ class HyperTransformer:
'b': 'boolean',
'M': 'datetime',
}
_DTYPES_TO_DATA_TYPES = {
'i': 'integer',
'f': 'float',
'O': 'categorical',
'b': 'boolean',
'M': 'datetime',
}

def __init__(self, transformers=None, copy=True, dtypes=None, dtype_transformers=None):
self.transformers = transformers
Expand Down
64 changes: 63 additions & 1 deletion rdt/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Transformers module."""

from collections import defaultdict
from functools import lru_cache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if functools is part of the standard library. If it is not, we should add it to setup.py, even if we already install it because one of our dependencies use it, so if they remove it in the future we do not crash.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is part of the standard library


from rdt.transformers.base import BaseTransformer
from rdt.transformers.boolean import BooleanTransformer
from rdt.transformers.categorical import (
Expand All @@ -23,7 +26,15 @@

TRANSFORMERS = {
transformer.__name__: transformer
for transformer in BaseTransformer.__subclasses__()
for transformer in BaseTransformer.get_subclasses()
}
DEFAULT_TRANSFORMERS = {
'numerical': NumericalTransformer,
'integer': NumericalTransformer(dtype=int),
'float': NumericalTransformer(dtype=float),
'categorical': CategoricalTransformer(fuzzy=True),
'boolean': BooleanTransformer,
'datetime': DatetimeTransformer,
}


Expand Down Expand Up @@ -77,3 +88,54 @@ def load_transformers(transformers):
name: load_transformer(transformer)
for name, transformer in transformers.items()
}


def get_transformers_by_type():
"""Build a ``dict`` mapping data types to valid existing transformers for that type.

Returns:
dict:
Mapping of data types to a list of existing transformers that take that
type as an input.
"""
data_type_transformers = defaultdict(list)
transformer_classes = BaseTransformer.get_subclasses()
for transformer in transformer_classes:
try:
input_type = transformer.get_input_type()
data_type_transformers[input_type].append(transformer)
except AttributeError:
pass

return data_type_transformers


@lru_cache()
def get_default_transformers():
"""Build a ``dict`` mapping data types to a default transformer for that type.

Returns:
dict:
Mapping of data types to a transformer.
"""
transformers_by_type = get_transformers_by_type()
defaults = {}
for (data_type, transformers) in transformers_by_type.items():
if data_type in DEFAULT_TRANSFORMERS:
defaults[data_type] = DEFAULT_TRANSFORMERS[data_type]
else:
defaults[data_type] = transformers[0]

return defaults


@lru_cache()
def get_default_transformer(data_type):
"""Get default transformer for a data type.

Returns:
Transformer:
Default transformer for data type.
"""
default_transformers = get_default_transformers()
return default_transformers[data_type]
18 changes: 18 additions & 0 deletions rdt/transformers/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""BaseTransformer module."""
import abc


class BaseTransformer:
Expand All @@ -20,6 +21,23 @@ class BaseTransformer:
_column_prefix = None
_output_columns = None

@classmethod
def get_subclasses(cls):
"""Recursively find subclasses of this Baseline.

Returns:
list:
List of all subclasses of this class.
"""
subclasses = []
for subclass in cls.__subclasses__():
if abc.ABC not in subclass.__bases__:
subclasses.append(subclass)

subclasses += subclass.get_subclasses()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a blank line above this one?


return subclasses

@classmethod
def get_input_type(cls):
"""Return the input type supported by the transformer.
Expand Down