-
Notifications
You must be signed in to change notification settings - Fork 73
/
Copy pathtest_quail.py
58 lines (45 loc) · 1.62 KB
/
test_quail.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import subprocess
import os
import pytest
import pandas as pd
from diffprivlib.models import LogisticRegression as DPLR
from snsynth.pytorch import PytorchDPSynthesizer
from snsynth.pytorch.nn import PATECTGAN
from snsynth.quail import QUAILSynthesizer
git_root_dir = (
subprocess.check_output("git rev-parse --show-toplevel".split(" "))
.decode("utf-8")
.strip()
)
meta_path = os.path.join(git_root_dir, os.path.join("datasets", "PUMS.yaml"))
csv_path = os.path.join(git_root_dir, os.path.join("datasets", "PUMS.csv"))
df = pd.read_csv(csv_path)
del df["income"]
@pytest.mark.torch
class TestQUAIL():
def setup_method(self):
def QuailClassifier(epsilon):
return DPLR(epsilon=epsilon)
def QuailSynth(epsilon):
return PytorchDPSynthesizer(
epsilon=epsilon,
preprocessor=None,
gan=PATECTGAN(loss="cross_entropy", batch_size=50, pac=1),
)
self.quail = QUAILSynthesizer(
10.0, QuailSynth, QuailClassifier, "married", eps_split=0.8
)
def test_fit(self):
categorical_columns = [col for col in df.columns if col != "married"]
self.quail.fit(
df, categorical_columns=categorical_columns, preprocessor_eps=0.5
)
assert self.quail.private_synth
def test_sample(self):
categorical_columns = [col for col in df.columns if col != "married"]
self.quail.fit(
df, categorical_columns=categorical_columns
)
sample_size = len(df)
synth_data = self.quail.sample(sample_size)
assert synth_data.shape == df.shape