Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
elena-pascal committed Sep 22, 2022
1 parent 5f6ba99 commit 3f8e033
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
8 changes: 5 additions & 3 deletions fit_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,18 @@ def set_plot(rows: int = 2, cols: int = 2):
return axes.ravel()


distributions = {'discrete uniform': stats.randint,
'beta binomial': stats.betabinom,
'zipfian': stats.zipf}


if __name__ == '__main__':
data = read_file('test_data.txt')

axs = set_plot()

show_data(data, ax=axs[0], title='Input data')

distributions = {'discrete uniform': stats.randint,
'beta binomial': stats.betabinom,
'zipfian': stats.zipf}
bounds = guess_bounds(data)

for i, (key, dist) in enumerate(distributions.items()):
Expand Down
30 changes: 30 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
from scipy import stats
from fit_discrete import read_file, distributions, fit_distribution, guess_bounds


@pytest.mark.parametrize("key, expected", [('discrete uniform', 6.591673732008659),
('beta binomial', 6.068425588244196),
('zipfian', 7.86098064513523)])
def test_results(key, expected):
data = read_file('test_data.txt')
bounds = guess_bounds(data)

res = fit_distribution(data, distributions[key], bounds[key])
assert res.success
assert pytest.approx(expected, res.nllf())


def test_bounds():
data = stats.randint.rvs(low=0, high=10, size=100)
bounds = guess_bounds(data)

res = fit_distribution(data, stats.randint, bounds['discrete uniform'])
assert res.success

res = fit_distribution(data, stats.betabinom, bounds['beta binomial'])
assert res.success

res = fit_distribution(data, stats.zipf, bounds['zipfian'])
assert res.success

0 comments on commit 3f8e033

Please sign in to comment.