Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warning when unable to turn off rounding scheme for a column #2279

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def _validate_transformers(self, column_name_to_transformer):
f"Transformer for column '{column}' has already been fit on data."
)

def _warn_for_update_transformers(self, column_name_to_transformer):
"""Raise warnings for update_transformers.
def _warn_quality_and_performance(self, column_name_to_transformer):
"""Raise warning if the quality/performance may be impacted.

Args:
column_name_to_transformer (dict):
Expand All @@ -259,6 +259,20 @@ def _warn_for_update_transformers(self, column_name_to_transformer):
'might impact the quality of your synthetic data.'
)

def _warn_unable_to_enforce_rounding(self, column_name_to_transformer):
if self.enforce_rounding:
for column, transformer in column_name_to_transformer.items():
if (
hasattr(transformer, 'learn_rounding_scheme')
and not transformer.learn_rounding_scheme
):
warnings.warn(
f"Unable to turn off rounding scheme for column '{column}', "
'because the overall synthesizer is enforcing rounding. We '
"recommend setting the synthesizer's 'enforce_rounding' "
'parameter to False.'
)

def update_transformers(self, column_name_to_transformer):
"""Update any of the transformers assigned to each of the column names.

Expand All @@ -267,7 +281,8 @@ def update_transformers(self, column_name_to_transformer):
Dict mapping column names to transformers to be used for that column.
"""
self._validate_transformers(column_name_to_transformer)
self._warn_for_update_transformers(column_name_to_transformer)
self._warn_quality_and_performance(column_name_to_transformer)
self._warn_unable_to_enforce_rounding(column_name_to_transformer)
self._data_processor.update_transformers(column_name_to_transformer)
if self._fitted:
msg = 'For this change to take effect, please refit the synthesizer using `fit`.'
Expand Down
4 changes: 2 additions & 2 deletions sdv/single_table/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def _fit_model(self, processed_data):
warnings.filterwarnings('ignore', module='scipy')
self._model.fit(processed_data)

def _warn_for_update_transformers(self, column_name_to_transformer):
"""Raise warnings for update_transformers.
def _warn_quality_and_performance(self, column_name_to_transformer):
"""Raise warning if the quality/performance may be impacted.
Args:
column_name_to_transformer (dict):
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,3 +844,22 @@ def test_fit_int_primary_key_regex_includes_zero(synthesizer_class, regex):
)
with pytest.raises(SynthesizerInputError, match=message):
instance.fit(data)


@patch('sdv.single_table.base.warnings')
def test_update_transformers(warning_mock):
"""Test the proper warning is raised."""
# Setup
data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests')

# Run
synthesizer = GaussianCopulaSynthesizer(metadata)
synthesizer.auto_assign_transformers(data)
synthesizer.update_transformers({'amenities_fee': FloatFormatter(learn_rounding_scheme=False)})

# Assert
warning_mock.warn.assert_called_once_with(
"Unable to turn off rounding scheme for column 'amenities_fee', because the overall "
"synthesizer is enforcing rounding. We recommend setting the synthesizer's "
"'enforce_rounding' parameter to False."
)
33 changes: 33 additions & 0 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,39 @@ def test_update_transformers(self):
assert isinstance(field_transformers['col1'], GaussianNormalizer)
assert isinstance(field_transformers['col2'], GaussianNormalizer)

def test_update_transformers_warns_rounding(self):
"""Test warning is raised if model cannot round."""
# Setup
column_name_to_transformer = {
'col1': GaussianNormalizer(learn_rounding_scheme=False),
'col2': GaussianNormalizer(learn_rounding_scheme=True),
'col3': GaussianNormalizer(learn_rounding_scheme=False),
}
metadata = Metadata()
instance = BaseSingleTableSynthesizer(metadata)
instance._validate_transformers = MagicMock()
instance._warn_quality_and_performance = MagicMock()
instance._data_processor = MagicMock()
instance.enforce_rounding = True
instance._fitted = False

# Run
with pytest.warns(UserWarning) as record:
instance.update_transformers(column_name_to_transformer)

# Assert
assert len(record) == 2
assert str(record[0].message) == (
"Unable to turn off rounding scheme for column 'col1', because the overall "
"synthesizer is enforcing rounding. We recommend setting the synthesizer's "
"'enforce_rounding' parameter to False."
)
assert str(record[1].message) == (
"Unable to turn off rounding scheme for column 'col3', because the overall "
"synthesizer is enforcing rounding. We recommend setting the synthesizer's "
"'enforce_rounding' parameter to False."
)

@patch('sdv.single_table.base.DataProcessor')
def test__set_random_state(self, mock_data_processor):
"""Test that ``_model.set_random_state`` is being called with the input value.
Expand Down
Loading