diff --git a/fit_discrete.py b/fit_discrete.py index fc0c8cb..f764c0c 100644 --- a/fit_discrete.py +++ b/fit_discrete.py @@ -70,6 +70,11 @@ def set_plot(rows: int = 2, cols: int = 2): return axes.ravel() +distributions = {'discrete uniform': stats.randint, + 'beta binomial': stats.betabinom, + 'zipfian': stats.zipf} + + if __name__ == '__main__': data = read_file('test_data.txt') @@ -77,9 +82,6 @@ def set_plot(rows: int = 2, cols: int = 2): show_data(data, ax=axs[0], title='Input data') - distributions = {'discrete uniform': stats.randint, - 'beta binomial': stats.betabinom, - 'zipfian': stats.zipf} bounds = guess_bounds(data) for i, (key, dist) in enumerate(distributions.items()): diff --git a/tests.py b/tests.py new file mode 100644 index 0000000..1d999ea --- /dev/null +++ b/tests.py @@ -0,0 +1,30 @@ +import pytest +from scipy import stats +from fit_discrete import read_file, distributions, fit_distribution, guess_bounds + + +@pytest.mark.parametrize("key, expected", [('discrete uniform', 6.591673732008659), + ('beta binomial', 6.068425588244196), + ('zipfian', 7.86098064513523)]) +def test_results(key, expected): + data = read_file('test_data.txt') + bounds = guess_bounds(data) + + res = fit_distribution(data, distributions[key], bounds[key]) + assert res.success + assert pytest.approx(expected, res.nllf()) + + +def test_bounds(): + data = stats.randint.rvs(low=0, high=10, size=100) + bounds = guess_bounds(data) + + res = fit_distribution(data, stats.randint, bounds['discrete uniform']) + assert res.success + + res = fit_distribution(data, stats.betabinom, bounds['beta binomial']) + assert res.success + + res = fit_distribution(data, stats.zipf, bounds['zipfian']) + assert res.success +