diff --git a/rdt/hyper_transformer.py b/rdt/hyper_transformer.py index 12e0271ef..67629bb2c 100644 --- a/rdt/hyper_transformer.py +++ b/rdt/hyper_transformer.py @@ -73,6 +73,13 @@ class HyperTransformer: 'b': 'boolean', 'M': 'datetime', } + _DTYPES_TO_DATA_TYPES = { + 'i': 'integer', + 'f': 'float', + 'O': 'categorical', + 'b': 'boolean', + 'M': 'datetime', + } def __init__(self, transformers=None, copy=True, dtypes=None, dtype_transformers=None): self.transformers = transformers diff --git a/rdt/transformers/__init__.py b/rdt/transformers/__init__.py index 33fd2f531..7327e0dbe 100644 --- a/rdt/transformers/__init__.py +++ b/rdt/transformers/__init__.py @@ -1,5 +1,8 @@ """Transformers module.""" +from collections import defaultdict +from functools import lru_cache + from rdt.transformers.base import BaseTransformer from rdt.transformers.boolean import BooleanTransformer from rdt.transformers.categorical import ( @@ -23,7 +26,15 @@ TRANSFORMERS = { transformer.__name__: transformer - for transformer in BaseTransformer.__subclasses__() + for transformer in BaseTransformer.get_subclasses() +} +DEFAULT_TRANSFORMERS = { + 'numerical': NumericalTransformer, + 'integer': NumericalTransformer(dtype=int), + 'float': NumericalTransformer(dtype=float), + 'categorical': CategoricalTransformer(fuzzy=True), + 'boolean': BooleanTransformer, + 'datetime': DatetimeTransformer, } @@ -77,3 +88,54 @@ def load_transformers(transformers): name: load_transformer(transformer) for name, transformer in transformers.items() } + + +def get_transformers_by_type(): + """Build a ``dict`` mapping data types to valid existing transformers for that type. + + Returns: + dict: + Mapping of data types to a list of existing transformers that take that + type as an input. + """ + data_type_transformers = defaultdict(list) + transformer_classes = BaseTransformer.get_subclasses() + for transformer in transformer_classes: + try: + input_type = transformer.get_input_type() + data_type_transformers[input_type].append(transformer) + except AttributeError: + pass + + return data_type_transformers + + +@lru_cache() +def get_default_transformers(): + """Build a ``dict`` mapping data types to a default transformer for that type. + + Returns: + dict: + Mapping of data types to a transformer. + """ + transformers_by_type = get_transformers_by_type() + defaults = {} + for (data_type, transformers) in transformers_by_type.items(): + if data_type in DEFAULT_TRANSFORMERS: + defaults[data_type] = DEFAULT_TRANSFORMERS[data_type] + else: + defaults[data_type] = transformers[0] + + return defaults + + +@lru_cache() +def get_default_transformer(data_type): + """Get default transformer for a data type. + + Returns: + Transformer: + Default transformer for data type. + """ + default_transformers = get_default_transformers() + return default_transformers[data_type] diff --git a/rdt/transformers/base.py b/rdt/transformers/base.py index a70ce0cb4..1102a932a 100644 --- a/rdt/transformers/base.py +++ b/rdt/transformers/base.py @@ -1,4 +1,5 @@ """BaseTransformer module.""" +import abc class BaseTransformer: @@ -20,6 +21,23 @@ class BaseTransformer: column_prefix = None output_columns = None + @classmethod + def get_subclasses(cls): + """Recursively find subclasses of this Baseline. + + Returns: + list: + List of all subclasses of this class. + """ + subclasses = [] + for subclass in cls.__subclasses__(): + if abc.ABC not in subclass.__bases__: + subclasses.append(subclass) + + subclasses += subclass.get_subclasses() + + return subclasses + @classmethod def get_input_type(cls): """Return the input type supported by the transformer.