Skip to content

Commit

Permalink
Unittest for bank and fixed linting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
joosjegoedhart committed Mar 27, 2023
1 parent 46dec10 commit 9d392d2
Showing 1 changed file with 34 additions and 21 deletions.
55 changes: 34 additions & 21 deletions tests/test_standard_datasets.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,62 @@
""" Tests for standard dataset classes """

from unittest.mock import patch
import numpy as np
import pandas as pd

pd.set_option('display.max_rows', 50)
pd.set_option('display.max_columns', 10)
pd.set_option('display.width', 200)

from aif360.datasets import AdultDataset
from aif360.datasets import BankDataset
from aif360.datasets import CompasDataset
from aif360.datasets import GermanDataset
from aif360.metrics import BinaryLabelDatasetMetric

pd.set_option('display.max_rows', 50)
pd.set_option('display.max_columns', 10)
pd.set_option('display.width', 200)

def test_compas():
''' Test default loading for compas '''
# just test that there are no errors for default loading...
cd = CompasDataset()
# print(cd)
compas_dataset = CompasDataset()
compas_dataset.validate_dataset()

def test_german():
gd = GermanDataset()
bldm = BinaryLabelDatasetMetric(gd)
''' Test default loading for german '''
german_dataset = GermanDataset()
bldm = BinaryLabelDatasetMetric(german_dataset)
assert bldm.num_instances() == 1000

def test_adult_test_set():
ad = AdultDataset()
# test, train = ad.split([16281])
test, train = ad.split([15060])
''' Test default loading for adult, test set '''
adult_dataset = AdultDataset()
test, _ = adult_dataset.split([15060])
assert np.any(test.labels)

def test_adult():
ad = AdultDataset()
# print(ad.feature_names)
assert np.isclose(ad.labels.mean(), 0.2478, atol=5e-5)

bldm = BinaryLabelDatasetMetric(ad)
''' Test default loading for adult, mean'''
adult_dataset = AdultDataset()
assert np.isclose(adult_dataset.labels.mean(), 0.2478, atol=5e-5)
bldm = BinaryLabelDatasetMetric(adult_dataset)
assert bldm.num_instances() == 45222

def test_adult_no_drop():
ad = AdultDataset(protected_attribute_names=['sex'],
''' Test default loading for adult, number of instances '''
adult_dataset = AdultDataset(protected_attribute_names=['sex'],
privileged_classes=[['Male']], categorical_features=[],
features_to_keep=['age', 'education-num'])
bldm = BinaryLabelDatasetMetric(ad)
bldm = BinaryLabelDatasetMetric(adult_dataset)
assert bldm.num_instances() == 48842

def test_bank():
''' Check for errors during default loading '''
bd = BankDataset()

''' Test for errors during default loading '''
bank_dataset = BankDataset()
bank_dataset.validate_dataset()

@patch("pandas.read_csv")
def test_bank_priviliged_attributes(mock_read_csv):
''' Test if priviliged attributes are correctly transformed '''
data = {'y': ['yes', 'no', 'no', 'yes'],
'age': [43, 18, 89, 25]}
mock_read_csv.return_value = pd.DataFrame(data)
bank_dataset = BankDataset(categorical_features=[])
assert bank_dataset.convert_to_dataframe()[0]["age"].tolist() == [1.0, 0.0, 0.0, 1.0]

0 comments on commit 9d392d2

Please sign in to comment.