Skip to content

Commit

Permalink
adding caching and some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
amontanez24 committed Sep 24, 2021
1 parent ad3cdb5 commit dbb1162
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
24 changes: 19 additions & 5 deletions rdt/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Transformers module."""

from collections import defaultdict
from functools import lru_cache

from rdt.transformers.base import BaseTransformer
from rdt.transformers.boolean import BooleanTransformer
Expand Down Expand Up @@ -93,9 +94,9 @@ def get_transformers_by_type():
"""Build a ``dict`` mapping data types to valid existing transformers for that type.
Returns:
dict:
Mapping of data types to a list of existing transformers that take that
type as an input.
dict:
Mapping of data types to a list of existing transformers that take that
type as an input.
"""
data_type_transformers = defaultdict(list)
transformer_classes = BaseTransformer.get_subclasses()
Expand All @@ -109,12 +110,13 @@ def get_transformers_by_type():
return data_type_transformers


@lru_cache()
def get_default_transformers():
"""Build a ``dict`` mapping data types to a default transformer for that type.
Returns:
dict:
Mapping of data types to a transformer.
dict:
Mapping of data types to a transformer.
"""
transformers_by_type = get_transformers_by_type()
defaults = {}
Expand All @@ -125,3 +127,15 @@ def get_default_transformers():
defaults[data_type] = transformers[0]

return defaults


@lru_cache()
def get_default_transformer(data_type):
"""Gets default transformer for a data type.
Returns:
Transformer:
Default transformer for data type.
"""
default_transformers = get_default_transformers()
return default_transformers[data_type]
1 change: 1 addition & 0 deletions rdt/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_subclasses(cls):
for subclass in cls.__subclasses__():
if abc.ABC not in subclass.__bases__:
subclasses.append(subclass)

subclasses += subclass.get_subclasses()

return subclasses
Expand Down

0 comments on commit dbb1162

Please sign in to comment.