Skip to content

Commit a21c3a9

Browse files
Drop object variables/coords before saving to_netcdf or to_zarr
1 parent e1903c0 commit a21c3a9

File tree

3 files changed

+153
-6
lines changed

3 files changed

+153
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* Correctly (re)order dimensions for `bfmi` and `plot_energy` ([2126](https://github.com/arviz-devs/arviz/pull/2126))
1717
* Fix bug with the dimension order dependency ([2103](https://github.com/arviz-devs/arviz/pull/2103))
1818
* Add testing module for labeller classes ([2095](https://github.com/arviz-devs/arviz/pull/2095))
19+
* Variables or coordinates with an `object` dtype are dropped automatically when doing `InferenceData.to_netcdf()` or `InferenceData.to_zarr()`. Log messages are emitted about these droppings. ([2134](https://github.com/arviz-devs/arviz/pull/2134))
1920

2021
### Deprecation
2122
* Removed `fill_last`, `contour` and `plot_kwargs` arguments from `plot_pair` function ([2085](https://github.com/arviz-devs/arviz/pull/2085))

arviz/data/inference_data.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# pylint: disable=too-many-lines,too-many-public-methods
22
"""Data structure for using netcdf groups with xarray."""
3+
import logging
34
import sys
45
import uuid
56
import warnings
@@ -29,6 +30,7 @@
2930
import xarray as xr
3031
from packaging import version
3132

33+
3234
from ..rcparams import rcParams
3335
from ..utils import HtmlTemplate, _subset_list, either_dict_or_kwargs
3436
from .base import _extend_xr_method, _make_json_serializable, dict_to_dataset
@@ -50,6 +52,8 @@
5052
# https://github.com/python/mypy/issues/1153
5153
import json # type: ignore
5254

55+
_log = logging.getLogger("arviz")
56+
5357

5458
SUPPORTED_GROUPS = [
5559
"posterior",
@@ -411,15 +415,21 @@ def to_netcdf(
411415
str
412416
Location of netcdf file
413417
"""
418+
# We don't support saving of object-typed variables or coords.
419+
# The below helper function removes them while logging it.
420+
idata = drop_objects_from_inferencedata(self)
421+
414422
mode = "w" # overwrite first, then append
415-
if self._groups_all: # check's whether a group is present or not.
423+
# check's whether a group is present or not.
424+
groups_all = idata._groups_all # pylint: disable=W0212
425+
if groups_all:
416426
if groups is None:
417-
groups = self._groups_all
427+
groups = groups_all
418428
else:
419-
groups = [group for group in self._groups_all if group in groups]
429+
groups = [group for group in groups_all if group in groups]
420430

421431
for group in groups:
422-
data = getattr(self, group)
432+
data = getattr(idata, group)
423433
kwargs = {}
424434
if compress:
425435
kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
@@ -664,6 +674,10 @@ def to_zarr(self, store=None):
664674
except (ImportError, AssertionError) as err:
665675
raise ImportError("'to_zarr' method needs Zarr (2.5.0+) installed.") from err
666676

677+
# We don't support saving of object-typed variables or coords.
678+
# The below helper function removes them while logging it.
679+
idata = drop_objects_from_inferencedata(self)
680+
667681
# Check store type and create store if necessary
668682
if store is None:
669683
store = zarr.storage.TempStore(suffix="arviz")
@@ -672,14 +686,14 @@ def to_zarr(self, store=None):
672686
elif not isinstance(store, MutableMapping):
673687
raise TypeError(f"No valid store found: {store}")
674688

675-
groups = self.groups()
689+
groups = idata.groups()
676690

677691
if not groups:
678692
raise TypeError("No valid groups found!")
679693

680694
for group in groups:
681695
# Create zarr group in store with same group name
682-
getattr(self, group).to_zarr(store=store, group=group, mode="w")
696+
getattr(idata, group).to_zarr(store=store, group=group, mode="w")
683697

684698
return zarr.open(store) # Open store to get overarching group
685699

@@ -2210,3 +2224,50 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
22102224
inference_data_dict["attrs"] = combined_attr
22112225

22122226
return None if inplace else InferenceData(**inference_data_dict)
2227+
2228+
2229+
def drop_objects_from_dataset(data: xr.Dataset) -> Tuple[xr.Dataset, List[str], List[str]]:
2230+
"""Returns a new ``Dataset`` without object variables and coords."""
2231+
vals = {}
2232+
vals_dropped = []
2233+
for vname, val in data.data_vars.items():
2234+
if val.dtype.hasobject:
2235+
vals_dropped.append(vname)
2236+
continue
2237+
vals[vname] = val
2238+
2239+
coords = {}
2240+
coords_dropped = []
2241+
for cname, cval in data.coords.items():
2242+
if cval.dtype.hasobject:
2243+
coords_dropped.append(cname)
2244+
continue
2245+
coords[cname] = cval
2246+
2247+
ndata = xr.Dataset(
2248+
data_vars=vals,
2249+
coords=coords,
2250+
attrs=data.attrs,
2251+
)
2252+
return ndata, vals_dropped, coords_dropped
2253+
2254+
2255+
def drop_objects_from_inferencedata(idata: InferenceData) -> InferenceData:
2256+
"""Returns a new InferenceData without object variables and coords.
2257+
2258+
All droppings are logged at WARNING level.
2259+
"""
2260+
nidata = InferenceData(attrs=idata.attrs)
2261+
for gname, group in idata.items():
2262+
ndata, vars_dropped, coords_dropped = drop_objects_from_dataset(group)
2263+
if "warning" in vars_dropped and "sample_stats" in gname:
2264+
vars_dropped.remove("warning")
2265+
_log.debug(
2266+
"Dropped 'warning' variable from '%s' group because it's dtyped object.", gname
2267+
)
2268+
if vars_dropped:
2269+
_log.warning("Dropped object variables from '%s' group: %s", gname, vars_dropped)
2270+
if coords_dropped:
2271+
_log.warning("Dropped object coords from '%s' group: %s", gname, coords_dropped)
2272+
nidata.add_groups({gname: ndata}, coords=ndata.coords, dims=ndata.dims)
2273+
return nidata

arviz/tests/base_tests/test_data.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections import namedtuple
66
from copy import deepcopy
77
from html import escape
8+
import logging
89
from typing import Dict
910
from urllib.parse import urlunsplit
1011

@@ -14,6 +15,8 @@
1415
from xarray.core.options import OPTIONS
1516
from xarray.testing import assert_identical
1617

18+
from ...data.inference_data import drop_objects_from_dataset, drop_objects_from_inferencedata
19+
1720
from ... import (
1821
InferenceData,
1922
clear_data_home,
@@ -1466,3 +1469,85 @@ def test_subset_samples(self):
14661469
post = extract(idata, num_samples=10)
14671470
assert post.dims["sample"] == 10
14681471
assert post.attrs == idata.posterior.attrs
1472+
1473+
1474+
class TestDropObjectVariables:
1475+
"""Tests for helper functions that drop object-typed variables/coords."""
1476+
1477+
def test_drop_variable_from_dataset(self):
1478+
a = np.ones((3,))
1479+
b = np.array([1, None, "test"])
1480+
ds = xr.Dataset(
1481+
data_vars={
1482+
"a": xr.DataArray(a, coords=dict(t=[0, 1, 2]), dims=("t",)),
1483+
"b": xr.DataArray(b, coords=dict(t=[0, 1, 2]), dims=("t",)),
1484+
}
1485+
)
1486+
new, dropped_vars, dropped_coords = drop_objects_from_dataset(ds)
1487+
assert isinstance(new, xr.Dataset)
1488+
assert new is not ds
1489+
assert "a" in new
1490+
assert "b" not in new
1491+
assert dropped_vars == ["b"]
1492+
assert not dropped_coords
1493+
1494+
def test_drop_coord_from_dataset(self):
1495+
a = np.ones((3,))
1496+
ds = xr.Dataset(
1497+
data_vars={
1498+
"a": xr.DataArray(a, coords=dict(adim=["A", "B", "C"]), dims=("adim",)),
1499+
"b": xr.DataArray(a, coords=dict(bdim=[0, None, 2]), dims=("bdim",)),
1500+
}
1501+
)
1502+
new, dropped_vars, dropped_coords = drop_objects_from_dataset(ds)
1503+
assert isinstance(new, xr.Dataset)
1504+
assert new is not ds
1505+
assert "a" in new
1506+
assert "b" in new
1507+
assert not dropped_vars
1508+
assert dropped_coords == ["bdim"]
1509+
1510+
@pytest.mark.parametrize("vname", ["warning", "other"])
1511+
def test_drop_objects_from_inferencedata(self, vname, caplog):
1512+
idata = from_dict(
1513+
sample_stats={
1514+
"a": np.ones((2, 5, 4)),
1515+
vname: np.ones((2, 5, 3), dtype=object),
1516+
},
1517+
attrs=dict(version="0.1.2"),
1518+
coords={
1519+
"adim": [0, 1, None, 3],
1520+
"vdim": list("ABC"),
1521+
},
1522+
dims={"a": ["adim"], vname: ["vdim"]},
1523+
)
1524+
1525+
# Capture logging messages about the droppings
1526+
with caplog.at_level(logging.DEBUG, logger="arviz"):
1527+
new = drop_objects_from_inferencedata(idata)
1528+
1529+
assert new is not idata
1530+
assert new.attrs.get("version") == "0.1.2"
1531+
1532+
ss = new.get("sample_stats", None)
1533+
assert isinstance(ss, xr.Dataset)
1534+
assert "a" in ss
1535+
assert vname not in ss
1536+
assert caplog.records
1537+
1538+
# Check the logging about the dropped variable
1539+
rec = caplog.records[0]
1540+
if vname == "warning":
1541+
# DEBUG level for the 'warning' stat
1542+
# which in PyMC is an object an very commonly present.
1543+
assert rec.levelno == logging.DEBUG
1544+
assert "Dropped 'warning' variable from 'sample_stats'" in rec.message
1545+
else:
1546+
# WARNING level otherwise
1547+
assert rec.levelno == logging.WARNING
1548+
assert f"from 'sample_stats' group: ['{vname}']" in rec.message
1549+
1550+
# And the logging about dropped coord
1551+
rec = caplog.records[1]
1552+
assert rec.levelno == logging.WARNING
1553+
assert "object coords from 'sample_stats' group: ['adim']" in rec.message

0 commit comments

Comments
 (0)