forked from pyro-ppl/pyro
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add GroupedNormalNormal distribution (pyro-ppl#3163)
- Loading branch information
1 parent
77a67ff
commit 3422c3a
Showing
4 changed files
with
213 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import math | ||
|
||
import torch | ||
from torch.distributions.utils import broadcast_all | ||
|
||
from pyro.distributions import Normal, constraints | ||
from pyro.distributions.torch_distribution import TorchDistribution | ||
|
||
LOG_ROOT_TWO_PI = 0.5 * math.log(2.0 * math.pi) | ||
|
||
|
||
class GroupedNormalNormal(TorchDistribution): | ||
r""" | ||
This likelihood, which operates on groups of real-valued scalar observations, is obtained by | ||
integrating out a latent mean for each group. Both the prior on each latent mean as well as the | ||
observation likelihood for each data point are univariate Normal distributions. | ||
The prior means are controlled by `prior_loc` and `prior_scale`. The observation noise of the | ||
Normal likelihood is controlled by `obs_scale`, which is allowed to vary from observation to | ||
observation. The tensor of indices `group_idx` connects each observation to one of the groups | ||
specified by `prior_loc` and `prior_scale`. | ||
See e.g. Eqn. (55) in ref. [1] for relevant expressions in a simpler case with scalar `obs_scale`. | ||
Example: | ||
>>> num_groups = 3 | ||
>>> num_data = 4 | ||
>>> prior_loc = torch.randn(num_groups) | ||
>>> prior_scale = torch.rand(num_groups) | ||
>>> obs_scale = torch.rand(num_data) | ||
>>> group_idx = torch.tensor([1, 0, 2, 1]).long() | ||
>>> values = torch.randn(num_data) | ||
>>> gnn = GroupedNormalNormal(prior_loc, prior_scale, obs_scale, group_idx) | ||
>>> assert gnn.log_prob(values).shape == () | ||
References: | ||
[1] "Conjugate Bayesian analysis of the Gaussian distribution," Kevin P. Murphy. | ||
:param torch.Tensor prior_loc: Tensor of shape `(num_groups,)` specifying the prior mean of the latent | ||
of each group. | ||
:param torch.Tensor prior_scale: Tensor of shape `(num_groups,)` specifying the prior scale of the latent | ||
of each group. | ||
:param torch.Tensor obs_scale: Tensor of shape `(num_data,)` specifying the scale of the observation noise | ||
of each observation. | ||
:param torch.LongTensor group_idx: Tensor of indices of shape `(num_data,)` linking each observation to one | ||
of the `num_groups` groups that are specified in `prior_loc` and `prior_scale`. | ||
""" | ||
arg_constraints = { | ||
"prior_loc": constraints.real, | ||
"prior_scale": constraints.positive, | ||
"obs_scale": constraints.positive, | ||
} | ||
support = constraints.real | ||
|
||
def __init__( | ||
self, prior_loc, prior_scale, obs_scale, group_idx, validate_args=None | ||
): | ||
if prior_loc.ndim not in [0, 1] or prior_scale.ndim not in [0, 1]: | ||
raise ValueError( | ||
"prior_loc and prior_scale must be broadcastable to 1D tensors of the same shape." | ||
) | ||
|
||
if obs_scale.ndim not in [0, 1]: | ||
raise ValueError( | ||
"obs_scale must be broadcastable to a 1-dimensional tensor." | ||
) | ||
|
||
if group_idx.ndim != 1 or not isinstance(group_idx, torch.LongTensor): | ||
raise ValueError("group_idx must be a 1-dimensional tensor of indices.") | ||
|
||
prior_loc, prior_scale = broadcast_all(prior_loc, prior_scale) | ||
obs_scale, group_idx = broadcast_all(obs_scale, group_idx) | ||
|
||
self.prior_loc = prior_loc | ||
self.prior_scale = prior_scale | ||
self.obs_scale = obs_scale | ||
self.group_idx = group_idx | ||
batch_shape = prior_loc.shape[:-1] | ||
|
||
if batch_shape != torch.Size([]): | ||
raise ValueError("GroupedNormalNormal only supports trivial batch_shape's.") | ||
|
||
self.num_groups = prior_loc.size(0) | ||
if group_idx.min().item() < 0 or group_idx.max().item() >= self.num_groups: | ||
raise ValueError( | ||
"Each index in group_idx must be an integer in the inclusive range [0, prior_loc.size(0) - 1]." | ||
) | ||
|
||
self.num_data_per_batch = prior_loc.new_zeros(self.num_groups).scatter_add( | ||
0, self.group_idx, prior_loc.new_ones(self.group_idx.shape) | ||
) | ||
super().__init__(batch_shape, validate_args=validate_args) | ||
|
||
def expand(self, batch_shape, _instance=None): | ||
raise NotImplementedError | ||
|
||
def sample(self, sample_shape=()): | ||
raise NotImplementedError | ||
|
||
def get_posterior(self, value): | ||
""" | ||
Get a `pyro.distributions.Normal` distribution that encodes the posterior distribution | ||
over the vector of latents specified by `prior_loc` and `prior_scale` conditioned on the | ||
observed data specified by `value`. | ||
""" | ||
if value.shape != self.group_idx.shape: | ||
raise ValueError( | ||
"GroupedNormalNormal.get_posterior only supports values that have the same shape as group_idx." | ||
) | ||
|
||
obs_scale_sq_inv = self.obs_scale.pow(-2) | ||
prior_scale_sq_inv = self.prior_scale.pow(-2) | ||
|
||
obs_scale_sq_inv_sum = torch.zeros_like(self.prior_loc).scatter_add( | ||
0, self.group_idx, obs_scale_sq_inv | ||
) | ||
precision = prior_scale_sq_inv + obs_scale_sq_inv_sum | ||
scaled_value_sum = torch.zeros_like(self.prior_loc).scatter_add( | ||
0, self.group_idx, value * obs_scale_sq_inv | ||
) | ||
|
||
loc = (scaled_value_sum + self.prior_loc * prior_scale_sq_inv) / precision | ||
scale = precision.rsqrt() | ||
|
||
return Normal(loc=loc, scale=scale) | ||
|
||
def log_prob(self, value): | ||
if self._validate_args: | ||
self._validate_sample(value) | ||
|
||
group_idx = self.group_idx | ||
|
||
if value.shape != group_idx.shape: | ||
raise ValueError( | ||
"GroupedNormalNormal.log_prob only supports values that have the same shape as group_idx." | ||
) | ||
|
||
prior_scale_sq = self.prior_scale.pow(2.0) | ||
obs_scale_sq_inv = self.obs_scale.pow(-2) | ||
obs_scale_sq_inv_sum = torch.zeros_like(self.prior_loc).scatter_add( | ||
0, self.group_idx, obs_scale_sq_inv | ||
) | ||
|
||
scale_ratio = prior_scale_sq * obs_scale_sq_inv_sum | ||
delta = value - self.prior_loc[group_idx] | ||
scaled_delta = delta * obs_scale_sq_inv | ||
scaled_delta_sum = torch.zeros_like(self.prior_loc).scatter_add( | ||
0, self.group_idx, scaled_delta | ||
) | ||
|
||
result1 = -(self.num_data_per_batch * LOG_ROOT_TWO_PI).sum() | ||
result2 = -0.5 * torch.log1p(scale_ratio).sum() - self.obs_scale.log().sum() | ||
result3 = -0.5 * torch.dot(delta, scaled_delta) | ||
numerator = prior_scale_sq * scaled_delta_sum.pow(2) | ||
result4 = 0.5 * (numerator / (1.0 + scale_ratio)).sum() | ||
|
||
return result1 + result2 + result3 + result4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import math | ||
|
||
import torch | ||
|
||
from pyro.distributions import GroupedNormalNormal, Normal | ||
from tests.common import assert_close | ||
|
||
|
||
def test_grouped_normal_normal(num_groups=3, num_samples=10**5): | ||
prior_scale = torch.rand(num_groups) | ||
prior_loc = torch.randn(num_groups) | ||
group_idx = torch.cat( | ||
[torch.arange(num_groups), torch.arange(num_groups), torch.zeros(2).long()] | ||
) | ||
values = torch.randn(group_idx.shape) | ||
obs_scale = torch.rand(group_idx.shape) | ||
|
||
# shape checks | ||
gnn = GroupedNormalNormal(prior_loc, prior_scale, obs_scale, group_idx) | ||
assert gnn.log_prob(values).shape == () | ||
posterior = gnn.get_posterior(values) | ||
loc, scale = posterior.loc, posterior.scale | ||
assert loc.shape == scale.shape == (num_groups,) | ||
|
||
# test correctness of log_prob | ||
prior_scale = 1 + torch.rand(1).double() | ||
prior_loc = torch.randn(1).double() | ||
group_idx = torch.zeros(2).long() | ||
values = torch.randn(group_idx.shape) | ||
obs_scale = 0.5 + torch.rand(group_idx.shape).double() | ||
|
||
gnn = GroupedNormalNormal(prior_loc, prior_scale, obs_scale, group_idx) | ||
actual = gnn.log_prob(values).item() | ||
|
||
prior = Normal(0.0, prior_scale) | ||
z = prior.sample(sample_shape=(num_samples // 2,)) | ||
z = torch.cat([prior_loc + z, prior_loc - z]) | ||
log_likelihood = Normal(z, obs_scale).log_prob(values).sum(-1) | ||
expected = torch.logsumexp(log_likelihood, dim=-1).item() - math.log(num_samples) | ||
|
||
assert_close(actual, expected, atol=0.001) |