Skip to content

Commit

Permalink
Create ways for HyperTransformer to know which transformers to apply …
Browse files Browse the repository at this point in the history
…to each data type #232 (#239)

* adding get_transformers_by_type function

* adding other attributes and fixing typo

* pr comments

* adding default transformers method

* pr comments

* adding caching and some cleanup
  • Loading branch information
amontanez24 authored and csala committed Oct 13, 2021
1 parent 296a898 commit 963b1eb
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 1 deletion.
7 changes: 7 additions & 0 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ class HyperTransformer:
'b': 'boolean',
'M': 'datetime',
}
_DTYPES_TO_DATA_TYPES = {
'i': 'integer',
'f': 'float',
'O': 'categorical',
'b': 'boolean',
'M': 'datetime',
}

def __init__(self, transformers=None, copy=True, dtypes=None, dtype_transformers=None):
self.transformers = transformers
Expand Down
64 changes: 63 additions & 1 deletion rdt/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Transformers module."""

from collections import defaultdict
from functools import lru_cache

from rdt.transformers.base import BaseTransformer
from rdt.transformers.boolean import BooleanTransformer
from rdt.transformers.categorical import (
Expand All @@ -23,7 +26,15 @@

TRANSFORMERS = {
transformer.__name__: transformer
for transformer in BaseTransformer.__subclasses__()
for transformer in BaseTransformer.get_subclasses()
}
DEFAULT_TRANSFORMERS = {
'numerical': NumericalTransformer,
'integer': NumericalTransformer(dtype=int),
'float': NumericalTransformer(dtype=float),
'categorical': CategoricalTransformer(fuzzy=True),
'boolean': BooleanTransformer,
'datetime': DatetimeTransformer,
}


Expand Down Expand Up @@ -77,3 +88,54 @@ def load_transformers(transformers):
name: load_transformer(transformer)
for name, transformer in transformers.items()
}


def get_transformers_by_type():
"""Build a ``dict`` mapping data types to valid existing transformers for that type.
Returns:
dict:
Mapping of data types to a list of existing transformers that take that
type as an input.
"""
data_type_transformers = defaultdict(list)
transformer_classes = BaseTransformer.get_subclasses()
for transformer in transformer_classes:
try:
input_type = transformer.get_input_type()
data_type_transformers[input_type].append(transformer)
except AttributeError:
pass

return data_type_transformers


@lru_cache()
def get_default_transformers():
"""Build a ``dict`` mapping data types to a default transformer for that type.
Returns:
dict:
Mapping of data types to a transformer.
"""
transformers_by_type = get_transformers_by_type()
defaults = {}
for (data_type, transformers) in transformers_by_type.items():
if data_type in DEFAULT_TRANSFORMERS:
defaults[data_type] = DEFAULT_TRANSFORMERS[data_type]
else:
defaults[data_type] = transformers[0]

return defaults


@lru_cache()
def get_default_transformer(data_type):
"""Get default transformer for a data type.
Returns:
Transformer:
Default transformer for data type.
"""
default_transformers = get_default_transformers()
return default_transformers[data_type]
18 changes: 18 additions & 0 deletions rdt/transformers/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""BaseTransformer module."""
import abc


class BaseTransformer:
Expand All @@ -20,6 +21,23 @@ class BaseTransformer:
column_prefix = None
output_columns = None

@classmethod
def get_subclasses(cls):
"""Recursively find subclasses of this Baseline.
Returns:
list:
List of all subclasses of this class.
"""
subclasses = []
for subclass in cls.__subclasses__():
if abc.ABC not in subclass.__bases__:
subclasses.append(subclass)

subclasses += subclass.get_subclasses()

return subclasses

@classmethod
def get_input_type(cls):
"""Return the input type supported by the transformer.
Expand Down

0 comments on commit 963b1eb

Please sign in to comment.