-
-
Notifications
You must be signed in to change notification settings - Fork 461
Description
Tell us about it
There are many situations in which it is very convenient to use pandas.MultiIndex as coordinates of an xarray.DataArray. The problem is that, at the moment, xarray doesn't provide a builtin way to save these indexes in netcdf format. Take for example:
from arviz.tests.helpers import create_model
idata = create_model()
idata.posterior = idata.posterior.stack(sample=["chain", "draw"])
idata.to_netcdf("test.nc")This raises a NotImplementedError with the following traceback
NotImplementedError Traceback (most recent call last)
<ipython-input-15-43e455b97609> in <module>
3 idata = create_model()
4 idata.posterior = idata.posterior.stack(sample=["chain", "draw"])
----> 5 idata.to_netcdf("test.nc")
~/anaconda3/lib/python3.9/site-packages/arviz/data/inference_data.py in to_netcdf(self, filename, compress, groups)
442 if _compressible_dtype(values.dtype)
443 }
--> 444 data.to_netcdf(filename, mode=mode, group=group, **kwargs)
445 data.close()
446 mode = "a"
~/anaconda3/lib/python3.9/site-packages/xarray/core/dataset.py in to_netcdf(self, path, mode, format, group, engine, encoding, unlimited_dims, compute, invalid_netcdf)
1898 from ..backends.api import to_netcdf
1899
-> 1900 return to_netcdf(
1901 self,
1902 path,
~/anaconda3/lib/python3.9/site-packages/xarray/backends/api.py in to_netcdf(dataset, path_or_file, mode, format, group, engine, encoding, unlimited_dims, compute, multifile, invalid_netcdf)
1070 # TODO: allow this work (setting up the file for writing array data)
1071 # to be parallelized with dask
-> 1072 dump_to_store(
1073 dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
1074 )
~/anaconda3/lib/python3.9/site-packages/xarray/backends/api.py in dump_to_store(dataset, store, writer, encoder, encoding, unlimited_dims)
1117 variables, attrs = encoder(variables, attrs)
1118
-> 1119 store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
1120
1121
~/anaconda3/lib/python3.9/site-packages/xarray/backends/common.py in store(self, variables, attributes, check_encoding_set, writer, unlimited_dims)
259 writer = ArrayWriter()
260
--> 261 variables, attributes = self.encode(variables, attributes)
262
263 self.set_attributes(attributes)
~/anaconda3/lib/python3.9/site-packages/xarray/backends/common.py in encode(self, variables, attributes)
348 # All NetCDF files get CF encoded by default, without this attempting
349 # to write times, for example, would fail.
--> 350 variables, attributes = cf_encoder(variables, attributes)
351 variables = {k: self.encode_variable(v) for k, v in variables.items()}
352 attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}
~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in cf_encoder(variables, attributes)
857 _update_bounds_encoding(variables)
858
--> 859 new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
860
861 # Remove attrs from bounds variables (issue #2921)
~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in <dictcomp>(.0)
857 _update_bounds_encoding(variables)
858
--> 859 new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
860
861 # Remove attrs from bounds variables (issue #2921)
~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in encode_cf_variable(var, needs_copy, name)
262 A variable which has been encoded as described above.
263 """
--> 264 ensure_not_multiindex(var, name=name)
265
266 for coder in [
~/anaconda3/lib/python3.9/site-packages/xarray/conventions.py in ensure_not_multiindex(var, name)
177 def ensure_not_multiindex(var, name=None):
178 if isinstance(var, IndexVariable) and isinstance(var.to_index(), pd.MultiIndex):
--> 179 raise NotImplementedError(
180 "variable {!r} is a MultiIndex, which cannot yet be "
181 "serialized to netCDF files "
NotImplementedError: variable 'sample' is a MultiIndex, which cannot yet be serialized to netCDF files (https://github.com/pydata/xarray/issues/1077). Use reset_index() to convert MultiIndex levels into coordinate variables instead.Thoughts on implementation
I had a look at the mentioned xarray issue, and the approach suggested by @dcherian works (at least in the scenario that I had to work with a month ago). I think that it would be good to incorporate something like that into arviz.from_netcdf and InferenceData.to_netcdf. The basic idea is to convert the MultiIndex into a simple array of integers, that are the codes of the MultiIndex, and also add an attribute that states that the dimension/coordinates were originally a MultiIndex. This attribute is also used to keep track of the level values and names of the original MultiIndex. The modified datastructure can be serialized to netcdf without any problems. The only thing to be aware of is that when the netcdf is loaded, some work has to happen to rebuild the MultiIndex from the original coordinates. I think that this small overhead is worth the benefit of bringing MultiIndex support to arviz.
If you all agree that this would be valuable, I can write a PR.