Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions tests/test_model_utils/test_binary_ks_curve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pytest
import numpy as np
from dython.model_utils import _binary_ks_curve


def test_binary_ks_curve_basic():
"""Test basic binary_ks_curve functionality"""
y_true = np.array([0, 1, 0, 1, 1, 0])
y_probas = np.array([0.1, 0.9, 0.3, 0.8, 0.7, 0.2])

thresholds, pct1, pct2, ks_stat, max_dist, classes = _binary_ks_curve(y_true, y_probas)

assert len(thresholds) > 0
assert len(pct1) == len(thresholds)
assert len(pct2) == len(thresholds)
assert 0 <= ks_stat <= 1
assert len(classes) == 2


def test_binary_ks_curve_multiclass_error():
"""Test binary_ks_curve with more than 2 classes"""
y_true = np.array([0, 1, 2, 0, 1, 2])
y_probas = np.array([0.1, 0.5, 0.9, 0.2, 0.6, 0.8])

with pytest.raises(ValueError):
_binary_ks_curve(y_true, y_probas)


def test_binary_ks_curve_thresholds_start_with_zero():
"""Test that thresholds start with 0"""
y_true = np.array([0, 1, 0, 1])
y_probas = np.array([0.2, 0.8, 0.3, 0.9])

thresholds, _, _, _, _, _ = _binary_ks_curve(y_true, y_probas)

assert thresholds[0] == 0.0


def test_binary_ks_curve_thresholds_end_with_one():
"""Test that thresholds end with 1"""
y_true = np.array([0, 1, 0, 1])
y_probas = np.array([0.2, 0.8, 0.3, 0.7])

thresholds, _, _, _, _, _ = _binary_ks_curve(y_true, y_probas)

assert thresholds[-1] == 1.0


def test_binary_ks_curve_with_edge_probabilities():
"""Test binary_ks_curve with probabilities at 0 and 1"""
y_true = np.array([0, 1, 0, 1])
y_probas = np.array([0.0, 1.0, 0.1, 0.9])

thresholds, pct1, pct2, ks_stat, max_dist, _ = _binary_ks_curve(y_true, y_probas)

assert len(thresholds) > 0
assert thresholds[0] == 0.0
assert thresholds[-1] == 1.0


def test_binary_ks_curve_data1_exhausted_first():
"""Test binary_ks_curve when data1 is exhausted before data2"""
y_true = np.array([0, 0, 1, 1, 1])
y_probas = np.array([0.1, 0.2, 0.6, 0.7, 0.8])

thresholds, pct1, pct2, ks_stat, max_dist, _ = _binary_ks_curve(y_true, y_probas)

assert len(thresholds) > 0
assert pct1[-1] == 1.0
assert pct2[-1] == 1.0


def test_binary_ks_curve_data2_exhausted_first():
"""Test binary_ks_curve when data2 is exhausted before data1"""
y_true = np.array([0, 0, 0, 1, 1])
y_probas = np.array([0.6, 0.7, 0.8, 0.1, 0.2])

thresholds, pct1, pct2, ks_stat, max_dist, _ = _binary_ks_curve(y_true, y_probas)

assert len(thresholds) > 0
assert pct1[-1] == 1.0
assert pct2[-1] == 1.0


def test_binary_ks_curve_equal_values():
"""Test binary_ks_curve with equal probability values"""
y_true = np.array([0, 0, 1, 1])
y_probas = np.array([0.5, 0.5, 0.5, 0.5])

thresholds, pct1, pct2, ks_stat, max_dist, _ = _binary_ks_curve(y_true, y_probas)

assert len(thresholds) > 0


def test_binary_ks_curve_interleaved_values():
"""Test binary_ks_curve with interleaved probability values"""
y_true = np.array([0, 1, 0, 1, 0, 1])
y_probas = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])

thresholds, pct1, pct2, ks_stat, max_dist, _ = _binary_ks_curve(y_true, y_probas)

assert len(thresholds) > 0
assert 0 <= ks_stat <= 1

121 changes: 121 additions & 0 deletions tests/test_model_utils/test_ks_abc_advanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import pytest
import numpy as np
import matplotlib.pyplot as plt
from dython.model_utils import ks_abc


def test_ks_abc_basic():
"""Test basic ks_abc functionality"""
y_true = [0, 1, 0, 1, 1, 0]
y_pred = [0.1, 0.9, 0.3, 0.8, 0.7, 0.2]

result = ks_abc(y_true, y_pred, plot=False)
assert 'abc' in result
assert 'ks_stat' in result
assert 'eopt' in result
assert 'ax' in result


def test_ks_abc_mismatched_shapes():
"""Test ks_abc with mismatched shapes"""
y_true = [0, 1, 0]
y_pred = [0.1, 0.9, 0.3, 0.8]

with pytest.raises(ValueError):
ks_abc(y_true, y_pred, plot=False)


def test_ks_abc_2d_binary():
"""Test ks_abc with 2D binary array"""
y_true = np.array([[1, 0], [0, 1], [1, 0], [0, 1]])
y_pred = np.array([[0.7, 0.3], [0.2, 0.8], [0.6, 0.4], [0.3, 0.7]])

result = ks_abc(y_true, y_pred, plot=False)
assert 'abc' in result


def test_ks_abc_single_column():
"""Test ks_abc with single column"""
y_true = np.array([[1], [0], [1], [0]])
y_pred = np.array([[0.9], [0.2], [0.7], [0.3]])

result = ks_abc(y_true, y_pred, plot=False)
assert 'abc' in result


def test_ks_abc_multiclass_error():
"""Test ks_abc with multiclass (should raise error)"""
y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
y_pred = np.array([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1], [0.2, 0.1, 0.7]])

with pytest.raises(ValueError):
ks_abc(y_true, y_pred, plot=False)


def test_ks_abc_with_ax():
"""Test ks_abc with provided ax"""
y_true = [0, 1, 0, 1]
y_pred = [0.1, 0.9, 0.3, 0.8]

fig, ax = plt.subplots()
result = ks_abc(y_true, y_pred, ax=ax, plot=False)
assert result['ax'] == ax
plt.close(fig)


def test_ks_abc_with_custom_params():
"""Test ks_abc with custom visualization parameters"""
y_true = [0, 1, 0, 1, 1, 0]
y_pred = [0.1, 0.9, 0.3, 0.8, 0.7, 0.2]

result = ks_abc(
y_true, y_pred,
colors=('red', 'blue'),
title='Custom KS Title',
xlim=(0, 0.5),
ylim=(0, 0.5),
fmt='.3f',
lw=3,
legend='upper left',
plot=False
)
assert 'abc' in result


def test_ks_abc_no_legend():
"""Test ks_abc without legend"""
y_true = [0, 1, 0, 1]
y_pred = [0.1, 0.9, 0.3, 0.8]

result = ks_abc(y_true, y_pred, legend=None, plot=False)
assert 'ax' in result


def test_ks_abc_with_filename(tmp_path):
"""Test ks_abc with filename"""
y_true = [0, 1, 0, 1]
y_pred = [0.1, 0.9, 0.3, 0.8]

filename = tmp_path / "ks_plot.png"
result = ks_abc(y_true, y_pred, filename=str(filename), plot=False)
assert filename.exists()
plt.close('all')


def test_ks_abc_abc_value_range():
"""Test that ABC value is in valid range"""
y_true = [0, 1, 0, 1, 1, 0, 0, 1]
y_pred = [0.1, 0.9, 0.2, 0.8, 0.7, 0.3, 0.15, 0.85]

result = ks_abc(y_true, y_pred, plot=False)
assert 0 <= result['abc'] <= 1


def test_ks_abc_ks_stat_value_range():
"""Test that KS statistic is in valid range"""
y_true = [0, 1, 0, 1, 1, 0, 0, 1]
y_pred = [0.1, 0.9, 0.2, 0.8, 0.7, 0.3, 0.15, 0.85]

result = ks_abc(y_true, y_pred, plot=False)
assert 0 <= result['ks_stat'] <= 1

Loading