Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Nov 6, 2024
1 parent 0c98a2e commit 863b29e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 2 deletions.
7 changes: 5 additions & 2 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,14 @@ def _warn_quality_and_performance(self, column_name_to_transformer):
f"Replacing the default transformer for column '{column}' "
'might impact the quality of your synthetic data.'
)

def _warn_unable_to_enforce_rounding(self, column_name_to_transformer):
if self.enforce_rounding:
for column, transformer in column_name_to_transformer.items():
if not transformer.learn_rounding_scheme:
if (
hasattr(transformer, 'learn_rounding_scheme')
and not transformer.learn_rounding_scheme
):
warnings.warn(
f"Unable to turn off rounding scheme for column '{column}', "
'because the overall synthesizer is enforcing rounding. We '
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,3 +844,22 @@ def test_fit_int_primary_key_regex_includes_zero(synthesizer_class, regex):
)
with pytest.raises(SynthesizerInputError, match=message):
instance.fit(data)


@patch('sdv.single_table.base.warnings')
def test_update_transformers(warning_mock):
"""Test the proper warning is raised."""
# Setup
data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests')

# Run
synthesizer = GaussianCopulaSynthesizer(metadata)
synthesizer.auto_assign_transformers(data)
synthesizer.update_transformers({'amenities_fee': FloatFormatter(learn_rounding_scheme=False)})

# Assert
warning_mock.warn.assert_called_once_with(
"Unable to turn off rounding scheme for column 'amenities_fee', because the overall "
"synthesizer is enforcing rounding. We recommend setting the synthesizer's "
"'enforce_rounding' parameter to False."
)
33 changes: 33 additions & 0 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,39 @@ def test_update_transformers(self):
assert isinstance(field_transformers['col1'], GaussianNormalizer)
assert isinstance(field_transformers['col2'], GaussianNormalizer)

def test_update_transformers_warns_rounding(self):
"""Test warning is raised if model cannot round."""
# Setup
column_name_to_transformer = {
'col1': GaussianNormalizer(learn_rounding_scheme=False),
'col2': GaussianNormalizer(learn_rounding_scheme=True),
'col3': GaussianNormalizer(learn_rounding_scheme=False),
}
metadata = Metadata()
instance = BaseSingleTableSynthesizer(metadata)
instance._validate_transformers = MagicMock()
instance._warn_quality_and_performance = MagicMock()
instance._data_processor = MagicMock()
instance.enforce_rounding = True
instance._fitted = False

# Run
with pytest.warns(UserWarning) as record:
instance.update_transformers(column_name_to_transformer)

# Assert
assert len(record) == 2
assert str(record[0].message) == (
"Unable to turn off rounding scheme for column 'col1', because the overall "
"synthesizer is enforcing rounding. We recommend setting the synthesizer's "
"'enforce_rounding' parameter to False."
)
assert str(record[1].message) == (
"Unable to turn off rounding scheme for column 'col3', because the overall "
"synthesizer is enforcing rounding. We recommend setting the synthesizer's "
"'enforce_rounding' parameter to False."
)

@patch('sdv.single_table.base.DataProcessor')
def test__set_random_state(self, mock_data_processor):
"""Test that ``_model.set_random_state`` is being called with the input value.
Expand Down

0 comments on commit 863b29e

Please sign in to comment.