Skip to content

Commit

Permalink
Small fixes for the BaseTransformer (#245)
Browse files Browse the repository at this point in the history
  • Loading branch information
csala authored Sep 23, 2021
1 parent 4a0fe1b commit 03ecfec
Showing 1 changed file with 67 additions and 67 deletions.
134 changes: 67 additions & 67 deletions rdt/transformers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""BaseTransformer module."""
import pandas as pd


class BaseTransformer:
Expand Down Expand Up @@ -42,11 +41,11 @@ def _add_prefix(self, dictionary):
return output

def get_output_types(self):
"""Return the output types supported by the transformer.
"""Return the output types produced by this transformer.
Returns:
dict:
Mapping from the transformed column names to supported data types.
Mapping from the transformed column names to the produced data types.
"""
return self._add_prefix(self.OUTPUT_TYPES)

Expand Down Expand Up @@ -86,24 +85,7 @@ def get_next_transformers(self):
"""
return self._add_prefix(self.NEXT_TRANSFORMERS)

@staticmethod
def _convert_if_length_one(columns):
"""Convert columns to string if it's a list of length one."""
if len(columns) == 1:
columns = columns[0]

return columns

def fit(self, data, columns):
"""Fit the transformer to the `columns` of the `data`.
Args:
data (pandas.DataFrame):
The entire table.
columns (list):
Column names. Must be present in the data.
"""
# make sure columns is a list where every column is in the data
def _store_columns(self, columns, data):
if isinstance(columns, tuple) and columns not in data:
columns = list(columns)
elif not isinstance(columns, list):
Expand All @@ -113,18 +95,32 @@ def fit(self, data, columns):
if missing:
raise KeyError(f'Columns {missing} were not present in the data.')

self._column_prefix = '#'.join(columns)
self._columns = columns

@staticmethod
def _get_columns_data(data, columns):
if len(columns) == 1:
columns = columns[0]

return data[columns]

@staticmethod
def _set_columns_data(data, columns_data, columns):
if len(columns_data.shape) == 1:
data[columns[0]] = columns_data
else:
data[columns] = columns_data

def _build_output_columns(self, data):
self._column_prefix = '#'.join(self._columns)
self._output_columns = list(self.get_output_types().keys())

# make sure none of the generated `output_columns` exists in the data
while any(output_column in data for output_column in self._output_columns):
data_columns = set(data.columns)
while data_columns & set(self._output_columns):
self._column_prefix += '#'
self._output_columns = list(self.get_output_types().keys())

self._columns = columns
columns = self._convert_if_length_one(self._columns)
self._fit(data[columns])

def _fit(self, columns_data):
"""Fit the transformer to the data.
Expand All @@ -134,13 +130,34 @@ def _fit(self, columns_data):
"""
raise NotImplementedError()

@staticmethod
def _convert_if_series(columns, data):
"""Convert columns to pandas.Series if it's a list of length one."""
if isinstance(data, pd.Series):
columns = columns[0]
def fit(self, data, columns):
"""Fit the transformer to the `columns` of the `data`.
return columns
Args:
data (pandas.DataFrame):
The entire table.
columns (list):
Column names. Must be present in the data.
"""
self._store_columns(columns, data)

columns_data = self._get_columns_data(data, self._columns)
self._fit(columns_data)

self._build_output_columns(data)

def _transform(self, columns_data):
"""Transform the data.
Args:
columns_data (pandas.DataFrame or pandas.Series):
Data to transform.
Returns:
pandas.DataFrame or pandas.Series:
Transformed data.
"""
raise NotImplementedError()

def transform(self, data):
"""Transform the `self._columns` of the `data`.
Expand All @@ -159,29 +176,14 @@ def transform(self, data):

data = data.copy()

columns = self._convert_if_length_one(self._columns)
columns_data = data[columns]
columns_data = self._get_columns_data(data, self._columns)
transformed_data = self._transform(columns_data)

output_columns = self._convert_if_series(self._output_columns, transformed_data)
data[output_columns] = transformed_data
self._set_columns_data(data, transformed_data, self._output_columns)
data.drop(self._columns, axis=1, inplace=True)

return data

def _transform(self, columns_data):
"""Transform the data.
Args:
columns_data (pandas.DataFrame or pandas.Series):
Data to transform.
Returns:
pandas.DataFrame or pandas.Series:
Transformed data.
"""
raise NotImplementedError()

def fit_transform(self, data, columns):
"""Fit the transformer to the `columns` of the `data` and then transform them.
Expand All @@ -200,6 +202,19 @@ def fit_transform(self, data, columns):
self.fit(data, columns)
return self.transform(data)

def _reverse_transform(self, columns_data):
"""Revert the transformations to the original values.
Args:
columns_data (pandas.DataFrame or pandas.Series):
Data to revert.
Returns:
pandas.DataFrame or pandas.Series:
Reverted data.
"""
raise NotImplementedError()

def reverse_transform(self, data):
"""Revert the transformations to the original values.
Expand All @@ -217,25 +232,10 @@ def reverse_transform(self, data):

data = data.copy()

output_columns = self._convert_if_length_one(self._output_columns)
columns_data = data[output_columns]
columns_data = self._get_columns_data(data, self._output_columns)
reversed_data = self._reverse_transform(columns_data)

columns = self._convert_if_series(self._columns, reversed_data)
data[columns] = reversed_data
self._set_columns_data(data, reversed_data, self._columns)
data.drop(self._output_columns, axis=1, inplace=True)

return data

def _reverse_transform(self, columns_data):
"""Revert the transformations to the original values.
Args:
columns_data (pandas.DataFrame):
Data to transform.
Returns:
pandas.Series:
Reverted data.
"""
raise NotImplementedError()

0 comments on commit 03ecfec

Please sign in to comment.