Skip to content

Commit

Permalink
columns_to_sdtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Sep 8, 2023
1 parent 9153334 commit ba302ba
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 33 deletions.
26 changes: 13 additions & 13 deletions rdt/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,15 +488,15 @@ class BaseMultiColumnTransformer(BaseTransformer):
in order to create a new multi column transformer.
Attributes:
columns_to_sdtype (dict):
columns_to_sdtypes (dict):
Dictionary mapping each column to its sdtype.
prefixes (dict):
Dictionary mapping each output column to its prefix.
"""

def __init__(self):
super().__init__()
self.columns_to_sdtype = {}
self.columns_to_sdtypes = {}
self.prefixes = {}

def get_input_column(self):
Expand Down Expand Up @@ -549,9 +549,9 @@ def _get_output_to_property(self, property_):

return output

def _validate_columns_to_sdtype(self, data, columns_to_sdtype):
"""Check that all the columns in ``columns_to_sdtype`` are present in the data."""
missing = set(columns_to_sdtype.keys()) - set(data.columns)
def _validate_columns_to_sdtypes(self, data, columns_to_sdtypes):
"""Check that all the columns in ``columns_to_sdtypes`` are present in the data."""
missing = set(columns_to_sdtypes.keys()) - set(data.columns)
if missing:
missing_to_print = ', '.join(missing)
raise KeyError(f'Columns ({missing_to_print}) are not present in the data.')
Expand All @@ -566,35 +566,35 @@ def _fit(self, data):
raise NotImplementedError()

@random_state
def fit(self, data, columns_to_sdtype):
def fit(self, data, columns_to_sdtypes):
"""Fit the transformer to a ``column`` of the ``data``.
Args:
data (pandas.DataFrame):
The entire table.
columns_to_sdtype (dict):
columns_to_sdtypes (dict):
Dictionary mapping each column to its sdtype.
"""
self._validate_columns_to_sdtype(data, columns_to_sdtype)
self.columns_to_sdtype = columns_to_sdtype
self._store_columns(list(self.columns_to_sdtype.keys()), data)
self._validate_columns_to_sdtypes(data, columns_to_sdtypes)
self.columns_to_sdtypes = columns_to_sdtypes
self._store_columns(list(self.columns_to_sdtypes.keys()), data)
self._set_seed(data)
columns_data = self._get_columns_data(data, self.columns)
self._fit(columns_data)
self._build_output_columns(data)

def fit_transform(self, data, columns_to_sdtype):
def fit_transform(self, data, columns_to_sdtypes):
"""Fit the transformer to a `column` of the `data` and then transform it.
Args:
data (pandas.DataFrame):
The entire table.
columns_to_sdtype (dict):
columns_to_sdtypes (dict):
Dictionary mapping each column to its sdtype.
Returns:
pd.DataFrame:
The entire table, containing the transformed data.
"""
self.fit(data, columns_to_sdtype)
self.fit(data, columns_to_sdtypes)
return self.transform(data)
12 changes: 6 additions & 6 deletions tests/integration/transformers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,15 @@ def _reverse_transform(self, data):
'col_3': [100, 200, 300]
})

columns_to_sdtype = {
columns_to_sdtypes = {
'col_1': 'numerical',
'col_2': 'numerical',
'col_3': 'numerical'
}
transformer = AdditionTransformer()

# Run
transformed = transformer.fit_transform(data_test, columns_to_sdtype)
transformed = transformer.fit_transform(data_test, columns_to_sdtypes)
reverse = transformer.reverse_transform(transformed)

# Assert
Expand Down Expand Up @@ -229,7 +229,7 @@ def _reverse_transform(self, data):
'col_4': ['J', 'K', 'L']
})

columns_to_sdtype = {
columns_to_sdtypes = {
'col_1': 'categorical',
'col_2': 'categorical',
'col_3': 'categorical',
Expand All @@ -238,7 +238,7 @@ def _reverse_transform(self, data):
transformer = ConcatenateTransformer()

# Run
transformer.fit(data_test, columns_to_sdtype)
transformer.fit(data_test, columns_to_sdtypes)
transformed = transformer.transform(data_test)
reverse = transformer.reverse_transform(transformed)

Expand Down Expand Up @@ -296,14 +296,14 @@ def _reverse_transform(self, data):
'col_2': ['GH', 'IJ', 'KL'],
})

columns_to_sdtype = {
columns_to_sdtypes = {
'col_1': 'categorical',
'col_2': 'categorical'
}
transformer = ExpandTransformer()

# Run
transformer.fit(data_test, columns_to_sdtype)
transformer.fit(data_test, columns_to_sdtypes)
transformed = transformer.transform(data_test)
reverse = transformer.reverse_transform(transformed)

Expand Down
28 changes: 14 additions & 14 deletions tests/unit/transformers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,7 @@ def test___init__(self):
transformer = BaseMultiColumnTransformer()

# Assert
assert transformer.columns_to_sdtype == {}
assert transformer.columns_to_sdtypes == {}
assert transformer.prefixes == {}

def test_get_input_column(self):
Expand Down Expand Up @@ -1396,25 +1396,25 @@ def test__get_output_to_property_with_prefix_none(self):
assert output == expected_output
transformer._get_prefix.assert_called_once_with()

def test__validate_columns_to_sdtype(self):
"""Test the ``_validate_columns_to_sdtype`` method."""
def test__validate_columns_to_sdtypes(self):
"""Test the ``_validate_columns_to_sdtypes`` method."""
# Setup
transformer = BaseMultiColumnTransformer()
data = pd.DataFrame({
'a': [1, 2, 3],
'b': ['a', 'b', 'c'],
'c': [True, False, True],
})
columns_to_sdtype = {
columns_to_sdtypes = {
'a': 'numerical',
'b': 'categorical',
'c': 'boolean',
}

# Run and Assert
transformer._validate_columns_to_sdtype(data, columns_to_sdtype)
transformer._validate_columns_to_sdtypes(data, columns_to_sdtypes)

wrong_columns_to_sdtype = {
wrong_columns_to_sdtypes = {
'a': 'numerical',
'b': 'categorical',
'd': 'boolean',
Expand All @@ -1423,7 +1423,7 @@ def test__validate_columns_to_sdtype(self):
'Columns (d) are not present in the data.'
)
with pytest.raises(KeyError, match=expected_error_msg):
transformer._validate_columns_to_sdtype(data, wrong_columns_to_sdtype)
transformer._validate_columns_to_sdtypes(data, wrong_columns_to_sdtypes)

def test__fit(self):
"""Test the ``_fit`` method.
Expand All @@ -1446,24 +1446,24 @@ def test_fit(self):
'a': [1, 2, 3],
'b': ['a', 'b', 'c'],
})
columns_to_sdtype = {
columns_to_sdtypes = {
'a': 'numerical',
'b': 'categorical',
}
transformer.columns = ['a', 'b']

transformer._validate_columns_to_sdtype = Mock()
transformer._validate_columns_to_sdtypes = Mock()
transformer._store_columns = Mock()
transformer._get_columns_data = Mock(return_value=data_transformer)
transformer._set_seed = Mock()
transformer._fit = Mock()
transformer._build_output_columns = Mock()

# Run
transformer.fit(data, columns_to_sdtype)
transformer.fit(data, columns_to_sdtypes)

# Assert
transformer._validate_columns_to_sdtype.assert_called_once_with(data, columns_to_sdtype)
transformer._validate_columns_to_sdtypes.assert_called_once_with(data, columns_to_sdtypes)
transformer._store_columns.assert_called_once_with(
['a', 'b'], data
)
Expand All @@ -1476,7 +1476,7 @@ def test_fit_transform(self):
"""Test the ``fit_transform`` method."""
# Setup
transformer = BaseMultiColumnTransformer()
columns_to_sdtype = ('a', 'b', 'c')
columns_to_sdtypes = ('a', 'b', 'c')
data = pd.DataFrame({
'a': [1, 2, 3],
'b': ['a', 'b', 'c'],
Expand All @@ -1488,8 +1488,8 @@ def test_fit_transform(self):
transformer.transform = mock_transform

# Run
transformer.fit_transform(data, columns_to_sdtype)
transformer.fit_transform(data, columns_to_sdtypes)

# Assert
mock_fit.assert_called_once_with(data, columns_to_sdtype)
mock_fit.assert_called_once_with(data, columns_to_sdtypes)
mock_transform.assert_called_once_with(data)

0 comments on commit ba302ba

Please sign in to comment.