-
Notifications
You must be signed in to change notification settings - Fork 73
/
Copy pathtest_mst.py
37 lines (27 loc) · 1.01 KB
/
test_mst.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import subprocess
import os
import numpy as np
import pandas as pd
from snsynth.mst import MSTSynthesizer
git_root_dir = subprocess.check_output("git rev-parse --show-toplevel".split(" ")).decode("utf-8").strip()
csv_path = os.path.join(git_root_dir, os.path.join("datasets", "PUMS.csv"))
df = pd.read_csv(csv_path)
df = df.drop(["income"], axis=1)
df = df.sample(frac=1, random_state=42)
class TestMST():
@classmethod
def setup_class(cls) -> None:
print("Setting up class")
cls.mst = MSTSynthesizer()
print("Setup class")
print(cls.mst)
def test_fit(self):
self.df_non_continuous = df[['sex','educ','race','married']]
self.mst.fit(self.df_non_continuous)
assert self.mst
def test_sample(self):
self.df_non_continuous = df[['sex','educ','race','married']]
self.mst.fit(self.df_non_continuous)
sample_size = len(df)
synth_data = self.mst.sample(sample_size)
assert synth_data.shape == self.df_non_continuous.shape