Skip to content

Commit

Permalink
Add helper to convert model coords and dims into format accepted by I…
Browse files Browse the repository at this point in the history
…nferenceData
  • Loading branch information
jaharvey8 authored and ricardoV94 committed Sep 14, 2023
1 parent f77372c commit ed631a2
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 50 deletions.
30 changes: 17 additions & 13 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@ def is_data(name, var, model) -> bool:
return constant_data


def coords_and_dims_for_inferencedata(model: Model) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Parse PyMC model coords and dims format to one accepted by InferenceData."""
coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in model.coords.items()
if cvals is not None
}
dims = {dname: list(dvals) for dname, dvals in model.named_vars_to_dims.items()}

return coords, dims


class _DefaultTrace:
"""
Utility for collecting samples into a dictionary.
Expand Down Expand Up @@ -216,19 +228,11 @@ def __init__(
" one of trace, prior, posterior_predictive or predictions."
)

# Make coord types more rigid
untyped_coords: Dict[str, Optional[Sequence[Any]]] = {**self.model.coords}
if coords:
untyped_coords.update(coords)
self.coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in untyped_coords.items()
if cvals is not None
}

self.dims = {} if dims is None else dims
model_dims = {k: list(v) for k, v in self.model.named_vars_to_dims.items()}
self.dims = {**model_dims, **self.dims}
user_coords = {} if coords is None else coords
user_dims = {} if dims is None else dims
model_coords, model_dims = coords_and_dims_for_inferencedata(model)
self.coords = {**model_coords, **user_coords}
self.dims = {**model_dims, **user_dims}

if sample_dims is None:
sample_dims = ["chain", "draw"]
Expand Down
2 changes: 1 addition & 1 deletion pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@ def RV_dims(self) -> Dict[str, Tuple[Union[str, None], ...]]:
Entries in the tuples may be ``None``, if the RV dimension was not given a name.
"""
warnings.warn(
"Model.RV_dims is deprecated. User Model.named_vars_to_dims instead.",
"Model.RV_dims is deprecated. Use Model.named_vars_to_dims instead.",
FutureWarning,
)
return self.named_vars_to_dims
Expand Down
36 changes: 9 additions & 27 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@
from pytensor.tensor.random.type import RandomType

from pymc import Model, modelcontext
from pymc.backends.arviz import find_constants, find_observations
from pymc.backends.arviz import (
coords_and_dims_for_inferencedata,
find_constants,
find_observations,
)
from pymc.distributions.multivariate import PosDefMatrix
from pymc.initial_point import StartDict
from pymc.logprob.utils import CheckParameterValue
Expand Down Expand Up @@ -392,17 +396,6 @@ def sample_blackjax_nuts(

vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))

coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in model.coords.items()
if cvals is not None
}

dims = {
var_name: [dim for dim in dims if dim is not None]
for var_name, dims in model.named_vars_to_dims.items()
}

(random_seed,) = _get_seeds_per_chain(random_seed, 1)

tic1 = datetime.now()
Expand Down Expand Up @@ -485,7 +478,7 @@ def sample_blackjax_nuts(
"sampling_time": (tic3 - tic2).total_seconds(),
}

posterior = mcmc_samples
coords, dims = coords_and_dims_for_inferencedata(model)
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'
# and drop keys 'coords' and 'dims' from 'idata_kwargs' if present.
_update_coords_and_dims(coords=coords, dims=dims, idata_kwargs=idata_kwargs)
Expand All @@ -500,7 +493,7 @@ def sample_blackjax_nuts(
dims=dims,
attrs=make_attrs(attrs, library=blackjax),
)
az_trace = to_trace(posterior=posterior, **idata_kwargs)
az_trace = to_trace(posterior=mcmc_samples, **idata_kwargs)

return az_trace

Expand Down Expand Up @@ -613,17 +606,6 @@ def sample_numpyro_nuts(

vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))

coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in model.coords.items()
if cvals is not None
}

dims = {
var_name: [dim for dim in dims if dim is not None]
for var_name, dims in model.named_vars_to_dims.items()
}

(random_seed,) = _get_seeds_per_chain(random_seed, 1)

tic1 = datetime.now()
Expand Down Expand Up @@ -715,7 +697,7 @@ def sample_numpyro_nuts(
"sampling_time": (tic3 - tic2).total_seconds(),
}

posterior = mcmc_samples
coords, dims = coords_and_dims_for_inferencedata(model)
# Update 'coords' and 'dims' extracted from the model with user 'idata_kwargs'
# and drop keys 'coords' and 'dims' from 'idata_kwargs' if present.
_update_coords_and_dims(coords=coords, dims=dims, idata_kwargs=idata_kwargs)
Expand All @@ -730,5 +712,5 @@ def sample_numpyro_nuts(
dims=dims,
attrs=make_attrs(attrs, library=numpyro),
)
az_trace = to_trace(posterior=posterior, **idata_kwargs)
az_trace = to_trace(posterior=mcmc_samples, **idata_kwargs)
return az_trace
12 changes: 4 additions & 8 deletions pymc/stats/log_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# limitations under the License.
from typing import Optional, Sequence, cast

import numpy as np

from arviz import InferenceData, dict_to_dataset
from fastprogress import progress_bar

import pymc

from pymc.backends.arviz import _DefaultTrace
from pymc.backends.arviz import _DefaultTrace, coords_and_dims_for_inferencedata
from pymc.model import Model, modelcontext
from pymc.pytensorf import PointFunc
from pymc.util import dataset_to_point_list
Expand Down Expand Up @@ -113,14 +111,12 @@ def compute_log_likelihood(
(*[len(coord) for coord in stacked_dims.values()], *array.shape[1:])
)

coords, dims = coords_and_dims_for_inferencedata(model)
loglike_dataset = dict_to_dataset(
loglike_trace,
library=pymc,
dims={dname: list(dvals) for dname, dvals in model.named_vars_to_dims.items()},
coords={
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in model.coords.items()
},
dims=dims,
coords=coords,
default_dims=list(sample_dims),
skip_event_dims=True,
)
Expand Down
17 changes: 16 additions & 1 deletion tests/stats/test_log_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
import scipy.stats as st

from arviz import InferenceData, dict_to_dataset
from arviz import InferenceData, dict_to_dataset, from_dict

from pymc.distributions import Dirichlet, Normal
from pymc.distributions.transforms import log
Expand Down Expand Up @@ -117,3 +117,18 @@ def test_invalid_var_names(self):
idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)}))
with pytest.raises(ValueError, match="var_names must refer to observed_RVs"):
compute_log_likelihood(idata, var_names=["x"])

def test_dims_without_coords(self):
# Issues #6820
with Model() as m:
x = Normal("x")
y = Normal("y", x, observed=[0, 0, 0], shape=(3,), dims="obs")

trace = from_dict({"x": [[0, 1]]})
llike = compute_log_likelihood(trace)

assert len(llike.log_likelihood["obs"]) == 3
np.testing.assert_allclose(
llike.log_likelihood["y"].values,
st.norm.logpdf([[[0, 0, 0], [1, 1, 1]]]),
)

0 comments on commit ed631a2

Please sign in to comment.