Skip to content

Commit

Permalink
unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Aug 22, 2023
1 parent 61e1447 commit 635fae4
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 9 deletions.
6 changes: 3 additions & 3 deletions rdt/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _store_columns(self, columns, data):
if missing:
raise KeyError(f'Columns {missing} were not present in the data.')

self.columns = columns
self.columns = column_names

@staticmethod
def _get_columns_data(data, columns):
Expand Down Expand Up @@ -490,7 +490,7 @@ class BaseMultiColumnTransformer(BaseTransformer):
"""Base class for all multi column transformers.
The ``BaseMultiColumnTransformer`` class contains methods that must be implemented
in order to create a new mulit column transformer.
in order to create a new multi column transformer.
"""

def get_input_column(self):
Expand All @@ -499,7 +499,7 @@ def get_input_column(self):
Raise an error because for multi column transformers, ``get_input_columns``
must be used instead.
"""
raise ValueError(
raise NotImplementedError(
'MultiColumnTransformers does not have a single input column.'
'Please use ``get_input_columns`` instead.'
)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _is_valid_transformer(transformer_name):
"""Determine if transformer should be tested or not."""
invalid_names = [
'IdentityTransformer', 'Dummy', 'OrderedLabelEncoder', 'CustomLabelEncoder',
'OrderedUniformEncoder',
'OrderedUniformEncoder', 'BaseMultiColumnTransformer'
]
return all(invalid_name not in transformer_name for invalid_name in invalid_names)

Expand Down
70 changes: 65 additions & 5 deletions tests/unit/transformers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from rdt.errors import TransformerInputError
from rdt.transformers import BaseTransformer, NullTransformer
from rdt.transformers import BaseMultiColumnTransformer, BaseTransformer, NullTransformer
from rdt.transformers.base import random_state, set_random_states


Expand Down Expand Up @@ -1278,13 +1278,73 @@ def _reverse_transform(self, data):
class TestBaseMultiColumnTransformer:

def test_get_input_column(self):
pass
"""Test the ``get_input_column`` method.
When the ``get_input_column`` method is called, it should raise a ``NotImplementedError``.
"""
# Setup
expected_message = (
'MultiColumnTransformers does not have a single input column.'
'Please use ``get_input_columns`` instead.'
)

# Run and Assert
with pytest.raises(NotImplementedError, match=expected_message):
BaseMultiColumnTransformer().get_input_column()

def test_get_input_columns(self):
pass
"""Test the ``get_input_columns`` method."""
# Setup
transformer = BaseMultiColumnTransformer()
transformer.columns = ['a', 'b', 'c']

# Run
output = transformer.get_input_columns()

# Assert
assert output == ['a', 'b', 'c']

def test__fit(self):
pass
"""Test the ``_fit`` method.
Check that an error is raised when the ``_fit`` method is called.
"""
# Setup
transformer = BaseMultiColumnTransformer()

# Run and Assert
with pytest.raises(NotImplementedError):
transformer._fit(None, None)

def test_fit(self):
pass
"""Test the ``fit`` method."""
# Setup
transformer = BaseMultiColumnTransformer()
data = Mock()
columns_to_sdtypes = {
'a': 'numerical',
'b': 'categorical',
'c': 'boolean'
}
data_transformer = pd.DataFrame({
'a': [1, 2, 3],
'b': ['a', 'b', 'c'],
})
transformer.columns = ['a', 'b']
transformer._store_columns = Mock()
transformer._get_columns_data = Mock(return_value=data_transformer)
transformer._set_seed = Mock()
transformer._fit = Mock()
transformer._build_output_columns = Mock()

# Run
transformer.fit(data, columns_to_sdtypes)

# Assert
transformer._store_columns.assert_called_once_with(
list(columns_to_sdtypes.keys()), data
)
transformer._set_seed.assert_called_once_with(data)
transformer._get_columns_data.assert_called_once_with(data, ['a', 'b'])
transformer._fit.assert_called_once_with(data_transformer, columns_to_sdtypes)
transformer._build_output_columns.assert_called_once_with(data)

0 comments on commit 635fae4

Please sign in to comment.