Skip to content

Implement compute_log_prior utility #7149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
tests/sampling/test_forward.py
tests/sampling/test_population.py
tests/stats/test_convergence.py
tests/stats/test_log_likelihood.py
tests/stats/test_log_density.py
tests/distributions/test_distribution.py
tests/distributions/test_discrete.py

Expand Down
1 change: 1 addition & 0 deletions docs/source/api/misc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ Other utils
:toctree: generated/

compute_log_likelihood
compute_log_prior
find_constrained_prior
DictToArrayBijection
24 changes: 21 additions & 3 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(
prior=None,
posterior_predictive=None,
log_likelihood=False,
log_prior=False,
predictions=None,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
Expand Down Expand Up @@ -215,6 +216,7 @@ def __init__(
self.prior = prior
self.posterior_predictive = posterior_predictive
self.log_likelihood = log_likelihood
self.log_prior = log_prior
self.predictions = predictions

if all(elem is None for elem in (trace, predictions, posterior_predictive, prior)):
Expand Down Expand Up @@ -436,7 +438,7 @@ def to_inference_data(self):
id_dict["constant_data"] = self.constant_data_to_xarray()
idata = InferenceData(save_warmup=self.save_warmup, **id_dict)
if self.log_likelihood:
from pymc.stats.log_likelihood import compute_log_likelihood
from pymc.stats.log_density import compute_log_likelihood

idata = compute_log_likelihood(
idata,
Expand All @@ -446,6 +448,17 @@ def to_inference_data(self):
sample_dims=self.sample_dims,
progressbar=False,
)
if self.log_prior:
from pymc.stats.log_density import compute_log_prior

idata = compute_log_prior(
idata,
var_names=None if self.log_prior is True else self.log_prior,
extend_inferencedata=True,
model=self.model,
sample_dims=self.sample_dims,
progressbar=False,
)
return idata


Expand All @@ -455,6 +468,7 @@ def to_inference_data(
prior: Optional[Mapping[str, Any]] = None,
posterior_predictive: Optional[Mapping[str, Any]] = None,
log_likelihood: Union[bool, Iterable[str]] = False,
log_prior: Union[bool, Iterable[str]] = False,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
sample_dims: Optional[list] = None,
Expand All @@ -481,8 +495,11 @@ def to_inference_data(
Dictionary with the variable names as keys, and values numpy arrays
containing posterior predictive samples.
log_likelihood : bool or array_like of str, optional
List of variables to calculate `log_likelihood`. Defaults to True which calculates
`log_likelihood` for all observed variables. If set to False, log_likelihood is skipped.
List of variables to calculate `log_likelihood`. Defaults to False.
If set to True, computes `log_likelihood` for all observed variables.
log_prior : bool or array_like of str, optional
List of variables to calculate `log_prior`. Defaults to False.
If set to True, computes `log_prior` for all unobserved variables.
coords : dict of {str: array-like}, optional
Map of coordinate names to coordinate values
dims : dict of {str: list of str}, optional
Expand All @@ -509,6 +526,7 @@ def to_inference_data(
prior=prior,
posterior_predictive=posterior_predictive,
log_likelihood=log_likelihood,
log_prior=log_prior,
coords=coords,
dims=dims,
sample_dims=sample_dims,
Expand Down
4 changes: 2 additions & 2 deletions pymc/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@
if not attr.startswith("__"):
setattr(sys.modules[__name__], attr, obj)

from pymc.stats.log_likelihood import compute_log_likelihood
from pymc.stats.log_density import compute_log_likelihood, compute_log_prior

__all__ = ("compute_log_likelihood", *az.stats.__all__)
__all__ = ("compute_log_likelihood", "compute_log_prior", *az.stats.__all__)
116 changes: 95 additions & 21 deletions pymc/stats/log_likelihood.py → pymc/stats/log_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pymc.pytensorf import PointFunc
from pymc.util import dataset_to_point_list

__all__ = ("compute_log_likelihood",)
__all__ = ("compute_log_likelihood", "compute_log_prior")


def compute_log_likelihood(
Expand All @@ -43,7 +43,8 @@ def compute_log_likelihood(
idata : InferenceData
InferenceData with posterior group
var_names : sequence of str, optional
List of Observed variable names for which to compute log_likelihood. Defaults to all observed variables
List of Observed variable names for which to compute log_likelihood.
Defaults to all observed variables.
extend_inferencedata : bool, default True
Whether to extend the original InferenceData or return a new one
model : Model, optional
Expand All @@ -54,20 +55,92 @@ def compute_log_likelihood(
-------
idata : InferenceData
InferenceData with log_likelihood group
"""
return compute_log_density(
idata=idata,
var_names=var_names,
extend_inferencedata=extend_inferencedata,
model=model,
kind="likelihood",
sample_dims=sample_dims,
progressbar=progressbar,
)


def compute_log_prior(
idata: InferenceData,
var_names: Optional[Sequence[str]] = None,
extend_inferencedata: bool = True,
model: Optional[Model] = None,
sample_dims: Sequence[str] = ("chain", "draw"),
progressbar=True,
):
"""Compute elemwise log_prior of model given InferenceData with posterior group

Parameters
----------
idata : InferenceData
InferenceData with posterior group
var_names : sequence of str, optional
List of Observed variable names for which to compute log_prior.
Defaults to all all free variables.
extend_inferencedata : bool, default True
Whether to extend the original InferenceData or return a new one
model : Model, optional
sample_dims : sequence of str, default ("chain", "draw")
progressbar : bool, default True

Returns
-------
idata : InferenceData
InferenceData with log_prior group
"""
return compute_log_density(
idata=idata,
var_names=var_names,
extend_inferencedata=extend_inferencedata,
model=model,
kind="prior",
sample_dims=sample_dims,
progressbar=progressbar,
)


def compute_log_density(
idata: InferenceData,
*,
var_names: Optional[Sequence[str]] = None,
extend_inferencedata: bool = True,
model: Optional[Model] = None,
kind="likelihood",
sample_dims: Sequence[str] = ("chain", "draw"),
progressbar=True,
):
"""
Compute elemwise log_likelihood or log_prior of model given InferenceData with posterior group
"""

posterior = idata["posterior"]

model = modelcontext(model)

if kind not in ("likelihood", "prior"):
raise ValueError("kind must be either 'likelihood' or 'prior'")

if kind == "likelihood":
target_rvs = model.observed_RVs
target_str = "observed_RVs"
else:
target_rvs = model.unobserved_RVs
target_str = "free_RVs"

if var_names is None:
observed_vars = model.observed_RVs
var_names = tuple(rv.name for rv in observed_vars)
vars = target_rvs
var_names = tuple(rv.name for rv in vars)
else:
observed_vars = [model.named_vars[name] for name in var_names]
if not set(observed_vars).issubset(model.observed_RVs):
raise ValueError(f"var_names must refer to observed_RVs in the model. Got: {var_names}")
vars = [model.named_vars[name] for name in var_names]
if not set(vars).issubset(target_rvs):
raise ValueError(f"var_names must refer to {target_str} in the model. Got: {var_names}")

# We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values
try:
Expand All @@ -80,39 +153,40 @@ def compute_log_likelihood(
}
model.rvs_to_transforms = {rv: None for rv in model.basic_RVs}

elemwise_loglike_fn = model.compile_fn(
elemwise_logdens_fn = model.compile_fn(
inputs=model.value_vars,
outs=model.logp(vars=observed_vars, sum=False),
outs=model.logp(vars=vars, sum=False),
on_unused_input="ignore",
)
elemwise_loglike_fn = cast(PointFunc, elemwise_loglike_fn)
elemwise_logdens_fn = cast(PointFunc, elemwise_logdens_fn)
finally:
model.rvs_to_values = original_rvs_to_values
model.rvs_to_transforms = original_rvs_to_transforms

# Ignore Deterministics
posterior_values = posterior[[rv.name for rv in model.free_RVs]]
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)

n_pts = len(posterior_pts)
loglike_dict = _DefaultTrace(n_pts)
logdens_dict = _DefaultTrace(n_pts)
indices = range(n_pts)
if progressbar:
indices = progress_bar(indices, total=n_pts, display=progressbar)

for idx in indices:
loglikes_pts = elemwise_loglike_fn(posterior_pts[idx])
for rv_name, rv_loglike in zip(var_names, loglikes_pts):
loglike_dict.insert(rv_name, rv_loglike, idx)
logdenss_pts = elemwise_logdens_fn(posterior_pts[idx])
for rv_name, rv_logdens in zip(var_names, logdenss_pts):
logdens_dict.insert(rv_name, rv_logdens, idx)

loglike_trace = loglike_dict.trace_dict
for key, array in loglike_trace.items():
loglike_trace[key] = array.reshape(
logdens_trace = logdens_dict.trace_dict
for key, array in logdens_trace.items():
logdens_trace[key] = array.reshape(
(*[len(coord) for coord in stacked_dims.values()], *array.shape[1:])
)

coords, dims = coords_and_dims_for_inferencedata(model)
loglike_dataset = dict_to_dataset(
loglike_trace,
logdens_dataset = dict_to_dataset(
logdens_trace,
library=pymc,
dims=dims,
coords=coords,
Expand All @@ -121,7 +195,7 @@ def compute_log_likelihood(
)

if extend_inferencedata:
idata.add_groups(dict(log_likelihood=loglike_dataset))
idata.add_groups({f"log_{kind}": logdens_dataset})
return idata
else:
return loglike_dataset
return logdens_dataset
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pymc.distributions import Dirichlet, Normal
from pymc.distributions.transforms import log
from pymc.model import Model
from pymc.stats.log_likelihood import compute_log_likelihood
from pymc.stats.log_density import compute_log_likelihood, compute_log_prior
from tests.distributions.test_multivariate import dirichlet_logpdf


Expand Down Expand Up @@ -132,3 +132,26 @@ def test_dims_without_coords(self):
llike.log_likelihood["y"].values,
st.norm.logpdf([[[0, 0, 0], [1, 1, 1]]]),
)

@pytest.mark.parametrize("transform", (False, True))
def test_basic_log_prior(self, transform):
transform = log if transform else None
with Model() as m:
x = Normal("x", transform=transform)
x_value_var = m.rvs_to_values[x]
Normal("y", x, observed=[0, 1, 2])

idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)}))
res = compute_log_prior(idata)

# Check we didn't erase the original mappings
assert m.rvs_to_values[x] is x_value_var
assert m.rvs_to_transforms[x] is transform

assert res is idata
assert res.log_prior.dims == {"chain": 4, "draw": 25}

np.testing.assert_allclose(
res.log_prior["x"].values,
st.norm(0, 1).logpdf(idata.posterior["x"].values),
)