Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in compute_log_likelihood when variable has dims without coords #6882

Merged
merged 1 commit into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@ def is_data(name, var, model) -> bool:
return constant_data


def coords_and_dims_for_inferencedata(model: Model) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Parse PyMC model coords and dims format to one accepted by InferenceData."""
coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in model.coords.items()
if cvals is not None
}
dims = {dname: list(dvals) for dname, dvals in model.named_vars_to_dims.items()}

return coords, dims


class _DefaultTrace:
"""
Utility for collecting samples into a dictionary.
Expand Down Expand Up @@ -216,19 +228,11 @@ def __init__(
" one of trace, prior, posterior_predictive or predictions."
)

# Make coord types more rigid
untyped_coords: Dict[str, Optional[Sequence[Any]]] = {**self.model.coords}
if coords:
untyped_coords.update(coords)
self.coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in untyped_coords.items()
if cvals is not None
}

self.dims = {} if dims is None else dims
model_dims = {k: list(v) for k, v in self.model.named_vars_to_dims.items()}
self.dims = {**model_dims, **self.dims}
user_coords = {} if coords is None else coords
user_dims = {} if dims is None else dims
model_coords, model_dims = coords_and_dims_for_inferencedata(self.model)
self.coords = {**model_coords, **user_coords}
self.dims = {**model_dims, **user_dims}

if sample_dims is None:
sample_dims = ["chain", "draw"]
Expand Down
2 changes: 1 addition & 1 deletion pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@ def RV_dims(self) -> Dict[str, Tuple[Union[str, None], ...]]:
Entries in the tuples may be ``None``, if the RV dimension was not given a name.
"""
warnings.warn(
"Model.RV_dims is deprecated. User Model.named_vars_to_dims instead.",
"Model.RV_dims is deprecated. Use Model.named_vars_to_dims instead.",
FutureWarning,
)
return self.named_vars_to_dims
Expand Down
36 changes: 9 additions & 27 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@
from pytensor.tensor.random.type import RandomType

from pymc import Model, modelcontext
from pymc.backends.arviz import find_constants, find_observations
from pymc.backends.arviz import (
coords_and_dims_for_inferencedata,
find_constants,
find_observations,
)
from pymc.distributions.multivariate import PosDefMatrix
from pymc.initial_point import StartDict
from pymc.logprob.utils import CheckParameterValue
Expand Down Expand Up @@ -392,17 +396,6 @@ def sample_blackjax_nuts(

vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))

coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in model.coords.items()
if cvals is not None
}

dims = {
var_name: [dim for dim in dims if dim is not None]
for var_name, dims in model.named_vars_to_dims.items()
}

(random_seed,) = _get_seeds_per_chain(random_seed, 1)

tic1 = datetime.now()
Expand Down Expand Up @@ -485,7 +478,7 @@ def sample_blackjax_nuts(
"sampling_time": (tic3 - tic2).total_seconds(),
}

posterior = mcmc_samples
coords, dims = coords_and_dims_for_inferencedata(model)
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'
# and drop keys 'coords' and 'dims' from 'idata_kwargs' if present.
_update_coords_and_dims(coords=coords, dims=dims, idata_kwargs=idata_kwargs)
Expand All @@ -500,7 +493,7 @@ def sample_blackjax_nuts(
dims=dims,
attrs=make_attrs(attrs, library=blackjax),
)
az_trace = to_trace(posterior=posterior, **idata_kwargs)
az_trace = to_trace(posterior=mcmc_samples, **idata_kwargs)

return az_trace

Expand Down Expand Up @@ -613,17 +606,6 @@ def sample_numpyro_nuts(

vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))

coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in model.coords.items()
if cvals is not None
}

dims = {
var_name: [dim for dim in dims if dim is not None]
for var_name, dims in model.named_vars_to_dims.items()
}

(random_seed,) = _get_seeds_per_chain(random_seed, 1)

tic1 = datetime.now()
Expand Down Expand Up @@ -715,7 +697,7 @@ def sample_numpyro_nuts(
"sampling_time": (tic3 - tic2).total_seconds(),
}

posterior = mcmc_samples
coords, dims = coords_and_dims_for_inferencedata(model)
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'
# and drop keys 'coords' and 'dims' from 'idata_kwargs' if present.
_update_coords_and_dims(coords=coords, dims=dims, idata_kwargs=idata_kwargs)
Expand All @@ -730,5 +712,5 @@ def sample_numpyro_nuts(
dims=dims,
attrs=make_attrs(attrs, library=numpyro),
)
az_trace = to_trace(posterior=posterior, **idata_kwargs)
az_trace = to_trace(posterior=mcmc_samples, **idata_kwargs)
return az_trace
12 changes: 4 additions & 8 deletions pymc/stats/log_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# limitations under the License.
from typing import Optional, Sequence, cast

import numpy as np

from arviz import InferenceData, dict_to_dataset
from fastprogress import progress_bar

import pymc

from pymc.backends.arviz import _DefaultTrace
from pymc.backends.arviz import _DefaultTrace, coords_and_dims_for_inferencedata
from pymc.model import Model, modelcontext
from pymc.pytensorf import PointFunc
from pymc.util import dataset_to_point_list
Expand Down Expand Up @@ -113,14 +111,12 @@ def compute_log_likelihood(
(*[len(coord) for coord in stacked_dims.values()], *array.shape[1:])
)

coords, dims = coords_and_dims_for_inferencedata(model)
loglike_dataset = dict_to_dataset(
loglike_trace,
library=pymc,
dims={dname: list(dvals) for dname, dvals in model.named_vars_to_dims.items()},
coords={
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in model.coords.items()
},
dims=dims,
coords=coords,
default_dims=list(sample_dims),
skip_event_dims=True,
)
Expand Down
17 changes: 16 additions & 1 deletion tests/stats/test_log_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
import scipy.stats as st

from arviz import InferenceData, dict_to_dataset
from arviz import InferenceData, dict_to_dataset, from_dict

from pymc.distributions import Dirichlet, Normal
from pymc.distributions.transforms import log
Expand Down Expand Up @@ -117,3 +117,18 @@ def test_invalid_var_names(self):
idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)}))
with pytest.raises(ValueError, match="var_names must refer to observed_RVs"):
compute_log_likelihood(idata, var_names=["x"])

def test_dims_without_coords(self):
# Issues #6820
with Model() as m:
x = Normal("x")
y = Normal("y", x, observed=[0, 0, 0], shape=(3,), dims="obs")

trace = from_dict({"x": [[0, 1]]})
llike = compute_log_likelihood(trace)

assert len(llike.log_likelihood["obs"]) == 3
np.testing.assert_allclose(
llike.log_likelihood["y"].values,
st.norm.logpdf([[[0, 0, 0], [1, 1, 1]]]),
)