-
Notifications
You must be signed in to change notification settings - Fork 73
/
Copy pathtest_factory.py
30 lines (25 loc) · 1.14 KB
/
test_factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
import subprocess
import pandas as pd
from sklearn import preprocessing
from unittest import TestCase
from snsynth import *
git_root_dir = subprocess.check_output("git rev-parse --show-toplevel".split(" ")).decode("utf-8").strip()
csv_path = os.path.join(git_root_dir, os.path.join("datasets", "PUMS.csv"))
df = pd.read_csv(csv_path, index_col=None)
class TestFactory(TestCase):
def test_create_empty(self):
for synth in Synthesizer.list_synthesizers():
_ = Synthesizer.create(synth, epsilon=1.0)
def test_fit_with_data_frame(self):
# fit income by marital status
narrow_df = df.drop(["age", "sex", "race", "educ"], axis=1)
for synth in Synthesizer.list_synthesizers():
synth = Synthesizer.create(synth, epsilon=2.0)
print(f"Fitting {synth}...")
synth.fit(narrow_df, preprocessor_eps=0.5)
rows = synth.sample(100)
assert (isinstance(rows, pd.DataFrame))
assert (rows['income'].mean() > 1000 and rows['income'].mean() < 250000)
assert (sum(rows['married'] == 1) > 1)
# assert (sum(rows['married'] == 0) > 1)