Skip to content
This repository was archived by the owner on Jun 14, 2024. It is now read-only.

Changing variable broadcasting for factors #133

Merged
merged 12 commits into from
Nov 23, 2018
8 changes: 7 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ jobs:
include:
- os: osx
language: generic
env: PYTHON=3.6.6
env: PYTHON=3.4.9
- os: osx
language: generic
env: PYTHON=3.5.6
- os: osx
language: generic
env: PYTHON=3.6.7

- stage: release
python: '3.6' # Official supported Python dist.
Expand Down
26 changes: 19 additions & 7 deletions docs/design_documents/model_definition.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ In a probabilistic model, random variables relate to each other through
probabilistic distributions.

During model definition, the typical interface to generate a 2 dimensional
random variable *x* from a zero mean unit variance Gaussian distribution looks
like:
random variable ```m.x``` from a zero mean unit variance Gaussian distribution
looks like:

```python
m.x = Normal.define_variable(mean=0, variance=1, shape=(2,))
from mxnet.ndarray import array

m.x = Normal.define_variable(mean=array([0, 0]), variance=array([1, 1]), shape=(2,))
```

The two dimensions are
Expand All @@ -70,9 +72,11 @@ distribution. The parameters or shape of a distribution can also be variables, f
example:

```python
from mxnet.ndarray import array

m.mean = Variable(shape=(2,))
m.y_shape = Variable()
m.y = Normal.define_variable(mean=m.mean, variance=1, shape=m.y_shape)
m.y = Normal.define_variable(mean=m.mean, variance=array([1, 1]), shape=(m.y_shape,))
```

MXFusion also allows users to specify a prior distribution over pre-existing
Expand All @@ -83,12 +87,20 @@ distribution looks like:

```Python
m.x = Variable(shape=(2,))
m.x.set_prior(Gaussian(mean=0, variance=1))
m.x.set_prior(Gaussian(mean=array([0, 0]), variance=array([1, 1]))
```

The above code defines a variable *x* and sets the prior distribution of each
dimension of *x* to be a scalar unit Gaussian distribution.
The above code defines a variable ```m.x``` and sets the prior distribution of
each dimension of ```m.x``` to be a scalar unit Gaussian distribution.

In many cases, we apply the same prior distribution to multiple dimensions. In the above example, we simply want to set the individual dimensions of ```m.x``` to follow a zero-mean and unit-variance Gaussian. A more elegant way to define the above prior distribution is to make use of the broadcasting rule of multi-dimensional arrays:
```Python
from mxfusion.components.functions.operators import broadcast_to

m.x.set_prior(Gaussian(mean=broadcast_to(array([0]), m.x.shape),
variance=broadcast_to(array([1]), m.x.shape)))
```
Note that the shape of ```m.x``` may not always be available. In those cases, it is better to explicitly define the shape to be broadcasted to.

Because Models are FactorGraphs, it is common to want to know what ModelComponents come before or after a particular component in the graph. These are accessed through the ModelComponent properties ```successors``` and ```predecessors```.

Expand Down
221 changes: 60 additions & 161 deletions examples/notebooks/bnn_classification.ipynb

Large diffs are not rendered by default.

107 changes: 63 additions & 44 deletions examples/notebooks/bnn_regression.ipynb

Large diffs are not rendered by default.

43 changes: 25 additions & 18 deletions examples/notebooks/ppca_tutorial.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mxfusion/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# ==============================================================================


__version__ = '0.2.2'
__version__ = '0.3.0'
88 changes: 6 additions & 82 deletions mxfusion/components/distributions/bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,84 +13,8 @@
# ==============================================================================


from ..variables import Variable
from .univariate import UnivariateDistribution
from .distribution import LogPDFDecorator, DrawSamplesDecorator
from ...util.customop import broadcast_to_w_samples
from ..variables import get_num_samples, array_has_samples
from ...common.config import get_default_MXNet_mode
from ...common.exceptions import InferenceError


class BernoulliLogPDFDecorator(LogPDFDecorator):

def _wrap_log_pdf_with_broadcast(self, func):
def log_pdf_broadcast(self, F, **kw):
"""
Computes the logarithm of the probability density/mass function (PDF/PMF) of the distribution.

:param F: the MXNet computation mode (mxnet.symbol or mxnet.ndarray)
:param kw: the dict of input and output variables of the distribution
:type kw: {name: MXNet NDArray or MXNet Symbol}
:returns: log pdf of the distribution
:rtypes: MXNet NDArray or MXNet Symbol
"""
variables = {name: kw[name] for name, _ in self.inputs}
variables['random_variable'] = kw['random_variable']
rv_shape = variables['random_variable'].shape[1:]

n_samples = max([get_num_samples(F, v) for v in variables.values()])
full_shape = (n_samples,) + rv_shape

variables = {
name: broadcast_to_w_samples(F, v, full_shape[:-1]+(v.shape[-1],)) for name, v in variables.items()}
res = func(self, F=F, **variables)
return res
return log_pdf_broadcast


class BernoulliDrawSamplesDecorator(DrawSamplesDecorator):

def _wrap_draw_samples_with_broadcast(self, func):
def draw_samples_broadcast(self, F, rv_shape, num_samples=1,
always_return_tuple=False, **kw):
"""
Draw a number of samples from the distribution.

:param F: the MXNet computation mode (mxnet.symbol or mxnet.ndarray)
:param rv_shape: the shape of each sample
:type rv_shape: tuple
:param num_samples: the number of drawn samples (default: one)
:int n_samples: int
:param always_return_tuple: Whether return a tuple even if there is only one variables in outputs.
:type always_return_tuple: boolean
:param kw: the dict of input variables of the distribution
:type kw: {name: MXNet NDArray or MXNet Symbol}
:returns: a set samples of the distribution
:rtypes: MXNet NDArray or MXNet Symbol or [MXNet NDArray or MXNet Symbol]
"""
rv_shape = list(rv_shape.values())[0]
variables = {name: kw[name] for name, _ in self.inputs}

is_samples = any([array_has_samples(F, v) for v in variables.values()])
if is_samples:
num_samples_inferred = max([get_num_samples(F, v) for v in variables.values()])
if num_samples_inferred != num_samples:
raise InferenceError("The number of samples in the n_samples argument of draw_samples of "
"Bernoulli has to be the same as the number of samples given "
"to the inputs. n_samples: {} the inferred number of samples from "
"inputs: {}.".format(num_samples, num_samples_inferred))
full_shape = (num_samples,) + rv_shape

variables = {
name: broadcast_to_w_samples(F, v, full_shape[:-1]+(v.shape[-1],)) for name, v in
variables.items()}
res = func(self, F=F, rv_shape=rv_shape, num_samples=num_samples,
**variables)
if always_return_tuple:
res = (res,)
return res
return draw_samples_broadcast


class Bernoulli(UnivariateDistribution):
Expand Down Expand Up @@ -134,8 +58,7 @@ def replicate_self(self, attribute_map=None):
replicant = super(Bernoulli, self).replicate_self(attribute_map=attribute_map)
return replicant

@BernoulliLogPDFDecorator()
def log_pdf(self, prob_true, random_variable, F=None):
def log_pdf_impl(self, prob_true, random_variable, F=None):
"""
Computes the logarithm of probabilistic mass function of the Bernoulli distribution.

Expand All @@ -153,8 +76,7 @@ def log_pdf(self, prob_true, random_variable, F=None):
logL = logL * self.log_pdf_scaling
return logL

@BernoulliDrawSamplesDecorator()
def draw_samples(self, prob_true, rv_shape, num_samples=1, F=None):
def draw_samples_impl(self, prob_true, rv_shape, num_samples=1, F=None):
"""
Draw a number of samples from the Bernoulli distribution.

Expand All @@ -169,7 +91,8 @@ def draw_samples(self, prob_true, rv_shape, num_samples=1, F=None):
:rtypes: MXNet NDArray or MXNet Symbol
"""
F = get_default_MXNet_mode() if F is None else F
return self._rand_gen.sample_bernoulli(prob_true, shape=(num_samples,) + rv_shape, dtype=self.dtype, F=F)
return self._rand_gen.sample_bernoulli(
prob_true, shape=(num_samples,) + rv_shape, dtype=self.dtype, F=F)

@staticmethod
def define_variable(prob_true, shape=None, rand_gen=None, dtype=None, ctx=None):
Expand All @@ -189,6 +112,7 @@ def define_variable(prob_true, shape=None, rand_gen=None, dtype=None, ctx=None):
:returns: RandomVariable drawn from the Bernoulli distribution.
:rtypes: Variable
"""
bernoulli = Bernoulli(prob_true=prob_true, rand_gen=rand_gen, dtype=dtype, ctx=ctx)
bernoulli = Bernoulli(prob_true=prob_true, rand_gen=rand_gen,
dtype=dtype, ctx=ctx)
bernoulli._generate_outputs(shape=shape)
return bernoulli.random_variable
61 changes: 32 additions & 29 deletions mxfusion/components/distributions/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@


from ...common.config import get_default_MXNet_mode
from ..variables import Variable
from .univariate import UnivariateDistribution, UnivariateLogPDFDecorator, UnivariateDrawSamplesDecorator
from .univariate import UnivariateDistribution


class Beta(UnivariateDistribution):
Expand All @@ -24,35 +23,34 @@ class Beta(UnivariateDistribution):
array of random variables. In case of an array of random variables, a and b are broadcasted to the
shape of the output random variable (array).

:param a: a parameter (alpha) of the beta distribution.
:type a: Variable
:param b: b parameter (beta) of the beta distribution.
:type b: Variable
:param alpha: a parameter (alpha) of the beta distribution.
:type alpha: Variable
:param beta: b parameter (beta) of the beta distribution.
:type beta: Variable
:param rand_gen: the random generator (default: MXNetRandomGenerator).
:type rand_gen: RandomGenerator
:param dtype: the data type for float point numbers.
:type dtype: numpy.float32 or numpy.float64
:param ctx: the mxnet context (default: None/current context).
:type ctx: None or mxnet.cpu or mxnet.gpu
"""
def __init__(self, a, b, rand_gen=None, dtype=None, ctx=None):
inputs = [('a', a), ('b', b)]
def __init__(self, alpha, beta, rand_gen=None, dtype=None, ctx=None):
inputs = [('alpha', alpha), ('beta', beta)]
input_names = [k for k, _ in inputs]
output_names = ['random_variable']
super(Beta, self).__init__(inputs=inputs, outputs=None,
input_names=input_names,
output_names=output_names,
rand_gen=rand_gen, dtype=dtype, ctx=ctx)

@UnivariateLogPDFDecorator()
def log_pdf(self, a, b, random_variable, F=None):
def log_pdf_impl(self, alpha, beta, random_variable, F=None):
"""
Computes the logarithm of the probability density function (PDF) of the beta distribution.

:param a: the a parameter (alpha) of the beta distribution.
:type a: MXNet NDArray or MXNet Symbol
:param b: the b parameter (beta) of the beta distributions.
:type b: MXNet NDArray or MXNet Symbol
:param alpha: the a parameter (alpha) of the beta distribution.
:type alpha: MXNet NDArray or MXNet Symbol
:param beta: the b parameter (beta) of the beta distributions.
:type beta: MXNet NDArray or MXNet Symbol
:param random_variable: the random variable of the beta distribution.
:type random_variable: MXNet NDArray or MXNet Symbol
:param F: the MXNet computation mode (mxnet.symbol or mxnet.ndarray).
Expand All @@ -63,23 +61,23 @@ def log_pdf(self, a, b, random_variable, F=None):

log_x = F.log(random_variable)
log_1_minus_x = F.log(1 - random_variable)
log_beta_ab = F.gammaln(a) + F.gammaln(b) - F.gammaln(a + b)
log_beta_ab = F.gammaln(alpha) + F.gammaln(beta) - \
F.gammaln(alpha + beta)

log_likelihood = F.broadcast_add((a - 1) * log_x, ((b - 1) * log_1_minus_x)) - log_beta_ab
log_likelihood = F.broadcast_add((alpha - 1) * log_x, ((beta - 1) * log_1_minus_x)) - log_beta_ab
return log_likelihood

@UnivariateDrawSamplesDecorator()
def draw_samples(self, a, b, rv_shape, num_samples=1, F=None):
def draw_samples_impl(self, alpha, beta, rv_shape, num_samples=1, F=None):
"""
Draw samples from the beta distribution.

If X and Y are independent, with $X \sim \Gamma(\alpha, \theta)$ and $Y \sim \Gamma(\beta, \theta)$ then
$\frac {X}{X+Y}}\sim \mathrm {B} (\alpha ,\beta ).}$

:param a: the a parameter (alpha) of the beta distribution.
:type a: MXNet NDArray or MXNet Symbol
:param b: the b parameter (beta) of the beta distributions.
:type b: MXNet NDArray or MXNet Symbol
:param alpha: the a parameter (alpha) of the beta distribution.
:type alpha: MXNet NDArray or MXNet Symbol
:param beta: the b parameter (beta) of the beta distributions.
:type beta: MXNet NDArray or MXNet Symbol
:param rv_shape: the shape of each sample.
:type rv_shape: tuple
:param num_samples: the number of drawn samples (default: one).
Expand All @@ -90,26 +88,30 @@ def draw_samples(self, a, b, rv_shape, num_samples=1, F=None):
"""
F = get_default_MXNet_mode() if F is None else F

if a.shape != (num_samples, ) + rv_shape:
if alpha.shape != (num_samples, ) + rv_shape:
raise ValueError("Shape mismatch between inputs {} and random variable {}".format(
a.shape, (num_samples, ) + rv_shape))
alpha.shape, (num_samples, ) + rv_shape))

# Note output shape is determined by input dimensions
out_shape = () # (num_samples,) + rv_shape

ones = F.ones_like(a)
ones = F.ones_like(alpha)

# Sample X from Gamma(a, 1)
x = self._rand_gen.sample_gamma(alpha=a, beta=ones, shape=out_shape, dtype=self.dtype, ctx=self.ctx, F=F)
x = self._rand_gen.sample_gamma(
alpha=alpha, beta=ones, shape=out_shape, dtype=self.dtype,
ctx=self.ctx, F=F)

# Sample Y from Gamma(b, 1)
y = self._rand_gen.sample_gamma(alpha=b, beta=ones, shape=out_shape, dtype=self.dtype, ctx=self.ctx, F=F)
y = self._rand_gen.sample_gamma(
alpha=beta, beta=ones, shape=out_shape, dtype=self.dtype,
ctx=self.ctx, F=F)

# Return X / (X + Y)
return F.broadcast_div(x, F.broadcast_add(x, y))

@staticmethod
def define_variable(a=1., b=1., shape=None, rand_gen=None,
def define_variable(alpha=1., beta=1., shape=None, rand_gen=None,
dtype=None, ctx=None):
"""
Creates and returns a random variable drawn from a beta distribution.
Expand All @@ -127,6 +129,7 @@ def define_variable(a=1., b=1., shape=None, rand_gen=None,
:returns: the random variables drawn from the beta distribution.
:rtypes: Variable
"""
beta = Beta(a=a, b=b, rand_gen=rand_gen, dtype=dtype, ctx=ctx)
beta = Beta(alpha=alpha, beta=beta, rand_gen=rand_gen, dtype=dtype,
ctx=ctx)
beta._generate_outputs(shape=shape)
return beta.random_variable
Loading