Skip to content

Support saving as netcdf InferenceData that has MultiIndex coordinates #2165

@lucianopaz

Description

@lucianopaz

Tell us about it

There are many situations in which it is very convenient to use pandas.MultiIndex as coordinates of an xarray.DataArray. The problem is that, at the moment, xarray doesn't provide a builtin way to save these indexes in netcdf format. Take for example:

from arviz.tests.helpers import create_model

idata = create_model()
idata.posterior = idata.posterior.stack(sample=["chain", "draw"])
idata.to_netcdf("test.nc")

This raises a NotImplementedError with the following traceback

NotImplementedError                       Traceback (most recent call last)
<ipython-input-15-43e455b97609> in <module>
      3 idata = create_model()
      4 idata.posterior = idata.posterior.stack(sample=["chain", "draw"])
----> 5 idata.to_netcdf("test.nc")

~/anaconda3/lib/python3.9/site-packages/arviz/data/inference_data.py in to_netcdf(self, filename, compress, groups)
    442                         if _compressible_dtype(values.dtype)
    443                     }
--> 444                 data.to_netcdf(filename, mode=mode, group=group, **kwargs)
    445                 data.close()
    446                 mode = "a"

~/anaconda3/lib/python3.9/site-packages/xarray/core/dataset.py in to_netcdf(self, path, mode, format, group, engine, encoding, unlimited_dims, compute, invalid_netcdf)
   1898         from ..backends.api import to_netcdf
   1899 
-> 1900         return to_netcdf(
   1901             self,
   1902             path,

~/anaconda3/lib/python3.9/site-packages/xarray/backends/api.py in to_netcdf(dataset, path_or_file, mode, format, group, engine, encoding, unlimited_dims, compute, multifile, invalid_netcdf)
   1070         # TODO: allow this work (setting up the file for writing array data)
   1071         # to be parallelized with dask
-> 1072         dump_to_store(
   1073             dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
   1074         )

~/anaconda3/lib/python3.9/site-packages/xarray/backends/api.py in dump_to_store(dataset, store, writer, encoder, encoding, unlimited_dims)
   1117         variables, attrs = encoder(variables, attrs)
   1118 
-> 1119     store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
   1120 
   1121 

~/anaconda3/lib/python3.9/site-packages/xarray/backends/common.py in store(self, variables, attributes, check_encoding_set, writer, unlimited_dims)
    259             writer = ArrayWriter()
    260 
--> 261         variables, attributes = self.encode(variables, attributes)
    262 
    263         self.set_attributes(attributes)

~/anaconda3/lib/python3.9/site-packages/xarray/backends/common.py in encode(self, variables, attributes)
    348         # All NetCDF files get CF encoded by default, without this attempting
    349         # to write times, for example, would fail.
--> 350         variables, attributes = cf_encoder(variables, attributes)
    351         variables = {k: self.encode_variable(v) for k, v in variables.items()}
    352         attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}

~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in cf_encoder(variables, attributes)
    857     _update_bounds_encoding(variables)
    858 
--> 859     new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
    860 
    861     # Remove attrs from bounds variables (issue #2921)

~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in <dictcomp>(.0)
    857     _update_bounds_encoding(variables)
    858 
--> 859     new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
    860 
    861     # Remove attrs from bounds variables (issue #2921)

~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in encode_cf_variable(var, needs_copy, name)
    262         A variable which has been encoded as described above.
    263     """
--> 264     ensure_not_multiindex(var, name=name)
    265 
    266     for coder in [

~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in ensure_not_multiindex(var, name)
    177 def ensure_not_multiindex(var, name=None):
    178     if isinstance(var, IndexVariable) and isinstance(var.to_index(), pd.MultiIndex):
--> 179         raise NotImplementedError(
    180             "variable {!r} is a MultiIndex, which cannot yet be "
    181             "serialized to netCDF files "

NotImplementedError: variable 'sample' is a MultiIndex, which cannot yet be serialized to netCDF files (https://github.com/pydata/xarray/issues/1077). Use reset_index() to convert MultiIndex levels into coordinate variables instead.

Thoughts on implementation

I had a look at the mentioned xarray issue, and the approach suggested by @dcherian works (at least in the scenario that I had to work with a month ago). I think that it would be good to incorporate something like that into arviz.from_netcdf and InferenceData.to_netcdf. The basic idea is to convert the MultiIndex into a simple array of integers, that are the codes of the MultiIndex, and also add an attribute that states that the dimension/coordinates were originally a MultiIndex. This attribute is also used to keep track of the level values and names of the original MultiIndex. The modified datastructure can be serialized to netcdf without any problems. The only thing to be aware of is that when the netcdf is loaded, some work has to happen to rebuild the MultiIndex from the original coordinates. I think that this small overhead is worth the benefit of bringing MultiIndex support to arviz.

If you all agree that this would be valuable, I can write a PR.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions