diff --git a/src/distributions.py b/src/distributions.py index a741ea3..c142f4e 100644 --- a/src/distributions.py +++ b/src/distributions.py @@ -711,3 +711,13 @@ def __init__(self, *args, **kwargs): pk = [1./len(xk)] * len(xk) kwargs['values'] = (xk, pk) super().__init__(*args, **kwargs) + +class discrete(rv_discrete): + def __init__(self, *args, **kwargs): + assert len(args) == 1 + assert not kwargs + values = args[0] + xk = tuple(values.keys()) + pk = tuple(values.values()) + kwargs['values'] = (xk, pk) + super().__init__(**kwargs) diff --git a/tests/test_parse_distributions.py b/tests/test_parse_distributions.py index 82b942c..9d234bd 100644 --- a/tests/test_parse_distributions.py +++ b/tests/test_parse_distributions.py @@ -8,6 +8,7 @@ from spn.distributions import DistributionMix from spn.distributions import bernoulli from spn.distributions import choice +from spn.distributions import discrete from spn.distributions import norm from spn.distributions import poisson from spn.distributions import rv_discrete @@ -48,12 +49,15 @@ def test_error(): a(X) def test_parse_rv_discrete(): - dist = rv_discrete(values=((1, 2, 10), (.3, .5, .2))) - spn = dist(X) - assert allclose(spn.prob(X<<{1}), .3) - assert allclose(spn.prob(X<<{2}), .5) - assert allclose(spn.prob(X<<{10}), .2) - assert allclose(spn.prob(X<=10), 1) + for dist in [ + rv_discrete(values=((1, 2, 10), (.3, .5, .2))), + discrete({1: .3, 2: .5, 10: .2}) + ]: + spn = dist(X) + assert allclose(spn.prob(X<<{1}), .3) + assert allclose(spn.prob(X<<{2}), .5) + assert allclose(spn.prob(X<<{10}), .2) + assert allclose(spn.prob(X<=10), 1) dist = uniformd(values=((1, 2, 10, 0))) spn = dist(X)