Skip to content

Commit

Permalink
Permit dist mixture syntax using or for nominal-continuous [fix #83].
Browse files Browse the repository at this point in the history
  • Loading branch information
Feras A. Saad committed May 17, 2020
1 parent 8232e7c commit f5628d9
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
35 changes: 18 additions & 17 deletions src/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,18 @@
import scipy.stats
import sympy

class NominalDistribution():
class Distribution():
def __rmul__(self, x):
from .sym_util import sympify_number
try:
x_val = sympify_number(x)
if not 0 < x_val < 1:
raise ValueError('invalid weight %s' % (str(x),))
return DistributionMix([self], [x])
except TypeError:
return NotImplemented

class NominalDistribution(Distribution):
def __init__(self, dist):
self.dist = dict(dist)
def __call__(self, symbol):
Expand All @@ -13,9 +24,9 @@ def __call__(self, symbol):

choice = NominalDistribution

# pylint: disable=not-callable
# pylint: disable=multiple-statements
class RealDistribution():
class RealDistribution(Distribution):
# pylint: disable=not-callable
# pylint: disable=multiple-statements
dist = None
constructor = None
def __init__(self, *args, **kwargs):
Expand All @@ -27,17 +38,7 @@ def __call__(self, symbol):
def get_domain(self, **kwargs):
raise NotImplementedError()

def __rmul__(self, x):
from .sym_util import sympify_number
try:
x_val = sympify_number(x)
if not 0 < x_val < 1:
raise ValueError('invalid weight %s' % (str(x),))
return RealDistributionMix([self], [x])
except TypeError:
return NotImplemented

class RealDistributionMix():
class DistributionMix():
"""Weighted mixture of SPNs that do not yet sum to unity."""
def __init__(self, distributions, weights):
self.distributions = distributions
Expand All @@ -50,13 +51,13 @@ def __call__(self, symbol):
return SumSPN(distributions, weights)

def __or__(self, x):
if not isinstance(x, RealDistributionMix):
if not isinstance(x, DistributionMix):
return NotImplemented
weights = self.weights + x.weights
cumsum = float(sum(weights))
assert 0 < cumsum <= 1
distributions = self.distributions + x.distributions
return RealDistributionMix(distributions, weights)
return DistributionMix(distributions, weights)

# ==============================================================================
# ContinuousReal
Expand Down
17 changes: 14 additions & 3 deletions tests/test_parse_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,22 @@

import pytest

from spn.distributions import RealDistributionMix
from spn.distributions import DistributionMix
from spn.distributions import bernoulli
from spn.distributions import choice
from spn.distributions import norm
from spn.distributions import poisson
from spn.math_util import allclose
from spn.spn import ContinuousLeaf
from spn.spn import DiscreteLeaf
from spn.spn import NominalLeaf
from spn.spn import SumSPN
from spn.transforms import Id

X = Id('X')

def test_simple_parse():
assert isinstance(.3*bernoulli(p=.1), RealDistributionMix)
def test_simple_parse_real():
assert isinstance(.3*bernoulli(p=.1), DistributionMix)
a = .3*bernoulli(p=.1) | .5 * norm() | .2*poisson(mu=7)
spn = a(X)
assert isinstance(spn, SumSPN)
Expand All @@ -27,6 +29,15 @@ def test_simple_parse():
assert isinstance(spn.children[1], ContinuousLeaf)
assert isinstance(spn.children[2], DiscreteLeaf)

def test_simple_parse_nominal():
assert isinstance(.7 * choice({'a': .1, 'b': .9}), DistributionMix)
a = .3*bernoulli(p=.1) | .7*choice({'a': .1, 'b': .9})
spn = a(X)
assert isinstance(spn, SumSPN)
assert allclose(spn.weights, [log(.3), log(.7)])
assert isinstance(spn.children[0], DiscreteLeaf)
assert isinstance(spn.children[1], NominalLeaf)

def test_error():
with pytest.raises(TypeError):
'a'*bernoulli(p=.1)
Expand Down

0 comments on commit f5628d9

Please sign in to comment.