forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implementation of mixture distributions (pytorch#22742)
Summary: Addressing issue pytorch#18125 This implements a mixture distributions, where all components are from the same distribution family. Right now the implementation supports the ```mean, variance, sample, log_prob``` methods. cc: fritzo and neerajprad - [x] add import and `__all__` string in `torch/distributions/__init__.py` - [x] register docs in docs/source/distributions.rst ### Tests (all tests live in tests/distributions.py) - [x] add an `Example(MixtureSameFamily, [...])` to the `EXAMPLES` list, populating `[...]` with three examples: one with `Normal`, one with `Categorical`, and one with `MultivariateNormal` (to exercise, `FloatTensor`, `LongTensor`, and nontrivial `event_dim`) - [x] add a `test_mixture_same_family_shape()` to `TestDistributions`. It would be good to test this with both `Normal` and `MultivariateNormal` - [x] add a `test_mixture_same_family_log_prob()` to `TestDistributions`. - [x] add a `test_mixture_same_family_sample()` to `TestDistributions`. - [x] add a `test_mixture_same_family_shape()` to `TestDistributionShapes` ### Triaged for follup-up PR? - support batch shape - implement `.expand()` - implement `kl_divergence()` in torch/distributions/kl.py Pull Request resolved: pytorch#22742 Differential Revision: D19899726 Pulled By: ezyang fbshipit-source-id: 9c816e83a2ef104fe3ea3117c95680b51c7a2fa4
- Loading branch information
1 parent
7dde91b
commit 4bef344
Showing
4 changed files
with
306 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
import torch | ||
from torch.distributions.distribution import Distribution | ||
from torch.distributions import Categorical | ||
from torch.distributions import constraints | ||
|
||
|
||
class MixtureSameFamily(Distribution): | ||
r""" | ||
The `MixtureSameFamily` distribution implements a (batch of) mixture | ||
distribution where all component are from different parameterizations of | ||
the same distribution type. It is parameterized by a `Categorical` | ||
"selecting distribution" (over `k` component) and a component | ||
distribution, i.e., a `Distribution` with a rightmost batch shape | ||
(equal to `[k]`) which indexes each (batch of) component. | ||
Examples:: | ||
# Construct Gaussian Mixture Model in 1D consisting of 5 equally | ||
# weighted normal distributions | ||
>>> mix = D.Categorical(torch.ones(5,)) | ||
>>> comp = D.Normal(torch.randn(5,), torch.rand(5,)) | ||
>>> gmm = MixtureSameFamily(mix, comp) | ||
# Construct Gaussian Mixture Modle in 2D consisting of 5 equally | ||
# weighted bivariate normal distributions | ||
>>> mix = D.Categorical(torch.ones(5,)) | ||
>>> comp = D.Independent(D.Normal( | ||
torch.randn(5,2), torch.rand(5,2)), 1) | ||
>>> gmm = MixtureSameFamily(mix, comp) | ||
# Construct a batch of 3 Gaussian Mixture Models in 2D each | ||
# consisting of 5 random weighted bivariate normal distributions | ||
>>> mix = D.Categorical(torch.rand(3,5)) | ||
>>> comp = D.Independent(D.Normal( | ||
torch.randn(3,5,2), torch.rand(3,5,2)), 1) | ||
>>> gmm = MixtureSameFamily(mix, comp) | ||
Args: | ||
mixture_distribution: `torch.distributions.Categorical`-like | ||
instance. Manages the probability of selecting component. | ||
The number of categories must match the rightmost batch | ||
dimension of the `component_distribution`. Must have either | ||
scalar `batch_shape` or `batch_shape` matching | ||
`component_distribution.batch_shape[:-1]` | ||
component_distribution: `torch.distributions.Distribution`-like | ||
instance. Right-most batch dimension indexes component. | ||
""" | ||
arg_constraints = {} | ||
has_rsample = False | ||
|
||
def __init__(self, | ||
mixture_distribution, | ||
component_distribution, | ||
validate_args=None): | ||
self._mixture_distribution = mixture_distribution | ||
self._component_distribution = component_distribution | ||
|
||
if not isinstance(self._mixture_distribution, Categorical): | ||
raise ValueError(" The Mixture distribution needs to be an " | ||
" instance of torch.distribtutions.Categorical") | ||
|
||
if not isinstance(self._component_distribution, Distribution): | ||
raise ValueError("The Component distribution need to be an " | ||
"instance of torch.distributions.Distribution") | ||
|
||
# Check that batch size matches | ||
mdbs = self._mixture_distribution.batch_shape | ||
cdbs = self._component_distribution.batch_shape[:-1] | ||
for size1, size2 in zip(reversed(mdbs), reversed(cdbs)): | ||
if size1 != 1 and size2 != 1 and size1 != size2: | ||
raise ValueError("`mixture_distribution.batch_shape` ({0}) is not " | ||
"compatible with `component_distribution." | ||
"batch_shape`({1})".format(mdbs, cdbs)) | ||
|
||
# Check that the number of mixture component matches | ||
km = self._mixture_distribution.logits.shape[-1] | ||
kc = self._component_distribution.batch_shape[-1] | ||
if km is not None and kc is not None and km != kc: | ||
raise ValueError("`mixture_distribution component` ({0}) does not" | ||
" equal `component_distribution.batch_shape[-1]`" | ||
" ({1})".format(km, kc)) | ||
self._num_component = km | ||
|
||
event_shape = self._component_distribution.event_shape | ||
self._event_ndims = len(event_shape) | ||
super(MixtureSameFamily, self).__init__(batch_shape=cdbs, | ||
event_shape=event_shape, | ||
validate_args=validate_args) | ||
|
||
def expand(self, batch_shape, _instance=None): | ||
batch_shape = torch.Size(batch_shape) | ||
batch_shape_comp = batch_shape + (self._num_component,) | ||
new = self._get_checked_instance(MixtureSameFamily, _instance) | ||
new._component_distribution = \ | ||
self._component_distribution.expand(batch_shape_comp) | ||
new._mixture_distribution = \ | ||
self._mixture_distribution.expand(batch_shape) | ||
new._num_component = self._num_component | ||
new._event_ndims = self._event_ndims | ||
event_shape = new._component_distribution.event_shape | ||
super(MixtureSameFamily, new).__init__(batch_shape=batch_shape, | ||
event_shape=event_shape, | ||
validate_args=False) | ||
new._validate_args = self._validate_args | ||
return new | ||
|
||
@constraints.dependent_property | ||
def support(self): | ||
# FIXME this may have the wrong shape when support contains batched | ||
# parameters | ||
return self._component_distribution.support | ||
|
||
@property | ||
def mixture_distribution(self): | ||
return self._mixture_distribution | ||
|
||
@property | ||
def component_distribution(self): | ||
return self._component_distribution | ||
|
||
@property | ||
def mean(self): | ||
probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) | ||
return torch.sum(probs * self.component_distribution.mean, | ||
dim=-1 - self._event_ndims) # [B, E] | ||
|
||
@property | ||
def variance(self): | ||
# Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) | ||
probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) | ||
mean_cond_var = torch.sum(probs * self.component_distribution.variance, | ||
dim=-1 - self._event_ndims) | ||
var_cond_mean = torch.sum(probs * (self.component_distribution.mean - | ||
self._pad(self.mean)).pow(2.0), | ||
dim=-1 - self._event_ndims) | ||
return mean_cond_var + var_cond_mean | ||
|
||
def log_prob(self, x): | ||
x = self._pad(x) | ||
log_prob_x = self.component_distribution.log_prob(x) # [S, B, k] | ||
log_mix_prob = torch.log_softmax(self.mixture_distribution.logits, | ||
dim=-1) # [B, k] | ||
return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B] | ||
|
||
def sample(self, sample_shape=torch.Size()): | ||
with torch.no_grad(): | ||
sample_len = len(sample_shape) | ||
batch_len = len(self.batch_shape) | ||
gather_dim = sample_len + batch_len | ||
es = self.event_shape | ||
|
||
# mixture samples [n, B] | ||
mix_sample = self.mixture_distribution.sample(sample_shape) | ||
mix_shape = mix_sample.shape | ||
|
||
# component samples [n, B, k, E] | ||
comp_samples = self.component_distribution.sample(sample_shape) | ||
|
||
# Gather along the k dimension | ||
mix_sample_r = mix_sample.reshape( | ||
mix_shape + torch.Size([1] * (len(es) + 1))) | ||
mix_sample_r = mix_sample_r.repeat( | ||
torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es) | ||
|
||
samples = torch.gather(comp_samples, gather_dim, mix_sample_r) | ||
return samples.squeeze(gather_dim) | ||
|
||
def _pad(self, x): | ||
return x.unsqueeze(-1 - self._event_ndims) | ||
|
||
def _pad_mixture_dimensions(self, x): | ||
dist_batch_ndims = self.batch_shape.numel() | ||
cat_batch_ndims = self.mixture_distribution.batch_shape.numel() | ||
pad_ndims = 0 if cat_batch_ndims == 1 else \ | ||
dist_batch_ndims - cat_batch_ndims | ||
xs = x.shape | ||
x = x.reshape(xs[:-1] + torch.Size(pad_ndims * [1]) + | ||
xs[-1:] + torch.Size(self._event_ndims * [1])) | ||
return x | ||
|
||
def __repr__(self): | ||
args_string = '\n {},\n {}'.format(self.mixture_distribution, | ||
self.component_distribution) | ||
return 'MixtureSameFamily' + '(' + args_string + ')' |