From 03ecfec63da929dfb148778a97581da15fca945b Mon Sep 17 00:00:00 2001 From: Carles Sala Date: Thu, 23 Sep 2021 23:43:06 +0200 Subject: [PATCH] Small fixes for the BaseTransformer (#245) --- rdt/transformers/base.py | 134 +++++++++++++++++++-------------------- 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/rdt/transformers/base.py b/rdt/transformers/base.py index 5198e0cb3..a6bcb9960 100644 --- a/rdt/transformers/base.py +++ b/rdt/transformers/base.py @@ -1,5 +1,4 @@ """BaseTransformer module.""" -import pandas as pd class BaseTransformer: @@ -42,11 +41,11 @@ def _add_prefix(self, dictionary): return output def get_output_types(self): - """Return the output types supported by the transformer. + """Return the output types produced by this transformer. Returns: dict: - Mapping from the transformed column names to supported data types. + Mapping from the transformed column names to the produced data types. """ return self._add_prefix(self.OUTPUT_TYPES) @@ -86,24 +85,7 @@ def get_next_transformers(self): """ return self._add_prefix(self.NEXT_TRANSFORMERS) - @staticmethod - def _convert_if_length_one(columns): - """Convert columns to string if it's a list of length one.""" - if len(columns) == 1: - columns = columns[0] - - return columns - - def fit(self, data, columns): - """Fit the transformer to the `columns` of the `data`. - - Args: - data (pandas.DataFrame): - The entire table. - columns (list): - Column names. Must be present in the data. - """ - # make sure columns is a list where every column is in the data + def _store_columns(self, columns, data): if isinstance(columns, tuple) and columns not in data: columns = list(columns) elif not isinstance(columns, list): @@ -113,18 +95,32 @@ def fit(self, data, columns): if missing: raise KeyError(f'Columns {missing} were not present in the data.') - self._column_prefix = '#'.join(columns) + self._columns = columns + + @staticmethod + def _get_columns_data(data, columns): + if len(columns) == 1: + columns = columns[0] + + return data[columns] + + @staticmethod + def _set_columns_data(data, columns_data, columns): + if len(columns_data.shape) == 1: + data[columns[0]] = columns_data + else: + data[columns] = columns_data + + def _build_output_columns(self, data): + self._column_prefix = '#'.join(self._columns) self._output_columns = list(self.get_output_types().keys()) # make sure none of the generated `output_columns` exists in the data - while any(output_column in data for output_column in self._output_columns): + data_columns = set(data.columns) + while data_columns & set(self._output_columns): self._column_prefix += '#' self._output_columns = list(self.get_output_types().keys()) - self._columns = columns - columns = self._convert_if_length_one(self._columns) - self._fit(data[columns]) - def _fit(self, columns_data): """Fit the transformer to the data. @@ -134,13 +130,34 @@ def _fit(self, columns_data): """ raise NotImplementedError() - @staticmethod - def _convert_if_series(columns, data): - """Convert columns to pandas.Series if it's a list of length one.""" - if isinstance(data, pd.Series): - columns = columns[0] + def fit(self, data, columns): + """Fit the transformer to the `columns` of the `data`. - return columns + Args: + data (pandas.DataFrame): + The entire table. + columns (list): + Column names. Must be present in the data. + """ + self._store_columns(columns, data) + + columns_data = self._get_columns_data(data, self._columns) + self._fit(columns_data) + + self._build_output_columns(data) + + def _transform(self, columns_data): + """Transform the data. + + Args: + columns_data (pandas.DataFrame or pandas.Series): + Data to transform. + + Returns: + pandas.DataFrame or pandas.Series: + Transformed data. + """ + raise NotImplementedError() def transform(self, data): """Transform the `self._columns` of the `data`. @@ -159,29 +176,14 @@ def transform(self, data): data = data.copy() - columns = self._convert_if_length_one(self._columns) - columns_data = data[columns] + columns_data = self._get_columns_data(data, self._columns) transformed_data = self._transform(columns_data) - output_columns = self._convert_if_series(self._output_columns, transformed_data) - data[output_columns] = transformed_data + self._set_columns_data(data, transformed_data, self._output_columns) data.drop(self._columns, axis=1, inplace=True) return data - def _transform(self, columns_data): - """Transform the data. - - Args: - columns_data (pandas.DataFrame or pandas.Series): - Data to transform. - - Returns: - pandas.DataFrame or pandas.Series: - Transformed data. - """ - raise NotImplementedError() - def fit_transform(self, data, columns): """Fit the transformer to the `columns` of the `data` and then transform them. @@ -200,6 +202,19 @@ def fit_transform(self, data, columns): self.fit(data, columns) return self.transform(data) + def _reverse_transform(self, columns_data): + """Revert the transformations to the original values. + + Args: + columns_data (pandas.DataFrame or pandas.Series): + Data to revert. + + Returns: + pandas.DataFrame or pandas.Series: + Reverted data. + """ + raise NotImplementedError() + def reverse_transform(self, data): """Revert the transformations to the original values. @@ -217,25 +232,10 @@ def reverse_transform(self, data): data = data.copy() - output_columns = self._convert_if_length_one(self._output_columns) - columns_data = data[output_columns] + columns_data = self._get_columns_data(data, self._output_columns) reversed_data = self._reverse_transform(columns_data) - columns = self._convert_if_series(self._columns, reversed_data) - data[columns] = reversed_data + self._set_columns_data(data, reversed_data, self._columns) data.drop(self._output_columns, axis=1, inplace=True) return data - - def _reverse_transform(self, columns_data): - """Revert the transformations to the original values. - - Args: - columns_data (pandas.DataFrame): - Data to transform. - - Returns: - pandas.Series: - Reverted data. - """ - raise NotImplementedError()