Skip to content

Commit

Permalink
Handle null values in speed preset (#737)
Browse files Browse the repository at this point in the history
* Do not model null values in speed preset

* add unit tests

* Update dtype transformers

* cr and fix test
  • Loading branch information
katxiao authored Mar 23, 2022
1 parent 8aec125 commit 4232f12
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 3 deletions.
35 changes: 33 additions & 2 deletions sdv/lite/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import sys
import warnings

import numpy as np
import rdt

from sdv.tabular import GaussianCopula
from sdv.tabular.base import BaseTabularModel

Expand All @@ -26,6 +29,7 @@ class TabularPreset(BaseTabularModel):
"""

_model = None
_null_percentages = None

def __init__(self, optimize_for=None, metadata=None):
if optimize_for is None:
Expand All @@ -42,22 +46,49 @@ def __init__(self, optimize_for=None, metadata=None):
if optimize_for == SPEED_PRESET:
self._model = GaussianCopula(
table_metadata=metadata,
categorical_transformer='categorical',
categorical_transformer='label_encoding',
default_distribution='gaussian',
rounding=None,
)

dtype_transformers = {
'i': rdt.transformers.NumericalTransformer(
dtype=np.int64, null_column=False),
'f': rdt.transformers.NumericalTransformer(
dtype=np.float64, null_column=False),
'O': rdt.transformers.CategoricalTransformer(fuzzy=True),
'b': rdt.transformers.BooleanTransformer(null_column=False),
'M': rdt.transformers.DatetimeTransformer(null_column=False),
}
self._model._metadata._dtype_transformers.update(dtype_transformers)

print('This config optimizes the modeling speed above all else.\n\n'
'Your exact runtime is dependent on the data. Benchmarks:\n'
'100K rows and 100 columns may take around 1 minute.\n'
'1M rows and 250 columns may take around 30 minutes.')

def fit(self, data):
"""Fit this model to the data."""
self._null_percentages = {}

for column, column_data in data.iteritems():
num_nulls = column_data.isna().sum()
if num_nulls > 0:
# Store null percentage for future reference.
self._null_percentages[column] = num_nulls / len(column_data)

self._model.fit(data)

def sample(self, num_rows):
"""Sample rows from this table."""
return self._model.sample(num_rows)
sampled = self._model.sample(num_rows)

if self._null_percentages:
for column, percentage in self._null_percentages.items():
sampled[column] = sampled[column].mask(
np.random.random((len(sampled), )) < percentage)

return sampled

@classmethod
def list_available_presets(cls, out=sys.stdout):
Expand Down
75 changes: 74 additions & 1 deletion tests/unit/lite/test_tabular.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
from unittest.mock import Mock, patch

import numpy as np
import pandas as pd
import pytest

Expand Down Expand Up @@ -30,6 +31,7 @@ def test___init__invalid_optimize_for(self):
Input:
- optimize_for = invalid parameter
Side Effects:
- ValueError should be thrown
"""
Expand All @@ -54,49 +56,120 @@ def test__init__speed_passes_correct_parameters(self, gaussian_copula_mock):
# Assert
gaussian_copula_mock.assert_called_once_with(
table_metadata=None,
categorical_transformer='categorical',
categorical_transformer='label_encoding',
default_distribution='gaussian',
rounding=None,
)
metadata = gaussian_copula_mock.return_value._metadata
assert metadata._dtype_transformers.update.call_count == 1

def test_fit(self):
"""Test the ``TabularPreset.fit`` method.
Expect that the model's fit method is called with the expected args.
Input:
- fit data
Side Effects:
- The model's fit method is called with the same data.
"""
# Setup
metadata = Mock()
metadata.to_dict.return_value = {'fields': {}}
model = Mock()
model._metadata = metadata
preset = Mock()
preset._model = model
preset._null_percentages = None

# Run
TabularPreset.fit(preset, pd.DataFrame())

# Assert
model.fit.assert_called_once_with(DataFrameMatcher(pd.DataFrame()))
assert preset._null_percentages == {}

def test_fit_with_null_values(self):
"""Test the ``TabularPreset.fit`` method with null values.
Expect that the model's fit method is called with the expected args, and that
the null percentage is calculated correctly.
Input:
- fit data
Side Effects:
- The model's fit method is called with the same data.
"""
# Setup
metadata = Mock()
metadata.to_dict.return_value = {'fields': {'a': {}}}
model = Mock()
model._metadata = metadata
preset = Mock()
preset._model = model
preset._null_percentages = None

data = {'a': [1, 2, np.nan]}

# Run
TabularPreset.fit(preset, pd.DataFrame(data))

# Assert
model.fit.assert_called_once_with(DataFrameMatcher(pd.DataFrame(data)))
assert preset._null_percentages == {'a': 1.0 / 3}

def test_sample(self):
"""Test the ``TabularPreset.sample`` method.
Expect that the model's sample method is called with the expected args.
Input:
- num_rows=5
Side Effects:
- The model's sample method is called with the same data.
"""
# Setup
model = Mock()
preset = Mock()
preset._model = model
preset._null_percentages = None

# Run
TabularPreset.sample(preset, 5)

# Assert
model.sample.assert_called_once_with(5)

def test_sample_with_null_values(self):
"""Test the ``TabularPreset.sample`` method with null percentages.
Expect that the model's sample method is called with the expected args, and that
null values are inserted back into the sampled data.
Input:
- num_rows=5
Side Effects:
- The model's sample method is called with the expected number of rows.
"""
# Setup
model = Mock()
model.sample.return_value = pd.DataFrame({'a': [1, 2, 3, 4, 5]})
preset = Mock()
preset._model = model
# Convoluted example - 100% percent chance of nulls to make test deterministic.
preset._null_percentages = {'a': 1}

# Run
sampled = TabularPreset.sample(preset, 5)

# Assert
model.sample.assert_called_once_with(5)
assert sampled['a'].isna().sum() == 5

def test_list_available_presets(self):
"""Tests the ``TabularPreset.list_available_presets`` method.
Expand Down

0 comments on commit 4232f12

Please sign in to comment.