Skip to content

Commit db5eb4f

Browse files
committed
Use dict_to_dataset_drop_incompatible_coords everywhere
1 parent 9f653a6 commit db5eb4f

File tree

4 files changed

+30
-21
lines changed

4 files changed

+30
-21
lines changed

pymc/backends/arviz.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs):
6262
safe_coords = coords
6363

64-
if not RAISE_ON_INCOMPATIBLE_COORD_LENGTHS:
64+
if dims and not RAISE_ON_INCOMPATIBLE_COORD_LENGTHS:
6565
coords_lengths = {k: len(v) for k, v in coords.items()}
6666
for var_name, var in vars_dict.items():
6767
# Iterate in reversed because of chain/draw batch dimensions
@@ -70,9 +70,8 @@ def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **k
7070
if (coord_length is not None) and (coord_length != dim_length):
7171
warnings.warn(
7272
f"Incompatible coordinate length of {coord_length} for dimension '{dim}' of variable '{var_name}'.\n"
73-
"This usually happens when a sliced or concatenated variable is wrapped as a `pymc.dims.Deterministic`."
74-
"The originate coordinates for this dim will not be included in the returned dataset for any of the variables. "
75-
"Instead they will default to `np.arange(var_length)` and the shorter variables will be right-padded with nan.\n"
73+
"The original coordinates for this dim will not be included in the returned dataset for any of the variables. "
74+
"Instead they will default to `np.arange`, possibly right-padded with nan.\n"
7675
"To make this warning into an error set `pymc.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS` to `True`",
7776
UserWarning,
7877
)
@@ -303,14 +302,14 @@ def posterior_to_xarray(self):
303302
self.posterior_trace.get_values(var_name, combine=False, squeeze=False)
304303
)
305304
return (
306-
dict_to_dataset(
305+
dict_to_dataset_drop_incompatible_coords(
307306
data,
308307
library=pymc,
309308
coords=self.coords,
310309
dims=self.dims,
311310
attrs=self.attrs,
312311
),
313-
dict_to_dataset(
312+
dict_to_dataset_drop_incompatible_coords(
314313
data_warmup,
315314
library=pymc,
316315
coords=self.coords,
@@ -345,14 +344,14 @@ def sample_stats_to_xarray(self):
345344
)
346345

347346
return (
348-
dict_to_dataset(
347+
dict_to_dataset_drop_incompatible_coords(
349348
data,
350349
library=pymc,
351350
dims=None,
352351
coords=self.coords,
353352
attrs=self.attrs,
354353
),
355-
dict_to_dataset(
354+
dict_to_dataset_drop_incompatible_coords(
356355
data_warmup,
357356
library=pymc,
358357
dims=None,
@@ -366,7 +365,7 @@ def posterior_predictive_to_xarray(self):
366365
"""Convert posterior_predictive samples to xarray."""
367366
data = self.posterior_predictive
368367
dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data}
369-
return dict_to_dataset(
368+
return dict_to_dataset_drop_incompatible_coords(
370369
data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims
371370
)
372371

@@ -375,7 +374,7 @@ def predictions_to_xarray(self):
375374
"""Convert predictions (out of sample predictions) to xarray."""
376375
data = self.predictions
377376
dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data}
378-
return dict_to_dataset(
377+
return dict_to_dataset_drop_incompatible_coords(
379378
data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims
380379
)
381380

@@ -412,7 +411,7 @@ def observed_data_to_xarray(self):
412411
"""Convert observed data to xarray."""
413412
if self.predictions:
414413
return None
415-
return dict_to_dataset(
414+
return dict_to_dataset_drop_incompatible_coords(
416415
self.observations,
417416
library=pymc,
418417
coords=self.coords,
@@ -427,7 +426,7 @@ def constant_data_to_xarray(self):
427426
if not constant_data:
428427
return None
429428

430-
xarray_dataset = dict_to_dataset(
429+
xarray_dataset = dict_to_dataset_drop_incompatible_coords(
431430
constant_data,
432431
library=pymc,
433432
coords=self.coords,
@@ -705,7 +704,7 @@ def apply_function_over_dataset(
705704
)
706705
)
707706

708-
return dict_to_dataset(
707+
return dict_to_dataset_drop_incompatible_coords(
709708
out_trace,
710709
library=pymc,
711710
dims=dims,

pymc/sampling/mcmc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import numpy as np
3434
import pytensor.gradient as tg
3535

36-
from arviz import InferenceData, dict_to_dataset
36+
from arviz import InferenceData
3737
from arviz.data.base import make_attrs
3838
from pytensor.graph.basic import Variable
3939
from rich.theme import Theme
@@ -45,6 +45,7 @@
4545
from pymc.backends import RunType, TraceOrBackend, init_traces
4646
from pymc.backends.arviz import (
4747
coords_and_dims_for_inferencedata,
48+
dict_to_dataset_drop_incompatible_coords,
4849
find_constants,
4950
find_observations,
5051
)
@@ -355,14 +356,14 @@ def _sample_external_nuts(
355356
# Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
356357
# gather observed and constant data as nutpie.sample() has no access to the PyMC model
357358
coords, dims = coords_and_dims_for_inferencedata(model)
358-
constant_data = dict_to_dataset(
359+
constant_data = dict_to_dataset_drop_incompatible_coords(
359360
find_constants(model),
360361
library=pm,
361362
coords=coords,
362363
dims=dims,
363364
default_dims=[],
364365
)
365-
observed_data = dict_to_dataset(
366+
observed_data = dict_to_dataset_drop_incompatible_coords(
366367
find_observations(model),
367368
library=pm,
368369
coords=coords,

pymc/smc/sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
import pymc
3535

36-
from pymc.backends.arviz import dict_to_dataset, to_inference_data
36+
from pymc.backends.arviz import dict_to_dataset_drop_incompatible_coords, to_inference_data
3737
from pymc.backends.base import MultiTrace
3838
from pymc.distributions.custom import CustomDistRV, CustomSymbolicDistRV
3939
from pymc.distributions.distribution import _support_point
@@ -264,7 +264,7 @@ def _save_sample_stats(
264264
else:
265265
sample_stats_dict[stat] = np.array(value)
266266

267-
sample_stats = dict_to_dataset(
267+
sample_stats = dict_to_dataset_drop_incompatible_coords(
268268
sample_stats_dict,
269269
attrs=sample_settings_dict,
270270
library=pymc,

tests/backends/test_arviz.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,16 @@ def test_zero_size(self):
851851
assert pl[0]["x"].dtype == np.float64
852852

853853

854-
def test_incompatible_coordinate_lengths():
854+
@pytest.mark.parametrize(
855+
"sampling_method",
856+
(
857+
lambda: pm.sample_prior_predictive(draws=1).prior,
858+
lambda: pm.sample(
859+
chains=1, draws=1, tune=0, compute_convergence_checks=False, progressbar=False
860+
).posterior,
861+
),
862+
)
863+
def test_incompatible_coordinate_lengths(sampling_method):
855864
with pm.Model(coords={"a": [-1, -2, -3]}) as m:
856865
x = pm.Normal("x", dims="a")
857866
y = pm.Deterministic("y", x[1:], dims=("a",))
@@ -862,14 +871,14 @@ def test_incompatible_coordinate_lengths():
862871
"Incompatible coordinate length of 3 for dimension 'a' of variable 'y'"
863872
),
864873
):
865-
prior = pm.sample_prior_predictive(draws=1).prior.squeeze(("chain", "draw"))
874+
prior = sampling_method().squeeze(("chain", "draw"))
866875
assert prior.x.dims == prior.y.dims == ("a",)
867876
assert prior.x.shape == prior.y.shape == (3,)
868877
assert np.isnan(prior.y.values[-1])
869878
assert list(prior.coords["a"]) == [0, 1, 2]
870879

871880
pm.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = True
872881
with pytest.raises(ValueError):
873-
pm.sample_prior_predictive(draws=1)
882+
sampling_method()
874883

875884
pm.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = False

0 commit comments

Comments
 (0)