-
Notifications
You must be signed in to change notification settings - Fork 73
/
Copy pathtest_pategan.py
39 lines (30 loc) · 1.29 KB
/
test_pategan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import subprocess
import os
import pytest
import string
import pandas as pd
# try:
from snsynth.pytorch import PytorchDPSynthesizer
from snsynth.pytorch.nn import PATEGAN
# except:
# import logging
# test_logger = logging.getLogger(__name__)
# test_logger.warning("Requires torch and torchdp")
git_root_dir = subprocess.check_output("git rev-parse --show-toplevel".split(" ")).decode("utf-8").strip()
meta_path = os.path.join(git_root_dir, os.path.join("datasets", "PUMS_pid.yaml"))
csv_path = os.path.join(git_root_dir, os.path.join("datasets", "PUMS_pid.csv"))
df = pd.read_csv(csv_path)
@pytest.mark.torch
class TestDPGAN():
def setup_method(self):
self.pategan = PytorchDPSynthesizer(1.0, PATEGAN(1.0), None)
def test_fit(self):
df_non_continuous = df[['sex','educ','race','married']]
self.pategan.fit(df_non_continuous, categorical_columns=['sex','educ','race','married'])
assert self.pategan.gan.generator
def test_sample(self):
df_non_continuous = df[['sex','educ','race','married']]
self.pategan.fit(df_non_continuous, categorical_columns=['sex','educ','race','married'])
sample_size = len(df_non_continuous)
synth_data = self.pategan.sample(sample_size)
assert synth_data.shape == df_non_continuous.shape