Skip to content

Commit

Permalink
Numerical(Real)/Ordinal(Real) are now ContinuousReal/DiscreteReal (fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Feras A. Saad committed Feb 21, 2020
1 parent efe3072 commit fbf9249
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 221 deletions.
291 changes: 191 additions & 100 deletions src/numerical.py → src/distributions.py

Large diffs are not rendered by default.

86 changes: 0 additions & 86 deletions src/ordinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,89 +4,3 @@
import scipy.stats
import sympy

from .spn import OrdinalDistribution

from .sym_util import Integers
from .sym_util import IntegersPos
from .sym_util import IntegersPos0

def Bernoulli(symbol, **kwargs):
"""A Bernoulli discrete random variable."""
return OrdinalDistribution(symbol, scipy.stats.bernoulli(**kwargs),
sympy.Range(0, 2))

def Betabinom(symbol, **kwargs):
"""A beta-binomial discrete random variable."""
return OrdinalDistribution(symbol, scipy.stats.betabinom(**kwargs),
sympy.Range(0, kwargs['n']+1))

def Binom(symbol, **kwargs):
"""A binomial discrete random variable."""
return OrdinalDistribution(symbol, scipy.stats.binom(**kwargs),
sympy.Range(0, kwargs['n']+1))

def Boltzmann(symbol, **kwargs):
"""A Boltzmann (Truncated Discrete Exponential) random variable."""
return OrdinalDistribution(symbol, scipy.stats.boltzmann(**kwargs),
sympy.Range(0, kwargs['N']+1))

def Dlaplace(symbol, **kwargs):
"""A Laplacian discrete random variable."""
return OrdinalDistribution(symbol, scipy.stats.dlaplace(**kwargs),
Integers)

def Geom(symbol, **kwargs):
"""A geometric discrete random variable."""
return OrdinalDistribution(symbol, scipy.stats.geom(**kwargs),
Integers)

def Hypergeom(symbol, **kwargs):
"""A hypergeometric discrete random variable."""
low = max(0, kwargs['N'], kwargs['N']-kwargs['M']+kwargs['n'])
high = min(kwargs['n'], kwargs['N'])
return OrdinalDistribution(symbol, scipy.stats.hypergeom(**kwargs),
sympy.Range(low, high+1))

def Logser(symbol, **kwargs):
"""A Logarithmic (Log-Series, Series) discrete random variable."""
return OrdinalDistribution(symbol, scipy.stats.logser(**kwargs),
IntegersPos)

def Nbinom(symbol, **kwargs):
"""A negative binomial discrete random variable."""
return OrdinalDistribution(symbol, scipy.stats.nbinom(**kwargs),
IntegersPos0)

def Planck(symbol, **kwargs):
"""A Planck discrete exponential random variable."""
return OrdinalDistribution(symbol, scipy.stats.planck(**kwargs),
IntegersPos0)

def Poisson(symbol, **kwargs):
"""A Poisson discrete random variable."""
return OrdinalDistribution(symbol, scipy.stats.poisson(**kwargs),
IntegersPos0)

def Randint(symbol, **kwargs):
"""A uniform discrete random variable."""
return OrdinalDistribution(symbol, scipy.stats.randint(**kwargs),
sympy.Range(kwargs['low'], kwargs['high']))

def Skellam(symbol, **kwargs):
"""A Skellam discrete random variable."""
return OrdinalDistribution(symbol, scipy.stats.skellam(**kwargs),
Integers)

def Zipf(symbol, **kwargs):
"""A Zipf discrete random variable."""
return OrdinalDistribution(symbol, scipy.stats.zipf(**kwargs),
IntegersPos)

def Yulesimon(symbol, **kwargs):
"""A Yule-Simon discrete random variable."""
return OrdinalDistribution(symbol, scipy.stats.yulesimon(**kwargs),
IntegersPos)

def Atomic(symbol, **kwargs):
"""A Yule-Simon discrete random variable."""
return Randint(symbol, low=kwargs['loc'], high=kwargs['loc']+1)
8 changes: 4 additions & 4 deletions src/spn.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,9 @@ def condition(self, event):
assert False, 'Unknown set type: %s' % (values,)

# ==============================================================================
# Numerical distribution.
# Continuous RealDistribution.

class NumericalDistribution(RealDistribution):
class ContinuousReal(RealDistribution):
"""Non-atomic distribution with a cumulative distribution function."""
def __init__(self, symbol, dist, support, conditioned=None):
super().__init__(symbol, dist, support, conditioned)
Expand Down Expand Up @@ -517,9 +517,9 @@ def logprob_interval(self, values):
return logdiffexp(logFu, logFl)

# ==============================================================================
# Ordinal distribution.
# Discrete RealDistribution.

class OrdinalDistribution(RealDistribution):
class DiscreteReal(RealDistribution):
"""Atomic distribution with a cumulative distribution function."""

def __init__(self, symbol, dist, support, conditioned=None):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_indian_gpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# See LICENSE.txt

from spn.math_util import allclose
from spn.numerical import Uniform
from spn.ordinal import Atomic
from spn.distributions import Uniform
from spn.distributions import Atomic
from spn.transforms import Identity

X = Identity('X')
Expand Down
2 changes: 1 addition & 1 deletion tests/test_mutual_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy

from spn.math_util import allclose
from spn.numerical import Norm
from spn.distributions import Norm
from spn.spn import ProductSPN
from spn.spn import SumSPN
from spn.transforms import Identity
Expand Down
12 changes: 6 additions & 6 deletions tests/test_numerical_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from spn.math_util import allclose
from spn.math_util import isinf_neg
from spn.math_util import logdiffexp
from spn.numerical import Gamma
from spn.numerical import Norm
from spn.spn import NumericalDistribution
from spn.distributions import Gamma
from spn.distributions import Norm
from spn.spn import ContinuousReal
from spn.spn import SumSPN
from spn.sym_util import Reals
from spn.transforms import Identity
Expand Down Expand Up @@ -53,7 +53,7 @@ def test_numeric_distribution_normal():

for event in [(X<-10), (X>3)]:
spn_condition_c = spn.condition(event)
assert isinstance(spn_condition_c, NumericalDistribution)
assert isinstance(spn_condition_c, ContinuousReal)
assert isinf_neg(spn_condition_c.logprob((-1 < X) < 1))
samples = spn_condition_c.sample(100, rng)
assert all(s[X] in event.values for s in samples)
Expand All @@ -79,13 +79,13 @@ def test_numeric_distribution_gamma():

# Intentionally set Reals as the domain to exercise an important
# code path in dist.condition (Union case with zero weights).
spn = NumericalDistribution(X, scipy.stats.gamma(a=1, scale=1), Reals)
spn = ContinuousReal(X, scipy.stats.gamma(a=1, scale=1), Reals)
assert isinf_neg(spn.logprob((X << {1, 2}) | (X < 0)))
with pytest.raises(ValueError):
spn.condition((X << {1, 2}) | (X < 0))

spn_condition = spn.condition((X << {1,2} | (X <= 3)))
assert isinstance(spn_condition, NumericalDistribution)
assert isinstance(spn_condition, ContinuousReal)
assert spn_condition.conditioned
assert spn_condition.support == sympy.Interval(-sympy.oo, 3)
assert allclose(
Expand Down
16 changes: 8 additions & 8 deletions tests/test_ordinal_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
import numpy
import sympy

from spn.spn import OrdinalDistribution
from spn.spn import SumSPN
from spn.distributions import Poisson
from spn.distributions import Randint
from spn.math_util import allclose
from spn.math_util import logdiffexp
from spn.math_util import logsumexp
from spn.ordinal import Poisson
from spn.ordinal import Randint
from spn.spn import DiscreteReal
from spn.spn import SumSPN
from spn.transforms import Identity

rng = numpy.random.RandomState(1)

def test_ordinal_distribution_poisson():
def test_poisson():
X = Identity('X')
spn = Poisson(X, mu=5)

Expand Down Expand Up @@ -46,7 +46,7 @@ def test_ordinal_distribution_poisson():
# Unify X = 5 with left interval to make one distribution.
event = ((1 <= X) < 5) | ((3*X + 1) << {16})
spn_condition = spn.condition(event)
assert isinstance(spn_condition, OrdinalDistribution)
assert isinstance(spn_condition, DiscreteReal)
assert spn_condition.conditioned
assert spn_condition.xl == 1
assert spn_condition.xu == 5
Expand All @@ -56,7 +56,7 @@ def test_ordinal_distribution_poisson():

# Ignore X = 14/3 as a probability zero condition.
spn_condition = spn.condition(((1 <= X) < 5) | (3*X + 1) << {15})
assert isinstance(spn_condition, OrdinalDistribution)
assert isinstance(spn_condition, DiscreteReal)
assert spn_condition.conditioned
assert spn_condition.xl == 1
assert spn_condition.xu == 4
Expand All @@ -78,7 +78,7 @@ def test_ordinal_distribution_poisson():
with pytest.raises(ValueError):
spn.condition(((-3 <= X) < 0) | (3*X + 1) << {20})

def test_ordinal_randint():
def test_randint():
X = Identity('X')
spn = Randint(X, low=0, high=5)
assert spn.logprob(X < 5) == spn.logprob(X <= 4) == 0
Expand Down
10 changes: 5 additions & 5 deletions tests/test_parse_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

import pytest

from spn.spn import NumericalDistribution
from spn.spn import ContinuousReal
from spn.spn import PartialSumSPN
from spn.spn import ProductSPN
from spn.spn import SumSPN

from spn.numerical import Gamma
from spn.numerical import Norm
from spn.distributions import Gamma
from spn.distributions import Norm
from spn.transforms import Identity

from spn.math_util import allclose
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_sum_of_sums():
assert isinstance(a, SumSPN)
assert isinstance(a.children[0], SumSPN)
assert isinstance(a.children[1], SumSPN)
assert isinstance(a.children[2], NumericalDistribution)
assert isinstance(a.children[2], ContinuousReal)

# Wrong symbol.
with pytest.raises(ValueError):
Expand All @@ -127,4 +127,4 @@ def test_or_and():
a = (0.3*Norm(X) | 0.7*Gamma(X, a=1)) & Norm(Z)
assert isinstance(a, ProductSPN)
assert isinstance(a.children[0], SumSPN)
assert isinstance(a.children[1], NumericalDistribution)
assert isinstance(a.children[1], ContinuousReal)
4 changes: 2 additions & 2 deletions tests/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from spn.math_util import isinf_neg
from spn.math_util import logdiffexp
from spn.math_util import logsumexp
from spn.numerical import Gamma
from spn.numerical import Norm
from spn.distributions import Gamma
from spn.distributions import Norm
from spn.spn import ProductSPN
from spn.spn import SumSPN
from spn.transforms import ExpNat as Exp
Expand Down
14 changes: 7 additions & 7 deletions tests/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from spn.math_util import isinf_neg
from spn.math_util import logdiffexp
from spn.math_util import logsumexp
from spn.numerical import Gamma
from spn.numerical import Norm
from spn.distributions import Gamma
from spn.distributions import Norm
from spn.spn import ExposedSumSPN
from spn.spn import NominalDistribution
from spn.spn import NumericalDistribution
from spn.spn import ContinuousReal
from spn.spn import ProductSPN
from spn.spn import SumSPN
from spn.sym_util import NominalSet
Expand Down Expand Up @@ -46,7 +46,7 @@ def test_sum_normal_gamma():
spn.sample_func(lambda Y: abs(X**3), 100, rng)

spn_condition = spn.condition(X < 0)
assert isinstance(spn_condition, NumericalDistribution)
assert isinstance(spn_condition, ContinuousReal)
assert spn_condition.conditioned
assert spn_condition.logprob(X < 0) == 0
samples = spn_condition.sample(100, rng)
Expand Down Expand Up @@ -91,10 +91,10 @@ def test_sum_normal_gamma_exposed():
spn_condition = spn.condition((W << {'1'}))
assert isinstance(spn_condition, ProductSPN)
assert isinstance(spn_condition.children[0], NominalDistribution)
assert isinstance(spn_condition.children[1], NumericalDistribution)
assert isinstance(spn_condition.children[1], ContinuousReal)
assert spn_condition.logprob(X < 5) == children[1].logprob(X < 5)

def test_sum_numerical_nominal():
def test_sum_normal_nominal():
X = Identity('X')
children = [
Norm(X, loc=0, scale=1),
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_sum_numerical_nominal():

with pytest.raises(ValueError):
spn_condition = spn.condition(X**2 < 9)
assert isinstance(spn_condition, NumericalDistribution)
assert isinstance(spn_condition, ContinuousReal)
assert spn_condition.support == sympy.Interval.open(-3, 3)

with pytest.raises(ValueError):
Expand Down

0 comments on commit fbf9249

Please sign in to comment.