Skip to content

Commit

Permalink
Update gini tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eugeneyan committed Sep 4, 2020
1 parent 614e703 commit 672d7d1
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions tests/tree/test_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@

@pytest.fixture
def dummy_feats_and_labels():
feats = np.array([[3.6216, 8.6661, -2.8073, -0.44699],
[4.5459, 8.1674, -2.4586, -1.4621],
[3.866, -2.6383, 1.9242, 0.10645],
[3.4566, 9.5228, -4.0112, -3.5944],
[0.32924, -4.4552, 4.5718, -0.9888],
[0.40614, 1.3492, -1.4501, -0.55949],
[-1.3887, -4.8773, 6.4774, 0.34179],
[-3.7503, -13.4586, 17.5932, -2.7771],
[-3.5637, -8.3827, 12.393, -1.2823],
[-2.5419, -0.65804, 2.6842, 1.1952]
feats = np.array([[0.7057, -5.4981, 8.3368, -2.8715],
[2.4391, 6.4417, -0.80743, -0.69139],
[-0.2062, 9.2207, -3.7044, -6.8103],
[4.2586, 11.2962, -4.0943, -4.3457],
[-2.343, 12.9516, 3.3285, -5.9426],
[-2.0545, -10.8679, 9.4926, -1.4116],
[2.2279, 4.0951, -4.8037, -2.1112],
[-6.1632, 8.7096, -0.21621, -3.6345],
[0.52374, 3.644, -4.0746, -1.9909],
[1.5077, 1.9596, -3.0584, -0.12243]
])
labels = np.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])
return feats, labels
Expand Down Expand Up @@ -78,14 +78,19 @@ def dummy_titanic_dt(dummy_titanic):


def test_gini_impurity():
assert round(gini_impurity([1, 1, 1, 0, 0, 0]), 3) == 0.500
assert round(gini_impurity([1, 1, 1, 1, 1, 1]), 3) == 0
assert round(gini_impurity([1, 1, 0, 0, 0, 0]), 3) == round(4 / 9, 3)
assert round(gini_impurity([1, 1, 1, 1, 1, 1, 1, 1]), 3) == 0
assert round(gini_impurity([1, 1, 1, 1, 1, 1, 1, 0]), 3) == 0.219
assert round(gini_impurity([1, 1, 1, 1, 1, 1, 0, 0]), 3) == 0.375
assert round(gini_impurity([1, 1, 1, 1, 1, 0, 0, 0]), 3) == 0.469
assert round(gini_impurity([1, 1, 1, 1, 0, 0, 0, 0]), 3) == 0.500
assert round(gini_impurity([1, 1, 0, 0, 0, 0, 0, 0]), 3) == 0.375


def test_gini_gain():
assert round(gini_gain([1, 1, 1, 0, 0, 0], [[1, 1, 1], [0, 0, 0]]), 3) == 0.5
assert round(gini_gain([1, 1, 1, 0, 0, 0], [[1, 1, 0], [1, 0, 0]]), 3) == 0.056
assert round(gini_gain([1, 1, 1, 1, 0, 0, 0, 0], [[1, 1, 1, 1], [0, 0, 0, 0]]), 3) == 0.5
assert round(gini_gain([1, 1, 1, 1, 0, 0, 0, 0], [[1, 1, 1, 0], [0, 0, 0, 1]]), 3) == 0.125
assert round(gini_gain([1, 1, 1, 1, 0, 0, 0, 0], [[1, 0, 0, 0], [0, 1, 1, 1]]), 3) == 0.125
assert round(gini_gain([1, 1, 1, 1, 0, 0, 0, 0], [[1, 1, 0, 0], [0, 0, 1, 1]]), 3) == 0.0


# Check model prediction to ensure: (i) same shape as labels, (ii) ranges from 0 to 1 inclusive
Expand Down Expand Up @@ -353,4 +358,4 @@ def test_dt_latency(dummy_titanic):

latency_array = np.array([predict_with_time(dt, X_test)[1] for i in range(500)])
latency_p99 = np.quantile(latency_array, 0.99)
assert latency_p99 < 0.003, 'Latency at 99th percentile should be < 0.003 sec'
assert latency_p99 < 0.004, 'Latency at 99th percentile should be < 0.004 sec'

0 comments on commit 672d7d1

Please sign in to comment.