Skip to content

Commit

Permalink
Add a discrete syntactic sugar for rv_discrete.
Browse files Browse the repository at this point in the history
  • Loading branch information
Feras A. Saad committed May 29, 2020
1 parent 545dfe7 commit 1943895
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
10 changes: 10 additions & 0 deletions src/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,3 +711,13 @@ def __init__(self, *args, **kwargs):
pk = [1./len(xk)] * len(xk)
kwargs['values'] = (xk, pk)
super().__init__(*args, **kwargs)

class discrete(rv_discrete):
def __init__(self, *args, **kwargs):
assert len(args) == 1
assert not kwargs
values = args[0]
xk = tuple(values.keys())
pk = tuple(values.values())
kwargs['values'] = (xk, pk)
super().__init__(**kwargs)
16 changes: 10 additions & 6 deletions tests/test_parse_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from spn.distributions import DistributionMix
from spn.distributions import bernoulli
from spn.distributions import choice
from spn.distributions import discrete
from spn.distributions import norm
from spn.distributions import poisson
from spn.distributions import rv_discrete
Expand Down Expand Up @@ -48,12 +49,15 @@ def test_error():
a(X)

def test_parse_rv_discrete():
dist = rv_discrete(values=((1, 2, 10), (.3, .5, .2)))
spn = dist(X)
assert allclose(spn.prob(X<<{1}), .3)
assert allclose(spn.prob(X<<{2}), .5)
assert allclose(spn.prob(X<<{10}), .2)
assert allclose(spn.prob(X<=10), 1)
for dist in [
rv_discrete(values=((1, 2, 10), (.3, .5, .2))),
discrete({1: .3, 2: .5, 10: .2})
]:
spn = dist(X)
assert allclose(spn.prob(X<<{1}), .3)
assert allclose(spn.prob(X<<{2}), .5)
assert allclose(spn.prob(X<<{10}), .2)
assert allclose(spn.prob(X<=10), 1)

dist = uniformd(values=((1, 2, 10, 0)))
spn = dist(X)
Expand Down

0 comments on commit 1943895

Please sign in to comment.