Skip to content

Commit 16f29a3

Browse files
committed
Use dict_to_dataset_drop_incompatible_coords everywhere
1 parent 9f653a6 commit 16f29a3

File tree

3 files changed

+15
-14
lines changed

3 files changed

+15
-14
lines changed

pymc/backends/arviz.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -303,14 +303,14 @@ def posterior_to_xarray(self):
303303
self.posterior_trace.get_values(var_name, combine=False, squeeze=False)
304304
)
305305
return (
306-
dict_to_dataset(
306+
dict_to_dataset_drop_incompatible_coords(
307307
data,
308308
library=pymc,
309309
coords=self.coords,
310310
dims=self.dims,
311311
attrs=self.attrs,
312312
),
313-
dict_to_dataset(
313+
dict_to_dataset_drop_incompatible_coords(
314314
data_warmup,
315315
library=pymc,
316316
coords=self.coords,
@@ -345,14 +345,14 @@ def sample_stats_to_xarray(self):
345345
)
346346

347347
return (
348-
dict_to_dataset(
348+
dict_to_dataset_drop_incompatible_coords(
349349
data,
350350
library=pymc,
351351
dims=None,
352352
coords=self.coords,
353353
attrs=self.attrs,
354354
),
355-
dict_to_dataset(
355+
dict_to_dataset_drop_incompatible_coords(
356356
data_warmup,
357357
library=pymc,
358358
dims=None,
@@ -366,7 +366,7 @@ def posterior_predictive_to_xarray(self):
366366
"""Convert posterior_predictive samples to xarray."""
367367
data = self.posterior_predictive
368368
dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data}
369-
return dict_to_dataset(
369+
return dict_to_dataset_drop_incompatible_coords(
370370
data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims
371371
)
372372

@@ -375,7 +375,7 @@ def predictions_to_xarray(self):
375375
"""Convert predictions (out of sample predictions) to xarray."""
376376
data = self.predictions
377377
dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data}
378-
return dict_to_dataset(
378+
return dict_to_dataset_drop_incompatible_coords(
379379
data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims
380380
)
381381

@@ -412,7 +412,7 @@ def observed_data_to_xarray(self):
412412
"""Convert observed data to xarray."""
413413
if self.predictions:
414414
return None
415-
return dict_to_dataset(
415+
return dict_to_dataset_drop_incompatible_coords(
416416
self.observations,
417417
library=pymc,
418418
coords=self.coords,
@@ -427,7 +427,7 @@ def constant_data_to_xarray(self):
427427
if not constant_data:
428428
return None
429429

430-
xarray_dataset = dict_to_dataset(
430+
xarray_dataset = dict_to_dataset_drop_incompatible_coords(
431431
constant_data,
432432
library=pymc,
433433
coords=self.coords,
@@ -705,7 +705,7 @@ def apply_function_over_dataset(
705705
)
706706
)
707707

708-
return dict_to_dataset(
708+
return dict_to_dataset_drop_incompatible_coords(
709709
out_trace,
710710
library=pymc,
711711
dims=dims,

pymc/sampling/mcmc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import numpy as np
3434
import pytensor.gradient as tg
3535

36-
from arviz import InferenceData, dict_to_dataset
36+
from arviz import InferenceData
3737
from arviz.data.base import make_attrs
3838
from pytensor.graph.basic import Variable
3939
from rich.theme import Theme
@@ -45,6 +45,7 @@
4545
from pymc.backends import RunType, TraceOrBackend, init_traces
4646
from pymc.backends.arviz import (
4747
coords_and_dims_for_inferencedata,
48+
dict_to_dataset_drop_incompatible_coords,
4849
find_constants,
4950
find_observations,
5051
)
@@ -355,14 +356,14 @@ def _sample_external_nuts(
355356
# Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
356357
# gather observed and constant data as nutpie.sample() has no access to the PyMC model
357358
coords, dims = coords_and_dims_for_inferencedata(model)
358-
constant_data = dict_to_dataset(
359+
constant_data = dict_to_dataset_drop_incompatible_coords(
359360
find_constants(model),
360361
library=pm,
361362
coords=coords,
362363
dims=dims,
363364
default_dims=[],
364365
)
365-
observed_data = dict_to_dataset(
366+
observed_data = dict_to_dataset_drop_incompatible_coords(
366367
find_observations(model),
367368
library=pm,
368369
coords=coords,

pymc/smc/sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
import pymc
3535

36-
from pymc.backends.arviz import dict_to_dataset, to_inference_data
36+
from pymc.backends.arviz import dict_to_dataset_drop_incompatible_coords, to_inference_data
3737
from pymc.backends.base import MultiTrace
3838
from pymc.distributions.custom import CustomDistRV, CustomSymbolicDistRV
3939
from pymc.distributions.distribution import _support_point
@@ -264,7 +264,7 @@ def _save_sample_stats(
264264
else:
265265
sample_stats_dict[stat] = np.array(value)
266266

267-
sample_stats = dict_to_dataset(
267+
sample_stats = dict_to_dataset_drop_incompatible_coords(
268268
sample_stats_dict,
269269
attrs=sample_settings_dict,
270270
library=pymc,

0 commit comments

Comments
 (0)