Skip to content

Commit

Permalink
better handle numerical id
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Oct 25, 2024
1 parent b2e5087 commit 420e883
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 41 deletions.
23 changes: 13 additions & 10 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
)
from sdv.data_processing.datetime_formatter import DatetimeFormatter
from sdv.data_processing.errors import InvalidConstraintsError, NotFittedError
from sdv.data_processing.numerical_formatter import NumericalFormatter
from sdv.data_processing.numerical_formatter import INTEGER_BOUNDS, NumericalFormatter
from sdv.data_processing.utils import load_module_from_path
from sdv.errors import SynthesizerInputError, log_exc_stacktrace
from sdv.metadata.single_table import SingleTableMetadata

LOGGER = logging.getLogger(__name__)
INTEGER_BOUNDS = {str(key).lower(): value for key, value in INTEGER_BOUNDS.items()}


class DataProcessor:
Expand Down Expand Up @@ -561,6 +562,7 @@ def _create_config(self, data, columns_created_by_constraints):
)

if sdtype == 'id':
function_name = 'bothify'
column_dtype = data[column].dtype
is_numeric = pd.api.types.is_numeric_dtype(column_dtype)
if column_metadata.get('regex_format', False):
Expand All @@ -570,24 +572,25 @@ def _create_config(self, data, columns_created_by_constraints):
sdtypes[column] = 'text'

else:
bothify_format = 'sdv-id-??????'
if is_numeric:
function_name = 'random_int'
column_dtype = str(column_dtype).lower()
if 'int8' in column_dtype:
bothify_format = '##'
elif 'int16' in column_dtype:
bothify_format = '####'
else:
bothify_format = '#########'
for key in INTEGER_BOUNDS:
if key in column_dtype:
min_value, max_value = INTEGER_BOUNDS[key]
function_kwargs = {'min': min_value, 'max': max_value}

else:
function_kwargs = {'text': 'sdv-id-??????'}

cardinality_rule = None
if column in self._keys:
cardinality_rule = 'unique'

transformers[column] = AnonymizedFaker(
provider_name=None,
function_name='bothify',
function_kwargs={'text': bothify_format},
function_name=function_name,
function_kwargs=function_kwargs,
cardinality_rule=cardinality_rule,
)

Expand Down
50 changes: 25 additions & 25 deletions tests/integration/single_table/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,31 +347,31 @@ def test_numerical_columns_gets_pii():

# Assert
expected_sampled = pd.DataFrame({
'id': {
0: 807994768,
1: 746439230,
2: 201363792,
3: 364823003,
4: 726973888,
5: 693331380,
6: 795819284,
7: 607278621,
8: 783746695,
9: 162118876,
},
'city': {
0: 'Danielfort',
1: 'Glendaside',
2: 'Port Jenniferchester',
3: 'Port Susan',
4: 'West Michellemouth',
5: 'West Jason',
6: 'Ryanfort',
7: 'West Stephenland',
8: 'Davidland',
9: 'Port Christopher',
},
'numerical': {0: 22, 1: 24, 2: 22, 3: 23, 4: 22, 5: 24, 6: 23, 7: 24, 8: 24, 9: 24},
'id': [
2099712954613693783,
-152666675184636528,
4268567557886801441,
8596895119661928307,
6592419880288711333,
7082988828807721487,
-1204202270621625701,
3831203727630084512,
-4724549333540186445,
496832674032232864,
],
'city': [
'Danielfort',
'Glendaside',
'Port Jenniferchester',
'Port Susan',
'West Michellemouth',
'West Jason',
'Ryanfort',
'West Stephenland',
'Davidland',
'Port Christopher',
],
'numerical': [22, 24, 22, 23, 22, 24, 23, 24, 24, 24],
})
pd.testing.assert_frame_equal(expected_sampled, sampled)

Expand Down
15 changes: 9 additions & 6 deletions tests/unit/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,19 +1244,22 @@ def test__create_config(self):

id_numeric_int_8_transformer = config['transformers']['id_numeric_int8']
assert isinstance(id_numeric_int_8_transformer, AnonymizedFaker)
assert id_numeric_int_8_transformer.function_name == 'bothify'
assert id_numeric_int_8_transformer.function_kwargs == {'text': '##'}
assert id_numeric_int_8_transformer.function_name == 'random_int'
assert id_numeric_int_8_transformer.function_kwargs == {'min': -128, 'max': 127}
assert id_numeric_int_8_transformer.cardinality_rule == 'unique'

id_numeric_int_16_transformer = config['transformers']['id_numeric_int16']
assert isinstance(id_numeric_int_16_transformer, AnonymizedFaker)
assert id_numeric_int_16_transformer.function_name == 'bothify'
assert id_numeric_int_16_transformer.function_kwargs == {'text': '####'}
assert id_numeric_int_16_transformer.function_name == 'random_int'
assert id_numeric_int_16_transformer.function_kwargs == {'min': -32768, 'max': 32767}

id_numeric_int_32_transformer = config['transformers']['id_numeric_int32']
assert isinstance(id_numeric_int_32_transformer, AnonymizedFaker)
assert id_numeric_int_32_transformer.function_name == 'bothify'
assert id_numeric_int_32_transformer.function_kwargs == {'text': '#########'}
assert id_numeric_int_32_transformer.function_name == 'random_int'
assert id_numeric_int_32_transformer.function_kwargs == {
'min': -2147483648,
'max': 2147483647,
}

id_column_transformer = config['transformers']['id_column']
assert isinstance(id_column_transformer, AnonymizedFaker)
Expand Down

0 comments on commit 420e883

Please sign in to comment.