Skip to content

BUG: Setting None coord is problematic #6485

Open
@ferrine

Description

@ferrine

Describe the issue:

When setting a coord None. Something goes broken in pymc 5.0.2.

Reproduceable code example:

# This is good
import pymc as pm
import numpy as np
with pm.Model(coords=dict(d1=range(2), d2=range(6))) as model:
    pm.Data("a", np.random.randn(2, 6), dims=("d1", "d2"), mutable=True)
    pm.Normal("b", 10)
    t = pm.sample(1, tune=1)

# this is broken
import pymc as pm
import numpy as np
with pm.Model(coords=dict(d1=range(2), d2=range(6))) as model:
    pm.Data("a", np.random.randn(2, 6), dims=(None, "d2"), mutable=True)
    pm.Normal("b", 10)
    t = pm.sample(1, tune=1)

Error message:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[14], line 6
      4 pm.Data("a", np.random.randn(2, 6), dims=(None, "d2"), mutable=True)
      5 pm.Normal("b", 10)
----> 6 t = pm.sample(1, tune=1)

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/pymc/sampling/mcmc.py:612, in sample(draws, step, init, n_init, initvals, trace, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, keep_warning_stat, idata_kwargs, mp_ctx, **kwargs)
    610 if idata_kwargs:
    611     ikwargs.update(idata_kwargs)
--> 612 idata = pm.to_inference_data(mtrace, **ikwargs)
    614 if compute_convergence_checks:
    615     if draws - tune < 100:

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/pymc/backends/arviz.py:485, in to_inference_data(trace, prior, posterior_predictive, log_likelihood, coords, dims, sample_dims, model, save_warmup, include_transformed)
    482 if isinstance(trace, InferenceData):
    483     return trace
--> 485 return InferenceDataConverter(
    486     trace=trace,
    487     prior=prior,
    488     posterior_predictive=posterior_predictive,
    489     log_likelihood=log_likelihood,
    490     coords=coords,
    491     dims=dims,
    492     sample_dims=sample_dims,
    493     model=model,
    494     save_warmup=save_warmup,
    495     include_transformed=include_transformed,
    496 ).to_inference_data()

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/pymc/backends/arviz.py:414, in InferenceDataConverter.to_inference_data(self)
    412     id_dict["predictions_constant_data"] = self.constant_data_to_xarray()
    413 else:
--> 414     id_dict["constant_data"] = self.constant_data_to_xarray()
    415 idata = InferenceData(save_warmup=self.save_warmup, **id_dict)
    416 if self.log_likelihood:

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/arviz/data/base.py:65, in requires.__call__.<locals>.wrapped(cls)
     63     if all((getattr(cls, prop_i) is None for prop_i in prop)):
     64         return None
---> 65 return func(cls)

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/pymc/backends/arviz.py:388, in InferenceDataConverter.constant_data_to_xarray(self)
    385 if not constant_data:
    386     return None
--> 388 return dict_to_dataset(
    389     constant_data,
    390     library=pymc,
    391     coords=self.coords,
    392     dims=self.dims,
    393     default_dims=[],
    394 )

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/arviz/data/base.py:306, in dict_to_dataset(data, attrs, library, coords, dims, default_dims, index_origin, skip_event_dims)
    303 if dims is None:
    304     dims = {}
--> 306 data_vars = {
    307     key: numpy_to_data_array(
    308         values,
    309         var_name=key,
    310         coords=coords,
    311         dims=dims.get(key),
    312         default_dims=default_dims,
    313         index_origin=index_origin,
    314         skip_event_dims=skip_event_dims,
    315     )
    316     for key, values in data.items()
    317 }
    318 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/arviz/data/base.py:307, in <dictcomp>(.0)
    303 if dims is None:
    304     dims = {}
    306 data_vars = {
--> 307     key: numpy_to_data_array(
    308         values,
    309         var_name=key,
    310         coords=coords,
    311         dims=dims.get(key),
    312         default_dims=default_dims,
    313         index_origin=index_origin,
    314         skip_event_dims=skip_event_dims,
    315     )
    316     for key, values in data.items()
    317 }
    318 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/arviz/data/base.py:255, in numpy_to_data_array(ary, var_name, coords, dims, default_dims, index_origin, skip_event_dims)
    253 # filter coords based on the dims
    254 coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
--> 255 return xr.DataArray(ary, coords=coords, dims=dims)

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/xarray/core/dataarray.py:419, in DataArray.__init__(self, data, coords, dims, name, attrs, indexes, fastpath)
    417 data = _check_data_shape(data, coords, dims)
    418 data = as_compatible_data(data)
--> 419 coords, dims = _infer_coords_and_dims(data.shape, coords, dims)
    420 variable = Variable(dims, data, attrs, fastpath=True)
    421 indexes, coords = _create_indexes_from_coords(coords)

File ~/micromamba/envs/gp_mmm/lib/python3.9/site-packages/xarray/core/dataarray.py:164, in _infer_coords_and_dims(shape, coords, dims)
    162 for d, s in zip(v.dims, v.shape):
    163     if s != sizes[d]:
--> 164         raise ValueError(
    165             f"conflicting sizes for dimension {d!r}: "
    166             f"length {sizes[d]} on the data but length {s} on "
    167             f"coordinate {k!r}"
    168         )
    170 if k in sizes and v.shape != (sizes[k],):
    171     raise ValueError(
    172         f"coordinate {k!r} is a DataArray dimension, but "
    173         f"it has shape {v.shape!r} rather than expected shape {sizes[k]!r} "
    174         "matching the dimension size"
    175     )

ValueError: conflicting sizes for dimension 'd2': length 2 on the data but length 6 on coordinate 'd2'

PyMC version information:

pymc 5.0.2

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions