Skip to content

Commit

Permalink
Implementation of mixture distributions (pytorch#22742)
Browse files Browse the repository at this point in the history
Summary:
Addressing issue pytorch#18125
This implements a mixture distributions, where all components are from the same distribution family. Right now the implementation supports the ```mean, variance, sample, log_prob``` methods.

cc: fritzo and neerajprad

- [x] add import and `__all__` string in `torch/distributions/__init__.py`
- [x] register docs in docs/source/distributions.rst

### Tests
(all tests live in tests/distributions.py)
- [x] add an `Example(MixtureSameFamily, [...])` to the `EXAMPLES` list,
     populating `[...]` with three examples:
     one with `Normal`, one with `Categorical`, and one with `MultivariateNormal`
     (to exercise, `FloatTensor`, `LongTensor`, and nontrivial `event_dim`)
- [x] add a `test_mixture_same_family_shape()` to `TestDistributions`. It would be good to test this with both `Normal` and `MultivariateNormal`
- [x] add a `test_mixture_same_family_log_prob()` to `TestDistributions`.
- [x] add a `test_mixture_same_family_sample()` to `TestDistributions`.
- [x] add a `test_mixture_same_family_shape()` to `TestDistributionShapes`

### Triaged for follup-up PR?
- support batch shape
- implement `.expand()`
- implement `kl_divergence()` in torch/distributions/kl.py
Pull Request resolved: pytorch#22742

Differential Revision: D19899726

Pulled By: ezyang

fbshipit-source-id: 9c816e83a2ef104fe3ea3117c95680b51c7a2fa4
  • Loading branch information
Nicki Skafte authored and facebook-github-bot committed Feb 14, 2020
1 parent 7dde91b commit 4bef344
Show file tree
Hide file tree
Showing 4 changed files with 306 additions and 2 deletions.
9 changes: 9 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@ Probability distributions - torch.distributions
:undoc-members:
:show-inheritance:

:hidden:`MixtureSameFamily`
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torch.distributions.mixture_same_family
.. autoclass:: MixtureSameFamily
:members:
:undoc-members:
:show-inheritance:

:hidden:`Multinomial`
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
113 changes: 111 additions & 2 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
HalfCauchy, HalfNormal,
Independent, Laplace, LogisticNormal,
LogNormal, LowRankMultivariateNormal,
Multinomial, MultivariateNormal,
MixtureSameFamily, Multinomial, MultivariateNormal,
NegativeBinomial, Normal, OneHotCategorical, Pareto,
Poisson, RelaxedBernoulli, RelaxedOneHotCategorical,
StudentT, TransformedDistribution, Uniform,
Expand Down Expand Up @@ -430,7 +430,20 @@ def is_all_nan(tensor):
'scale': torch.randn(5, 5).abs().requires_grad_(),
'concentration': torch.randn(1).abs().requires_grad_()
}
])
]),
Example(MixtureSameFamily, [
{
'mixture_distribution': Categorical(torch.rand(5, requires_grad=True)),
'component_distribution': Normal(torch.randn(5, requires_grad=True),
torch.rand(5, requires_grad=True)),
},
{
'mixture_distribution': Categorical(torch.rand(5, requires_grad=True)),
'component_distribution': MultivariateNormal(
loc=torch.randn(5, 2, requires_grad=True),
covariance_matrix=torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True)),
},
])
]

BAD_EXAMPLES = [
Expand Down Expand Up @@ -1559,6 +1572,92 @@ def test_logisticnormal_sample(self):
'LogisticNormal(loc={}, scale={})'.format(mean_th, std_th),
multivariate=True)

def test_mixture_same_family_shape(self):
normal_case_1d = MixtureSameFamily(
Categorical(torch.rand(5)),
Normal(torch.randn(5), torch.rand(5)))
normal_case_1d_batch = MixtureSameFamily(
Categorical(torch.rand(3, 5)),
Normal(torch.randn(3, 5), torch.rand(3, 5)))
normal_case_1d_multi_batch = MixtureSameFamily(
Categorical(torch.rand(4, 3, 5)),
Normal(torch.randn(4, 3, 5), torch.rand(4, 3, 5)))
normal_case_2d = MixtureSameFamily(
Categorical(torch.rand(5)),
Independent(Normal(torch.randn(5, 2), torch.rand(5, 2)), 1))
normal_case_2d_batch = MixtureSameFamily(
Categorical(torch.rand(3, 5)),
Independent(Normal(torch.randn(3, 5, 2), torch.rand(3, 5, 2)), 1))
normal_case_2d_multi_batch = MixtureSameFamily(
Categorical(torch.rand(4, 3, 5)),
Independent(Normal(torch.randn(4, 3, 5, 2), torch.rand(4, 3, 5, 2)), 1))

self.assertEqual(normal_case_1d.sample().size(), ())
self.assertEqual(normal_case_1d.sample((2,)).size(), (2,))
self.assertEqual(normal_case_1d.sample((2, 7)).size(), (2, 7))
self.assertEqual(normal_case_1d_batch.sample().size(), (3,))
self.assertEqual(normal_case_1d_batch.sample((2,)).size(), (2, 3))
self.assertEqual(normal_case_1d_batch.sample((2, 7)).size(), (2, 7, 3))
self.assertEqual(normal_case_1d_multi_batch.sample().size(), (4, 3))
self.assertEqual(normal_case_1d_multi_batch.sample((2,)).size(), (2, 4, 3))
self.assertEqual(normal_case_1d_multi_batch.sample((2, 7)).size(), (2, 7, 4, 3))

self.assertEqual(normal_case_2d.sample().size(), (2,))
self.assertEqual(normal_case_2d.sample((2,)).size(), (2, 2))
self.assertEqual(normal_case_2d.sample((2, 7)).size(), (2, 7, 2))
self.assertEqual(normal_case_2d_batch.sample().size(), (3, 2))
self.assertEqual(normal_case_2d_batch.sample((2,)).size(), (2, 3, 2))
self.assertEqual(normal_case_2d_batch.sample((2, 7)).size(), (2, 7, 3, 2))
self.assertEqual(normal_case_2d_multi_batch.sample().size(), (4, 3, 2))
self.assertEqual(normal_case_2d_multi_batch.sample((2,)).size(), (2, 4, 3, 2))
self.assertEqual(normal_case_2d_multi_batch.sample((2, 7)).size(), (2, 7, 4, 3, 2))

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_mixture_same_family_log_prob(self):
probs = torch.rand(5, 5).softmax(dim=-1)
loc = torch.randn(5, 5)
scale = torch.rand(5, 5)

def ref_log_prob(idx, x, log_prob):
p = probs[idx].numpy()
m = loc[idx].numpy()
s = scale[idx].numpy()
mix = scipy.stats.multinomial(1, p)
comp = scipy.stats.norm(m, s)
expected = scipy.special.logsumexp(comp.logpdf(x) + np.log(mix.p))
self.assertAlmostEqual(log_prob, expected, places=3)

self._check_log_prob(
MixtureSameFamily(Categorical(probs=probs),
Normal(loc, scale)), ref_log_prob)

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_mixture_same_family_sample(self):
probs = torch.rand(5).softmax(dim=-1)
loc = torch.randn(5)
scale = torch.rand(5)

class ScipyMixtureNormal(object):
def __init__(self, probs, mu, std):
self.probs = probs
self.mu = mu
self.std = std

def rvs(self, n_sample):
comp_samples = [scipy.stats.norm(m, s).rvs(n_sample) for m, s
in zip(self.mu, self.std)]
mix_samples = scipy.stats.multinomial(1, self.probs).rvs(n_sample)
samples = []
for i in range(n_sample):
samples.append(comp_samples[mix_samples[i].argmax()][i])
return np.asarray(samples)

self._check_sampler_sampler(
MixtureSameFamily(Categorical(probs=probs), Normal(loc, scale)),
ScipyMixtureNormal(probs.numpy(), loc.numpy(), scale.numpy()),
'''MixtureSameFamily(Categorical(probs={}),
Normal(loc={}, scale={}))'''.format(probs, loc, scale))

def test_normal(self):
loc = torch.randn(5, 5, requires_grad=True)
scale = torch.randn(5, 5).abs().requires_grad_()
Expand Down Expand Up @@ -2945,6 +3044,16 @@ def test_dirichlet_shape(self):
simplex_sample = simplex_sample / simplex_sample.sum(-1).unsqueeze(-1)
self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 3)))

def test_mixture_same_family_shape(self):
dist = MixtureSameFamily(Categorical(torch.rand(5)),
Normal(torch.randn(5), torch.rand(5)))
self.assertEqual(dist._batch_shape, torch.Size())
self.assertEqual(dist._event_shape, torch.Size())
self.assertEqual(dist.sample().size(), torch.Size())
self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4)))
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))

def test_gamma_shape_scalar_params(self):
gamma = Gamma(1, 1)
self.assertEqual(gamma._batch_shape, torch.Size())
Expand Down
2 changes: 2 additions & 0 deletions torch/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
from .log_normal import LogNormal
from .logistic_normal import LogisticNormal
from .lowrank_multivariate_normal import LowRankMultivariateNormal
from .mixture_same_family import MixtureSameFamily
from .multinomial import Multinomial
from .multivariate_normal import MultivariateNormal
from .negative_binomial import NegativeBinomial
Expand Down Expand Up @@ -131,6 +132,7 @@
'LogNormal',
'LogisticNormal',
'LowRankMultivariateNormal',
'MixtureSameFamily',
'Multinomial',
'MultivariateNormal',
'NegativeBinomial',
Expand Down
184 changes: 184 additions & 0 deletions torch/distributions/mixture_same_family.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import torch
from torch.distributions.distribution import Distribution
from torch.distributions import Categorical
from torch.distributions import constraints


class MixtureSameFamily(Distribution):
r"""
The `MixtureSameFamily` distribution implements a (batch of) mixture
distribution where all component are from different parameterizations of
the same distribution type. It is parameterized by a `Categorical`
"selecting distribution" (over `k` component) and a component
distribution, i.e., a `Distribution` with a rightmost batch shape
(equal to `[k]`) which indexes each (batch of) component.
Examples::
# Construct Gaussian Mixture Model in 1D consisting of 5 equally
# weighted normal distributions
>>> mix = D.Categorical(torch.ones(5,))
>>> comp = D.Normal(torch.randn(5,), torch.rand(5,))
>>> gmm = MixtureSameFamily(mix, comp)
# Construct Gaussian Mixture Modle in 2D consisting of 5 equally
# weighted bivariate normal distributions
>>> mix = D.Categorical(torch.ones(5,))
>>> comp = D.Independent(D.Normal(
torch.randn(5,2), torch.rand(5,2)), 1)
>>> gmm = MixtureSameFamily(mix, comp)
# Construct a batch of 3 Gaussian Mixture Models in 2D each
# consisting of 5 random weighted bivariate normal distributions
>>> mix = D.Categorical(torch.rand(3,5))
>>> comp = D.Independent(D.Normal(
torch.randn(3,5,2), torch.rand(3,5,2)), 1)
>>> gmm = MixtureSameFamily(mix, comp)
Args:
mixture_distribution: `torch.distributions.Categorical`-like
instance. Manages the probability of selecting component.
The number of categories must match the rightmost batch
dimension of the `component_distribution`. Must have either
scalar `batch_shape` or `batch_shape` matching
`component_distribution.batch_shape[:-1]`
component_distribution: `torch.distributions.Distribution`-like
instance. Right-most batch dimension indexes component.
"""
arg_constraints = {}
has_rsample = False

def __init__(self,
mixture_distribution,
component_distribution,
validate_args=None):
self._mixture_distribution = mixture_distribution
self._component_distribution = component_distribution

if not isinstance(self._mixture_distribution, Categorical):
raise ValueError(" The Mixture distribution needs to be an "
" instance of torch.distribtutions.Categorical")

if not isinstance(self._component_distribution, Distribution):
raise ValueError("The Component distribution need to be an "
"instance of torch.distributions.Distribution")

# Check that batch size matches
mdbs = self._mixture_distribution.batch_shape
cdbs = self._component_distribution.batch_shape[:-1]
for size1, size2 in zip(reversed(mdbs), reversed(cdbs)):
if size1 != 1 and size2 != 1 and size1 != size2:
raise ValueError("`mixture_distribution.batch_shape` ({0}) is not "
"compatible with `component_distribution."
"batch_shape`({1})".format(mdbs, cdbs))

# Check that the number of mixture component matches
km = self._mixture_distribution.logits.shape[-1]
kc = self._component_distribution.batch_shape[-1]
if km is not None and kc is not None and km != kc:
raise ValueError("`mixture_distribution component` ({0}) does not"
" equal `component_distribution.batch_shape[-1]`"
" ({1})".format(km, kc))
self._num_component = km

event_shape = self._component_distribution.event_shape
self._event_ndims = len(event_shape)
super(MixtureSameFamily, self).__init__(batch_shape=cdbs,
event_shape=event_shape,
validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
batch_shape = torch.Size(batch_shape)
batch_shape_comp = batch_shape + (self._num_component,)
new = self._get_checked_instance(MixtureSameFamily, _instance)
new._component_distribution = \
self._component_distribution.expand(batch_shape_comp)
new._mixture_distribution = \
self._mixture_distribution.expand(batch_shape)
new._num_component = self._num_component
new._event_ndims = self._event_ndims
event_shape = new._component_distribution.event_shape
super(MixtureSameFamily, new).__init__(batch_shape=batch_shape,
event_shape=event_shape,
validate_args=False)
new._validate_args = self._validate_args
return new

@constraints.dependent_property
def support(self):
# FIXME this may have the wrong shape when support contains batched
# parameters
return self._component_distribution.support

@property
def mixture_distribution(self):
return self._mixture_distribution

@property
def component_distribution(self):
return self._component_distribution

@property
def mean(self):
probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
return torch.sum(probs * self.component_distribution.mean,
dim=-1 - self._event_ndims) # [B, E]

@property
def variance(self):
# Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
mean_cond_var = torch.sum(probs * self.component_distribution.variance,
dim=-1 - self._event_ndims)
var_cond_mean = torch.sum(probs * (self.component_distribution.mean -
self._pad(self.mean)).pow(2.0),
dim=-1 - self._event_ndims)
return mean_cond_var + var_cond_mean

def log_prob(self, x):
x = self._pad(x)
log_prob_x = self.component_distribution.log_prob(x) # [S, B, k]
log_mix_prob = torch.log_softmax(self.mixture_distribution.logits,
dim=-1) # [B, k]
return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B]

def sample(self, sample_shape=torch.Size()):
with torch.no_grad():
sample_len = len(sample_shape)
batch_len = len(self.batch_shape)
gather_dim = sample_len + batch_len
es = self.event_shape

# mixture samples [n, B]
mix_sample = self.mixture_distribution.sample(sample_shape)
mix_shape = mix_sample.shape

# component samples [n, B, k, E]
comp_samples = self.component_distribution.sample(sample_shape)

# Gather along the k dimension
mix_sample_r = mix_sample.reshape(
mix_shape + torch.Size([1] * (len(es) + 1)))
mix_sample_r = mix_sample_r.repeat(
torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es)

samples = torch.gather(comp_samples, gather_dim, mix_sample_r)
return samples.squeeze(gather_dim)

def _pad(self, x):
return x.unsqueeze(-1 - self._event_ndims)

def _pad_mixture_dimensions(self, x):
dist_batch_ndims = self.batch_shape.numel()
cat_batch_ndims = self.mixture_distribution.batch_shape.numel()
pad_ndims = 0 if cat_batch_ndims == 1 else \
dist_batch_ndims - cat_batch_ndims
xs = x.shape
x = x.reshape(xs[:-1] + torch.Size(pad_ndims * [1]) +
xs[-1:] + torch.Size(self._event_ndims * [1]))
return x

def __repr__(self):
args_string = '\n {},\n {}'.format(self.mixture_distribution,
self.component_distribution)
return 'MixtureSameFamily' + '(' + args_string + ')'

0 comments on commit 4bef344

Please sign in to comment.