diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index bfd3008a3f1..5a8174ff7aa 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -100,6 +100,18 @@ def is_data(name, var, model) -> bool: return constant_data +def coords_and_dims_to_inferencedata(model: Model) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Parse PyMC model coords and dims format to one accepted by InferenceData.""" + coords = { + cname: np.array(cvals) if isinstance(cvals, tuple) else cvals + for cname, cvals in model.coords.items() + if cvals is not None + } + dims = {dname: list(dvals) for dname, dvals in model.named_vars_to_dims.items()} + + return coords, dims + + class _DefaultTrace: """ Utility for collecting samples into a dictionary. @@ -216,19 +228,11 @@ def __init__( " one of trace, prior, posterior_predictive or predictions." ) - # Make coord types more rigid - untyped_coords: Dict[str, Optional[Sequence[Any]]] = {**self.model.coords} - if coords: - untyped_coords.update(coords) - self.coords = { - cname: np.array(cvals) if isinstance(cvals, tuple) else cvals - for cname, cvals in untyped_coords.items() - if cvals is not None - } - - self.dims = {} if dims is None else dims - model_dims = {k: list(v) for k, v in self.model.named_vars_to_dims.items()} - self.dims = {**model_dims, **self.dims} + user_coords = {} if coords is None else coords + user_dims = {} if dims is None else dims + model_coords, model_dims = coords_and_dims_to_inferencedata(model) + self.coords = {**model_coords, **user_coords} + self.dims = {**model_dims, **user_dims} if sample_dims is None: sample_dims = ["chain", "draw"] diff --git a/pymc/model/core.py b/pymc/model/core.py index 2edcb183615..01a655679a5 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -942,7 +942,7 @@ def RV_dims(self) -> Dict[str, Tuple[Union[str, None], ...]]: Entries in the tuples may be ``None``, if the RV dimension was not given a name. """ warnings.warn( - "Model.RV_dims is deprecated. User Model.named_vars_to_dims instead.", + "Model.RV_dims is deprecated. Use Model.named_vars_to_dims instead.", FutureWarning, ) return self.named_vars_to_dims diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 5016b0897d6..627e98b064c 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -37,7 +37,11 @@ from pytensor.tensor.random.type import RandomType from pymc import Model, modelcontext -from pymc.backends.arviz import find_constants, find_observations +from pymc.backends.arviz import ( + coords_and_dims_to_inferencedata, + find_constants, + find_observations, +) from pymc.distributions.multivariate import PosDefMatrix from pymc.initial_point import StartDict from pymc.logprob.utils import CheckParameterValue @@ -392,17 +396,6 @@ def sample_blackjax_nuts( vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) - coords = { - cname: np.array(cvals) if isinstance(cvals, tuple) else cvals - for cname, cvals in model.coords.items() - if cvals is not None - } - - dims = { - var_name: [dim for dim in dims if dim is not None] - for var_name, dims in model.named_vars_to_dims.items() - } - (random_seed,) = _get_seeds_per_chain(random_seed, 1) tic1 = datetime.now() @@ -485,7 +478,7 @@ def sample_blackjax_nuts( "sampling_time": (tic3 - tic2).total_seconds(), } - posterior = mcmc_samples + coords, dims = coords_and_dims_to_inferencedata(model) # Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs' # and drop keys 'coords' and 'dims' from 'idata_kwargs' if present. _update_coords_and_dims(coords=coords, dims=dims, idata_kwargs=idata_kwargs) @@ -500,7 +493,7 @@ def sample_blackjax_nuts( dims=dims, attrs=make_attrs(attrs, library=blackjax), ) - az_trace = to_trace(posterior=posterior, **idata_kwargs) + az_trace = to_trace(posterior=mcmc_samples, **idata_kwargs) return az_trace @@ -613,17 +606,6 @@ def sample_numpyro_nuts( vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) - coords = { - cname: np.array(cvals) if isinstance(cvals, tuple) else cvals - for cname, cvals in model.coords.items() - if cvals is not None - } - - dims = { - var_name: [dim for dim in dims if dim is not None] - for var_name, dims in model.named_vars_to_dims.items() - } - (random_seed,) = _get_seeds_per_chain(random_seed, 1) tic1 = datetime.now() @@ -715,7 +697,7 @@ def sample_numpyro_nuts( "sampling_time": (tic3 - tic2).total_seconds(), } - posterior = mcmc_samples + coords, dims = coords_and_dims_to_inferencedata(model) # Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs' # and drop keys 'coords' and 'dims' from 'idata_kwargs' if present. _update_coords_and_dims(coords=coords, dims=dims, idata_kwargs=idata_kwargs) @@ -730,5 +712,5 @@ def sample_numpyro_nuts( dims=dims, attrs=make_attrs(attrs, library=numpyro), ) - az_trace = to_trace(posterior=posterior, **idata_kwargs) + az_trace = to_trace(posterior=mcmc_samples, **idata_kwargs) return az_trace diff --git a/pymc/stats/log_likelihood.py b/pymc/stats/log_likelihood.py index 51bbd59d609..2883d3b9f31 100644 --- a/pymc/stats/log_likelihood.py +++ b/pymc/stats/log_likelihood.py @@ -13,14 +13,12 @@ # limitations under the License. from typing import Optional, Sequence, cast -import numpy as np - from arviz import InferenceData, dict_to_dataset from fastprogress import progress_bar import pymc -from pymc.backends.arviz import _DefaultTrace +from pymc.backends.arviz import _DefaultTrace, coords_and_dims_to_inferencedata from pymc.model import Model, modelcontext from pymc.pytensorf import PointFunc from pymc.util import dataset_to_point_list @@ -113,14 +111,12 @@ def compute_log_likelihood( (*[len(coord) for coord in stacked_dims.values()], *array.shape[1:]) ) + coords, dims = coords_and_dims_to_inferencedata(model) loglike_dataset = dict_to_dataset( loglike_trace, library=pymc, - dims={dname: list(dvals) for dname, dvals in model.named_vars_to_dims.items()}, - coords={ - cname: np.array(cvals) if isinstance(cvals, tuple) else cvals - for cname, cvals in model.coords.items() - }, + dims=dims, + coords=coords, default_dims=list(sample_dims), skip_event_dims=True, ) diff --git a/tests/stats/test_log_likelihood.py b/tests/stats/test_log_likelihood.py index e7201b761bf..fde4d69c3e6 100644 --- a/tests/stats/test_log_likelihood.py +++ b/tests/stats/test_log_likelihood.py @@ -15,7 +15,7 @@ import pytest import scipy.stats as st -from arviz import InferenceData, dict_to_dataset +from arviz import InferenceData, dict_to_dataset, from_dict from pymc.distributions import Dirichlet, Normal from pymc.distributions.transforms import log @@ -117,3 +117,18 @@ def test_invalid_var_names(self): idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)})) with pytest.raises(ValueError, match="var_names must refer to observed_RVs"): compute_log_likelihood(idata, var_names=["x"]) + + def test_dims_without_coords(self): + # Issues #6820 + with Model() as m: + x = Normal("x") + y = Normal("y", x, observed=[0, 0, 0], shape=(3,), dims="obs") + + trace = from_dict({"x": [[0, 1]]}) + llike = compute_log_likelihood(trace) + + assert len(llike.log_likelihood["obs"]) == 3 + np.testing.assert_allclose( + llike.log_likelihood["y"].values, + st.norm.logpdf([[[0, 0, 0], [1, 1, 1]]]), + )