Skip to content

Commit

Permalink
Add log_prob method to Stable (same one that already exists in Stable…
Browse files Browse the repository at this point in the history
…WithLogProb) (#3370)
  • Loading branch information
BenZickel authored Jun 1, 2024
1 parent 0678b35 commit 55750ed
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 63 deletions.
62 changes: 40 additions & 22 deletions pyro/distributions/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all

from pyro.distributions.stable_log_prob import StableLogProb
from pyro.distributions.stable_log_prob import _stable_log_prob
from pyro.distributions.torch_distribution import TorchDistribution


Expand Down Expand Up @@ -105,9 +105,12 @@ class Stable(TorchDistribution):
pass ``coords="S"``, but BEWARE this is discontinuous at ``stability=1``
and has poor geometry for inference.
This implements a reparametrized sampler :meth:`rsample` , but does not
implement :meth:`log_prob` . Inference can be performed using either
likelihood-free algorithms such as
This implements a reparametrized sampler :meth:`rsample` , and a relatively
expensive :meth:`log_prob` calculation by numerical integration which makes
inference slow (compared to other distributions) , but with better
convergence properties especially for :math:`\alpha`-stable distributions
that are skewed (see the ``skew`` parameter below). Faster
inference can be performed using either likelihood-free algorithms such as
:class:`~pyro.infer.energy_distance.EnergyDistance`, or reparameterization
via the :func:`~pyro.poutine.handlers.reparam` handler with one of the
reparameterizers :class:`~pyro.infer.reparam.stable.LatentStableReparam` ,
Expand Down Expand Up @@ -176,7 +179,32 @@ def expand(self, batch_shape, _instance=None):
return new

def log_prob(self, value):
raise NotImplementedError("Stable.log_prob() is not implemented")
r"""Implemented by numerical integration that is based on the algorithm
proposed by Chambers, Mallows and Stuck (CMS) for simulating the
Levy :math:`\alpha`-stable distribution. The CMS algorithm involves a
nonlinear transformation of two independent random variables into
one stable random variable. The first random variable is uniformly
distributed while the second is exponentially distributed. The numerical
integration is performed over the first uniformly distributed random
variable.
"""
if self._validate_args:
self._validate_sample(value)

# Undo shift and scale
value = (value - self.loc) / self.scale
value_dtype = value.dtype

# Use double precision math
alpha = self.stability.double()
beta = self.skew.double()
value = value.double()

alpha, beta, value = broadcast_all(alpha, beta, value)

log_prob = _stable_log_prob(alpha, beta, value, self.coords)

return log_prob.to(dtype=value_dtype) - self.scale.log()

def rsample(self, sample_shape=torch.Size()):
# Draw parameter-free noise.
Expand Down Expand Up @@ -207,22 +235,12 @@ def variance(self):
return var.mul(2).masked_fill(self.stability < 2, math.inf)


class StableWithLogProb(StableLogProb, Stable):
class StableWithLogProb(Stable):
r"""
Levy :math:`\alpha`-stable distribution that is based on
:class:`Stable` but with an added method for calculating the
log probability density using numerical integration.
This should be used in cases where reparameterization does not work
like when trying to estimate the skew :math:`\beta` parameter. Running
times are slower than with reparameterization.
The numerical integration implementation is based on the algorithm
proposed by Chambers, Mallows and Stuck (CMS) for simulating the
Levy :math:`\alpha`-stable distribution. The CMS algorithm involves a
nonlinear transformation of two independent random variables into
one stable random variable. The first random variable is uniformly
distributed while the second is exponentially distributed. The numerical
integration is performed over the first uniformly distributed random
variable.
Same as :class:`Stable` but will not undergo reparameterization by
:class:`~pyro.infer.reparam.strategies.MinimalReparam` and will fail
reparametrization by
:class:`~pyro.infer.reparam.stable.LatentStableReparam` ,
:class:`~pyro.infer.reparam.stable.SymmetricStableReparam` , or
:class:`~pyro.infer.reparam.stable.StableReparam`.
"""
16 changes: 0 additions & 16 deletions pyro/distributions/stable_log_prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,6 @@ def integrate(*args, **kwargs):
return integrate(*args, **kwargs)


class StableLogProb:
def log_prob(self, value):
# Undo shift and scale
value = (value - self.loc) / self.scale
value_dtype = value.dtype

# Use double precision math
alpha = self.stability.double()
beta = self.skew.double()
value = value.double()

log_prob = _stable_log_prob(alpha, beta, value, self.coords)

return log_prob.to(dtype=value_dtype) - self.scale.log()


def _stable_log_prob(alpha, beta, value, coords):
# Convert to Nolan's parametrization S^0 where samples depend
# continuously on (alpha,beta), allowing interpolation around the hole at
Expand Down
12 changes: 10 additions & 2 deletions pyro/infer/reparam/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def apply(self, msg):
is_observed = msg["is_observed"]

fn, event_dim = self._unwrap(fn)
assert isinstance(fn, dist.Stable) and fn.coords == "S0"
assert (
isinstance(fn, dist.Stable)
and fn.coords == "S0"
and not isinstance(fn, dist.StableWithLogProb)
)
if is_observed:
raise NotImplementedError(
f"At pyro.sample({repr(name)},...), "
Expand Down Expand Up @@ -101,7 +105,11 @@ def apply(self, msg):
is_observed = msg["is_observed"]

fn, event_dim = self._unwrap(fn)
assert isinstance(fn, dist.Stable) and fn.coords == "S0"
assert (
isinstance(fn, dist.Stable)
and fn.coords == "S0"
and not isinstance(fn, dist.StableWithLogProb)
)
if is_validation_enabled():
if not (fn.skew == 0).all():
raise ValueError("SymmetricStableReparam found nonzero skew")
Expand Down
27 changes: 22 additions & 5 deletions tests/distributions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,16 +500,33 @@ def __init__(self, von_loc, von_conc, skewness):
),
Fixture(
pyro_dist=dist.Stable,
scipy_dist=sp.levy_stable,
examples=[
{"stability": [1.5], "skew": 0.1, "test_data": [-10.0]},
{
"stability": [1.5],
"skew": 0.1,
# Skew is zero as the default parameterization of the scipy
# implementation is S and cannot be changed via initizalization
# arguments (pyro's default parameterization is S0 which
# gives different results with non-zero skew).
# Testing with non-zero skew is done in
# tests.distributions.test_stable_log_prob and
# tests.distributions.test_stable
{"stability": [1.5], "skew": 0.0, "test_data": [-10.0]},
{
"stability": [1.5, 0.5],
"skew": 0.0,
"scale": 2.0,
"loc": -2.0,
"test_data": [10.0],
"test_data": [10.0, -10.0],
},
],
scipy_arg_fn=lambda stability, skew, scale, loc: (
(),
{
"alpha": np.array(stability),
"beta": np.array(skew),
"scale": np.array(scale),
"loc": np.array(loc),
},
),
),
Fixture(
pyro_dist=dist.MultivariateStudentT,
Expand Down
7 changes: 1 addition & 6 deletions tests/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def test_mean(continuous_dist):
"SineBivariateVonMises",
"VonMises",
"ProjectedNormal",
"Stable",
]:
pytest.xfail(reason="Euclidean mean is not defined")
for i in range(continuous_dist.get_num_test_data()):
Expand Down Expand Up @@ -310,8 +311,6 @@ def test_expand_by(dist, sample_shape, shape_type):
small = dist.pyro_dist(**dist.get_dist_params(idx))
large = small.expand_by(shape_type(sample_shape))
assert large.batch_shape == sample_shape + small.batch_shape
if dist.get_test_distribution_name() == "Stable":
pytest.skip("Stable does not implement a log_prob method.")
check_sample_shapes(small, large)


Expand All @@ -329,8 +328,6 @@ def test_expand_new_dim(dist, sample_shape, shape_type, default):
with xfail_if_not_implemented():
large = small.expand(shape_type(sample_shape + small.batch_shape))
assert large.batch_shape == sample_shape + small.batch_shape
if dist.get_test_distribution_name() == "Stable":
pytest.skip("Stable does not implement a log_prob method.")
check_sample_shapes(small, large)


Expand All @@ -351,8 +348,6 @@ def test_expand_existing_dim(dist, shape_type, default):
with xfail_if_not_implemented():
large = small.expand(shape_type(batch_shape))
assert large.batch_shape == batch_shape
if dist.get_test_distribution_name() == "Stable":
pytest.skip("Stable does not implement a log_prob method.")
check_sample_shapes(small, large)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import pyro
import pyro.distributions
import pyro.distributions.stable_log_prob
from pyro.distributions import StableWithLogProb as Stable
from pyro.distributions import constraints
from pyro.distributions import Stable, constraints
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from tests.common import assert_close
Expand Down Expand Up @@ -41,7 +40,7 @@ def test_stable_gof(stability, skew):
# Check goodness of fit of samples to scipy's implementation of the log-probability calculation.
logging.info(
f"Calculating log-probability of (stablity={stability}, "
"skew={skew}) for {len(samples_scipy)} samples with scipy"
f"skew={skew}) for {len(samples_scipy)} samples with scipy"
)
probs_scipy = torch.Tensor(dist_scipy.pdf(samples_scipy))
gof_scipy = auto_goodness_of_fit(samples_scipy, probs_scipy)
Expand All @@ -53,7 +52,7 @@ def test_stable_gof(stability, skew):
# Check goodness of fit of pyro's implementation of the log-probability calculation to generated samples.
logging.info(
f"Calculating log-probability of (stablity={stability}, "
"skew={skew}) for {len(samples)} samples with pyro"
f"skew={skew}) for {len(samples)} samples with pyro"
)
probs = dist.log_prob(samples).exp()
gof = auto_goodness_of_fit(samples, probs)
Expand Down
16 changes: 8 additions & 8 deletions tutorial/source/stable.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
"\n",
"This tutorial demonstrates inference using the Levy [Stable](http://docs.pyro.ai/en/stable/distributions.html#stable) distribution through a motivating example of a non-Gaussian stochastic volatilty model.\n",
"\n",
"Inference with stable distribution is tricky because the density `Stable.log_prob()` is not defined. In this tutorial we demonstrate two approaches to inference: (i) using the [poutine.reparam](http://docs.pyro.ai/en/latest/poutine.html#pyro.poutine.handlers.reparam) effect to transform models in to a tractable form, (ii) using the likelihood-free loss [EnergyDistance](http://docs.pyro.ai/en/latest/inference_algos.html#pyro.infer.energy_distance.EnergyDistance) with SVI, and (iii) using the `StableWithLogProb` distribution which has a numerically integrated log-probability calculation.\n",
"Inference with stable distribution is tricky because the density `Stable.log_prob()` is very expensive. In this tutorial we demonstrate three approaches to inference: (i) using the [poutine.reparam](http://docs.pyro.ai/en/latest/poutine.html#pyro.poutine.handlers.reparam) effect to transform models in to a tractable form, (ii) using the likelihood-free loss [EnergyDistance](http://docs.pyro.ai/en/latest/inference_algos.html#pyro.infer.energy_distance.EnergyDistance) with SVI, and (iii) using `Stable.log_prob()` which has a numerically integrated log-probability calculation.\n",
"\n",
"\n",
"#### Summary\n",
"\n",
"- [Stable.log_prob()](http://docs.pyro.ai/en/stable/distributions.html#stable) is undefined.\n",
"- [Stable.log_prob()](http://docs.pyro.ai/en/stable/distributions.html#stable) is very expensive.\n",
"- Stable inference requires either reparameterization or a likelihood-free loss.\n",
"- Reparameterization:\n",
" - The [poutine.reparam()](http://docs.pyro.ai/en/latest/poutine.html#pyro.poutine.handlers.reparam) handler can transform models using various [strategies](http://docs.pyro.ai/en/latest/infer.reparam.html).\n",
Expand All @@ -29,7 +29,7 @@
"- [Fitting a single distribution to log returns](#fitting) using `EnergyDistance`\n",
"- [Modeling stochastic volatility](#modeling) using:\n",
" - [Reparameterization](#reparam) with `poutine.reparam`\n",
" - [Numerically integrated log-probability](#numeric) of [StableWithLogProb](http://docs.pyro.ai/en/stable/distributions.html#stablewithlogprob)"
" - [Numerically integrated log-probability](#numeric) with `Stable.log_prob()`"
]
},
{
Expand Down Expand Up @@ -337,7 +337,7 @@
"metadata": {},
"outputs": [],
"source": [
"def model(data, r_dist=dist.Stable):\n",
"def model(data):\n",
" # Note we avoid plates because we'll later reparameterize along the time axis using\n",
" # DiscreteCosineReparam, breaking independence. This requires .unsqueeze()ing scalars.\n",
" h_0 = pyro.sample(\"h_0\", dist.Normal(0, 1)).unsqueeze(-1)\n",
Expand All @@ -350,7 +350,7 @@
" r_loc = pyro.sample(\"r_loc\", dist.Normal(0, 1e-2)).unsqueeze(-1)\n",
" r_skew = pyro.sample(\"r_skew\", dist.Uniform(-1, 1)).unsqueeze(-1)\n",
" r_stability = pyro.sample(\"r_stability\", dist.Uniform(0, 2)).unsqueeze(-1)\n",
" pyro.sample(\"r\", r_dist(r_stability, r_skew, sqrt_h, r_loc * sqrt_h).to_event(1),\n",
" pyro.sample(\"r\", dist.Stable(r_stability, r_skew, sqrt_h, r_loc * sqrt_h).to_event(1),\n",
" obs=data)"
]
},
Expand All @@ -360,7 +360,7 @@
"source": [
"### Fitting a Model with Reparameterization <a class=\"anchor\" id=\"reparam\"></a>\n",
"\n",
"We use two reparameterizers: [StableReparam](http://docs.pyro.ai/en/latest/infer.reparam.html#pyro.infer.reparam.stable.StableReparam) to handle the `Stable` likelihood (since `Stable.log_prob()` is undefined), and [DiscreteCosineReparam](http://docs.pyro.ai/en/latest/infer.reparam.html#pyro.infer.reparam.discrete_cosine.DiscreteCosineReparam) to improve geometry of the latent Gaussian process for `v`. We'll then use `reparam_model` for both inference and prediction."
"We use two reparameterizers: [StableReparam](http://docs.pyro.ai/en/latest/infer.reparam.html#pyro.infer.reparam.stable.StableReparam) to handle the `Stable` likelihood (since `Stable.log_prob()` is very expensive), and [DiscreteCosineReparam](http://docs.pyro.ai/en/latest/infer.reparam.html#pyro.infer.reparam.discrete_cosine.DiscreteCosineReparam) to improve geometry of the latent Gaussian process for `v`. We'll then use `reparam_model` for both inference and prediction."
]
},
{
Expand Down Expand Up @@ -533,7 +533,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We now create a model with the `Stable` distirbution replaced by `StableWithLogProb` which has a method for calculating the log-probability density."
"We now create a model without reparameterization of the `Stable` distirbution. This model will use the `Stable.log_prob()` method in order to calculate the log-probability density."
]
},
{
Expand All @@ -543,7 +543,7 @@
"outputs": [],
"source": [
"from functools import partial\n",
"model_with_log_prob = poutine.reparam(partial(model, r_dist=dist.StableWithLogProb), {\"v\": DiscreteCosineReparam()})"
"model_with_log_prob = poutine.reparam(model, {\"v\": DiscreteCosineReparam()})"
]
},
{
Expand Down

0 comments on commit 55750ed

Please sign in to comment.