Skip to content

Commit eb777f8

Browse files
Drop object variables/coords before saving to_netcdf or to_zarr
1 parent e1903c0 commit eb777f8

File tree

3 files changed

+151
-6
lines changed

3 files changed

+151
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* Correctly (re)order dimensions for `bfmi` and `plot_energy` ([2126](https://github.com/arviz-devs/arviz/pull/2126))
1717
* Fix bug with the dimension order dependency ([2103](https://github.com/arviz-devs/arviz/pull/2103))
1818
* Add testing module for labeller classes ([2095](https://github.com/arviz-devs/arviz/pull/2095))
19+
* Variables or coordinates with an `object` dtype are dropped automatically when doing `InferenceData.to_netcdf()` or `InferenceData.to_zarr()`. Log messages are emitted about these droppings. ([2134](https://github.com/arviz-devs/arviz/pull/2134))
1920

2021
### Deprecation
2122
* Removed `fill_last`, `contour` and `plot_kwargs` arguments from `plot_pair` function ([2085](https://github.com/arviz-devs/arviz/pull/2085))

arviz/data/inference_data.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# pylint: disable=too-many-lines,too-many-public-methods
22
"""Data structure for using netcdf groups with xarray."""
3+
import logging
34
import sys
45
import uuid
56
import warnings
@@ -29,6 +30,7 @@
2930
import xarray as xr
3031
from packaging import version
3132

33+
3234
from ..rcparams import rcParams
3335
from ..utils import HtmlTemplate, _subset_list, either_dict_or_kwargs
3436
from .base import _extend_xr_method, _make_json_serializable, dict_to_dataset
@@ -50,6 +52,8 @@
5052
# https://github.com/python/mypy/issues/1153
5153
import json # type: ignore
5254

55+
_log = logging.getLogger("arviz")
56+
5357

5458
SUPPORTED_GROUPS = [
5559
"posterior",
@@ -411,15 +415,19 @@ def to_netcdf(
411415
str
412416
Location of netcdf file
413417
"""
418+
# We don't support saving of object-typed variables or coords.
419+
# The below helper function removes them while logging it.
420+
idata = drop_objects_from_inferencedata(self)
421+
414422
mode = "w" # overwrite first, then append
415-
if self._groups_all: # check's whether a group is present or not.
423+
if idata._groups_all: # check's whether a group is present or not.
416424
if groups is None:
417-
groups = self._groups_all
425+
groups = idata._groups_all
418426
else:
419-
groups = [group for group in self._groups_all if group in groups]
427+
groups = [group for group in idata._groups_all if group in groups]
420428

421429
for group in groups:
422-
data = getattr(self, group)
430+
data = getattr(idata, group)
423431
kwargs = {}
424432
if compress:
425433
kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
@@ -664,6 +672,10 @@ def to_zarr(self, store=None):
664672
except (ImportError, AssertionError) as err:
665673
raise ImportError("'to_zarr' method needs Zarr (2.5.0+) installed.") from err
666674

675+
# We don't support saving of object-typed variables or coords.
676+
# The below helper function removes them while logging it.
677+
idata = drop_objects_from_inferencedata(self)
678+
667679
# Check store type and create store if necessary
668680
if store is None:
669681
store = zarr.storage.TempStore(suffix="arviz")
@@ -672,14 +684,14 @@ def to_zarr(self, store=None):
672684
elif not isinstance(store, MutableMapping):
673685
raise TypeError(f"No valid store found: {store}")
674686

675-
groups = self.groups()
687+
groups = idata.groups()
676688

677689
if not groups:
678690
raise TypeError("No valid groups found!")
679691

680692
for group in groups:
681693
# Create zarr group in store with same group name
682-
getattr(self, group).to_zarr(store=store, group=group, mode="w")
694+
getattr(idata, group).to_zarr(store=store, group=group, mode="w")
683695

684696
return zarr.open(store) # Open store to get overarching group
685697

@@ -2210,3 +2222,50 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
22102222
inference_data_dict["attrs"] = combined_attr
22112223

22122224
return None if inplace else InferenceData(**inference_data_dict)
2225+
2226+
2227+
def drop_objects_from_dataset(data: xr.Dataset) -> Tuple[xr.Dataset, List[str], List[str]]:
2228+
"""Returns a new ``Dataset`` without object variables and coords."""
2229+
vars = {}
2230+
vars_dropped = []
2231+
for k, v in data.data_vars.items():
2232+
if v.dtype.hasobject:
2233+
vars_dropped.append(k)
2234+
continue
2235+
vars[k] = v
2236+
2237+
coords = {}
2238+
coords_dropped = []
2239+
for k, v in data.coords.items():
2240+
if v.dtype.hasobject:
2241+
coords_dropped.append(k)
2242+
continue
2243+
coords[k] = v
2244+
2245+
ndata = xr.Dataset(
2246+
data_vars=vars,
2247+
coords=coords,
2248+
attrs=data.attrs,
2249+
)
2250+
return ndata, vars_dropped, coords_dropped
2251+
2252+
2253+
def drop_objects_from_inferencedata(idata: InferenceData) -> InferenceData:
2254+
"""Returns a new InferenceData without object variables and coords.
2255+
2256+
All droppings are logged at WARNING level.
2257+
"""
2258+
nidata = InferenceData(attrs=idata.attrs)
2259+
for gname, group in idata.items():
2260+
ndata, vars_dropped, coords_dropped = drop_objects_from_dataset(group)
2261+
if "warning" in vars_dropped and "sample_stats" in gname:
2262+
vars_dropped.remove("warning")
2263+
_log.debug(
2264+
"Dropped 'warning' variable from '%s' group because it's dtyped object.", gname
2265+
)
2266+
if vars_dropped:
2267+
_log.warning("Dropped object variables from '%s' group: %s", gname, vars_dropped)
2268+
if coords_dropped:
2269+
_log.warning("Dropped object coords from '%s' group: %s", gname, coords_dropped)
2270+
nidata.add_groups({gname: ndata}, coords=ndata.coords, dims=ndata.dims)
2271+
return nidata

arviz/tests/base_tests/test_data.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections import namedtuple
66
from copy import deepcopy
77
from html import escape
8+
import logging
89
from typing import Dict
910
from urllib.parse import urlunsplit
1011

@@ -14,6 +15,8 @@
1415
from xarray.core.options import OPTIONS
1516
from xarray.testing import assert_identical
1617

18+
from arviz.data.inference_data import drop_objects_from_dataset, drop_objects_from_inferencedata
19+
1720
from ... import (
1821
InferenceData,
1922
clear_data_home,
@@ -1466,3 +1469,85 @@ def test_subset_samples(self):
14661469
post = extract(idata, num_samples=10)
14671470
assert post.dims["sample"] == 10
14681471
assert post.attrs == idata.posterior.attrs
1472+
1473+
1474+
class TestDropObjectVariables:
1475+
"""Tests for helper functions that drop object-typed variables/coords."""
1476+
1477+
def test_drop_variable_from_dataset(self):
1478+
a = np.ones((3,))
1479+
b = np.array([1, None, "test"])
1480+
ds = xr.Dataset(
1481+
data_vars={
1482+
"a": xr.DataArray(a, coords=dict(t=[0, 1, 2]), dims=("t",)),
1483+
"b": xr.DataArray(b, coords=dict(t=[0, 1, 2]), dims=("t",)),
1484+
}
1485+
)
1486+
new, dropped_vars, dropped_coords = drop_objects_from_dataset(ds)
1487+
assert isinstance(new, xr.Dataset)
1488+
assert new is not ds
1489+
assert "a" in new
1490+
assert "b" not in new
1491+
assert dropped_vars == ["b"]
1492+
assert not dropped_coords
1493+
1494+
def test_drop_coord_from_dataset(self):
1495+
a = np.ones((3,))
1496+
ds = xr.Dataset(
1497+
data_vars={
1498+
"a": xr.DataArray(a, coords=dict(adim=[0, 1, 2]), dims=("adim",)),
1499+
"b": xr.DataArray(a, coords=dict(bdim=[0, None, 2]), dims=("bdim",)),
1500+
}
1501+
)
1502+
new, dropped_vars, dropped_coords = drop_objects_from_dataset(ds)
1503+
assert isinstance(new, xr.Dataset)
1504+
assert new is not ds
1505+
assert "a" in new
1506+
assert "b" in new
1507+
assert not dropped_vars
1508+
assert dropped_coords == ["bdim"]
1509+
1510+
@pytest.mark.parametrize("vname", ["warning", "other"])
1511+
def test_drop_objects_from_inferencedata(self, vname, caplog):
1512+
idata = from_dict(
1513+
sample_stats={
1514+
"a": np.ones((2, 5, 4)),
1515+
vname: np.ones((2, 5, 3), dtype=object),
1516+
},
1517+
attrs=dict(version="0.1.2"),
1518+
coords={
1519+
"adim": [0, 1, None, 3],
1520+
"vdim": list("ABC"),
1521+
},
1522+
dims={"a": ["adim"], vname: ["vdim"]},
1523+
)
1524+
1525+
# Capture logging messages about the droppings
1526+
with caplog.at_level(logging.DEBUG, logger="arviz"):
1527+
new = drop_objects_from_inferencedata(idata)
1528+
1529+
assert new is not idata
1530+
assert new.attrs.get("version") == "0.1.2"
1531+
1532+
ss = new.get("sample_stats", None)
1533+
assert isinstance(ss, xr.Dataset)
1534+
assert "a" in ss
1535+
assert vname not in ss
1536+
assert caplog.records
1537+
1538+
# Check the logging about the dropped variable
1539+
rec = caplog.records[0]
1540+
if vname == "warning":
1541+
# DEBUG level for the 'warning' stat
1542+
# which in PyMC is an object an very commonly present.
1543+
assert rec.levelno == logging.DEBUG
1544+
assert "Dropped 'warning' variable from 'sample_stats'" in rec.message
1545+
else:
1546+
# WARNING level otherwise
1547+
assert rec.levelno == logging.WARNING
1548+
assert f"from 'sample_stats' group: ['{vname}']" in rec.message
1549+
1550+
# And the logging about dropped coord
1551+
rec = caplog.records[1]
1552+
assert rec.levelno == logging.WARNING
1553+
assert "object coords from 'sample_stats' group: ['adim']" in rec.message

0 commit comments

Comments
 (0)