Skip to content

Commit

Permalink
Warn the user if they are trying to save an unfit synthesizer (#2147)
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho authored Jul 29, 2024
1 parent 20b76f1 commit a0e0a76
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 1 deletion.
10 changes: 10 additions & 0 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,13 +665,23 @@ def get_info(self):

return info

def _validate_fit_before_save(self):
"""Validate that the synthesizer has been fitted before saving."""
if not self._fitted:
warnings.warn(
'You are saving a synthesizer that has not yet been fitted. You will not be able '
'to sample synthetic data without fitting. We recommend fitting the synthesizer '
'first and then saving.'
)

def save(self, filepath):
"""Save this instance to the given path using cloudpickle.
Args:
filepath (str):
Path where the instance will be serialized.
"""
self._validate_fit_before_save()
synthesizer_id = getattr(self, '_synthesizer_id', None)
SYNTHESIZER_LOGGER.info({
'EVENT': 'Save',
Expand Down
10 changes: 10 additions & 0 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,13 +471,23 @@ def fit(self, data):
processed_data = self.preprocess(data)
self.fit_processed_data(processed_data)

def _validate_fit_before_save(self):
"""Validate that the synthesizer has been fitted before saving."""
if not self._fitted:
warnings.warn(
'You are saving a synthesizer that has not yet been fitted. You will not be able '
'to sample synthetic data without fitting. We recommend fitting the synthesizer '
'first and then saving.'
)

def save(self, filepath):
"""Save this model instance to the given path using cloudpickle.
Args:
filepath (str):
Path where the synthesizer instance will be serialized.
"""
self._validate_fit_before_save()
synthesizer_id = getattr(self, '_synthesizer_id', None)
SYNTHESIZER_LOGGER.info({
'EVENT': 'Save',
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import re
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -1521,6 +1522,21 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog):
'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
})

def test_save_warning(self, tmp_path):
"""Test that the synthesizer produces a warning if saved without fitting."""
# Setup
synthesizer = BaseMultiTableSynthesizer(MultiTableMetadata())

# Run and Assert
warn_msg = re.escape(
'You are saving a synthesizer that has not yet been fitted. You will not be able '
'to sample synthetic data without fitting. We recommend fitting the synthesizer '
'first and then saving.'
)
with pytest.warns(Warning, match=warn_msg):
filepath = os.path.join(tmp_path, 'output.pkl')
synthesizer.save(filepath)

@patch('sdv.multi_table.base.datetime')
@patch('sdv.multi_table.base.generate_synthesizer_id')
@patch('sdv.multi_table.base.check_synthesizer_version')
Expand Down
18 changes: 17 additions & 1 deletion tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import re
from datetime import date, datetime
from unittest.mock import ANY, MagicMock, Mock, call, mock_open, patch
Expand Down Expand Up @@ -32,7 +33,7 @@
GaussianCopulaSynthesizer,
TVAESynthesizer,
)
from sdv.single_table.base import COND_IDX, BaseSingleTableSynthesizer
from sdv.single_table.base import COND_IDX, BaseSingleTableSynthesizer, BaseSynthesizer
from tests.utils import catch_sdv_logs


Expand Down Expand Up @@ -1809,6 +1810,21 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog):
'SYNTHESIZER ID': 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
})

def test_save_warning(self, tmp_path):
"""Test that the synthesizer produces a warning if saved without fitting."""
# Setup
synthesizer = BaseSynthesizer(SingleTableMetadata())

# Run and Assert
warn_msg = re.escape(
'You are saving a synthesizer that has not yet been fitted. You will not be able '
'to sample synthetic data without fitting. We recommend fitting the synthesizer '
'first and then saving.'
)
with pytest.warns(Warning, match=warn_msg):
filepath = os.path.join(tmp_path, 'output.pkl')
synthesizer.save(filepath)

@patch('sdv.single_table.base.datetime')
@patch('sdv.single_table.base.generate_synthesizer_id')
@patch('sdv.single_table.base.check_synthesizer_version')
Expand Down

0 comments on commit a0e0a76

Please sign in to comment.