Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
* Fix bug with the dimension order dependency ([2103](https://github.com/arviz-devs/arviz/pull/2103))
* Add testing module for labeller classes ([2095](https://github.com/arviz-devs/arviz/pull/2095))
* Skip compression for object dtype while creating a netcdf file ([2129](https://github.com/arviz-devs/arviz/pull/2129))
* Variables or coordinates with an `object` dtype are dropped automatically when doing `InferenceData.to_netcdf()` or `InferenceData.to_zarr()`. Log messages are emitted about these droppings. ([2134](https://github.com/arviz-devs/arviz/pull/2134))

### Deprecation
* Removed `fill_last`, `contour` and `plot_kwargs` arguments from `plot_pair` function ([2085](https://github.com/arviz-devs/arviz/pull/2085))
Expand Down
1,429 changes: 954 additions & 475 deletions arviz/data/example_data/code/centered_eight/centered_eight.ipynb

Large diffs are not rendered by default.

1,556 changes: 1,017 additions & 539 deletions arviz/data/example_data/code/non_centered_eight/non_centered_eight.ipynb

Large diffs are not rendered by default.

Binary file modified arviz/data/example_data/data/centered_eight.nc
Binary file not shown.
Binary file modified arviz/data/example_data/data/non_centered_eight.nc
Binary file not shown.
73 changes: 67 additions & 6 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=too-many-lines,too-many-public-methods
"""Data structure for using netcdf groups with xarray."""
import logging
import sys
import uuid
import warnings
Expand Down Expand Up @@ -29,6 +30,7 @@
import xarray as xr
from packaging import version


from ..rcparams import rcParams
from ..utils import HtmlTemplate, _subset_list, either_dict_or_kwargs
from .base import _extend_xr_method, _make_json_serializable, dict_to_dataset
Expand All @@ -50,6 +52,8 @@
# https://github.com/python/mypy/issues/1153
import json # type: ignore

_log = logging.getLogger("arviz")


SUPPORTED_GROUPS = [
"posterior",
Expand Down Expand Up @@ -418,15 +422,21 @@ def to_netcdf(
str
Location of netcdf file
"""
# We don't support saving of object-typed variables or coords.
# The below helper function removes them while logging it.
idata = drop_objects_from_inferencedata(self)

mode = "w" # overwrite first, then append
if self._groups_all: # check's whether a group is present or not.
# check's whether a group is present or not.
groups_all = idata._groups_all # pylint: disable=W0212
if groups_all:
if groups is None:
groups = self._groups_all
groups = groups_all
else:
groups = [group for group in self._groups_all if group in groups]
groups = [group for group in groups_all if group in groups]

for group in groups:
data = getattr(self, group)
data = getattr(idata, group)
kwargs = {}
if compress:
kwargs["encoding"] = {
Expand Down Expand Up @@ -675,6 +685,10 @@ def to_zarr(self, store=None):
except (ImportError, AssertionError) as err:
raise ImportError("'to_zarr' method needs Zarr (2.5.0+) installed.") from err

# We don't support saving of object-typed variables or coords.
# The below helper function removes them while logging it.
idata = drop_objects_from_inferencedata(self)

# Check store type and create store if necessary
if store is None:
store = zarr.storage.TempStore(suffix="arviz")
Expand All @@ -683,14 +697,14 @@ def to_zarr(self, store=None):
elif not isinstance(store, MutableMapping):
raise TypeError(f"No valid store found: {store}")

groups = self.groups()
groups = idata.groups()

if not groups:
raise TypeError("No valid groups found!")

for group in groups:
# Create zarr group in store with same group name
getattr(self, group).to_zarr(store=store, group=group, mode="w")
getattr(idata, group).to_zarr(store=store, group=group, mode="w")

return zarr.open(store) # Open store to get overarching group

Expand Down Expand Up @@ -2221,3 +2235,50 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
inference_data_dict["attrs"] = combined_attr

return None if inplace else InferenceData(**inference_data_dict)


def drop_objects_from_dataset(data: xr.Dataset) -> Tuple[xr.Dataset, List[str], List[str]]:
"""Returns a new ``Dataset`` without object variables and coords."""
vals = {}
vals_dropped = []
for vname, val in data.data_vars.items():
if val.dtype.hasobject:
vals_dropped.append(vname)
continue
vals[vname] = val

coords = {}
coords_dropped = []
for cname, cval in data.coords.items():
if cval.dtype.hasobject:
coords_dropped.append(cname)
continue
coords[cname] = cval

ndata = xr.Dataset(
data_vars=vals,
coords=coords,
attrs=data.attrs,
)
return ndata, vals_dropped, coords_dropped


def drop_objects_from_inferencedata(idata: InferenceData) -> InferenceData:
"""Returns a new InferenceData without object variables and coords.

All droppings are logged at WARNING level.
"""
nidata = InferenceData(attrs=idata.attrs)
for gname, group in idata.items():
ndata, vars_dropped, coords_dropped = drop_objects_from_dataset(group)
if "warning" in vars_dropped and "sample_stats" in gname:
vars_dropped.remove("warning")
_log.debug(
"Dropped 'warning' variable from '%s' group because it's dtyped object.", gname
)
if vars_dropped:
_log.warning("Dropped object variables from '%s' group: %s", gname, vars_dropped)
if coords_dropped:
_log.warning("Dropped object coords from '%s' group: %s", gname, coords_dropped)
nidata.add_groups({gname: ndata}, coords=ndata.coords, dims=ndata.dims)
return nidata
93 changes: 93 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import namedtuple
from copy import deepcopy
from html import escape
import logging
from typing import Dict
from urllib.parse import urlunsplit

Expand All @@ -14,6 +15,8 @@
from xarray.core.options import OPTIONS
from xarray.testing import assert_identical

from ...data.inference_data import drop_objects_from_dataset, drop_objects_from_inferencedata

from ... import (
InferenceData,
clear_data_home,
Expand Down Expand Up @@ -1469,3 +1472,93 @@ def test_subset_samples(self):
post = extract(idata, num_samples=10)
assert post.dims["sample"] == 10
assert post.attrs == idata.posterior.attrs


class TestDropObjectVariables:
"""Tests for helper functions that drop object-typed variables/coords."""

def test_drop_variable_from_dataset(self):
a = np.ones((3,))
b = np.array([1, None, "test"])
ds = xr.Dataset(
data_vars={
"a": xr.DataArray(a, coords=dict(t=[0, 1, 2]), dims=("t",)),
"b": xr.DataArray(b, coords=dict(t=[0, 1, 2]), dims=("t",)),
}
)
new, dropped_vars, dropped_coords = drop_objects_from_dataset(ds)
assert isinstance(new, xr.Dataset)
assert new is not ds
assert "a" in new
assert "b" not in new
assert dropped_vars == ["b"]
assert not dropped_coords

def test_drop_coord_from_dataset(self):
a = np.ones((3,))
ds = xr.Dataset(
data_vars={
"a": xr.DataArray(a, coords=dict(adim=["A", "B", "C"]), dims=("adim",)),
"b": xr.DataArray(a, coords=dict(bdim=[0, None, 2]), dims=("bdim",)),
}
)
new, dropped_vars, dropped_coords = drop_objects_from_dataset(ds)
assert isinstance(new, xr.Dataset)
assert new is not ds
assert "a" in new
assert "b" in new
assert not dropped_vars
assert dropped_coords == ["bdim"]

@pytest.mark.parametrize("vname", ["warning", "other"])
def test_drop_objects_from_inferencedata(self, vname, caplog):
idata = from_dict(
sample_stats={
"a": np.ones((2, 5, 4)),
vname: np.ones((2, 5, 3), dtype=object),
},
attrs=dict(version="0.1.2"),
coords={
"adim": [0, 1, None, 3],
"vdim": list("ABC"),
},
dims={"a": ["adim"], vname: ["vdim"]},
)

# Capture logging messages about the droppings
with caplog.at_level(logging.DEBUG, logger="arviz"):
new = drop_objects_from_inferencedata(idata)

assert new is not idata
assert new.attrs.get("version") == "0.1.2"

ss = new.get("sample_stats", None)
assert isinstance(ss, xr.Dataset)
assert "a" in ss
assert vname not in ss
assert caplog.records

# Check the logging about the dropped variable
rec = caplog.records[0]
if vname == "warning":
# DEBUG level for the 'warning' stat
# which in PyMC is an object an very commonly present.
assert rec.levelno == logging.DEBUG
assert "Dropped 'warning' variable from 'sample_stats'" in rec.message
else:
# WARNING level otherwise
assert rec.levelno == logging.WARNING
assert f"from 'sample_stats' group: ['{vname}']" in rec.message

# And the logging about dropped coord
rec = caplog.records[1]
assert rec.levelno == logging.WARNING
assert "object coords from 'sample_stats' group: ['adim']" in rec.message

@pytest.mark.parametrize("name", LOCAL_DATASETS.keys())
def test_no_objects_in_example_datasets(self, name):
idata = load_arviz_data(name)
for group in idata.values():
_, dropped_vars, dropped_coords = drop_objects_from_dataset(group)
assert not dropped_vars
assert not dropped_coords