Description
What is your issue?
Repost of xarray-contrib/datatree#277, with some updates.
Test case
Write a tree containing 13 nodes and negligible data to S3/GCS with fsspec:
import numpy as np
import xarray as xr
ds = xr.Dataset(
data_vars={
"a": xr.DataArray(np.ones((2, 2)), coords={"x": [1, 2], "y": [1, 2]}),
"b": xr.DataArray(np.ones((2, 2)), coords={"x": [1, 2], "y": [1, 2]}),
"c": xr.DataArray(np.ones((2, 2)), coords={"x": [1, 2], "y": [1, 2]}),
}
)
dt = xr.core.datatree.DataTree()
for first_level in [1, 2, 3]:
dt[f"{first_level}"] = DataTree(ds)
for second_level in [1, 2, 3]:
dt[f"{first_level}/{second_level}"] = DataTree(ds)
%time dt.to_zarr("test.zarr", mode="w")
bucket = "s3|gs://your-bucket/path"
%time dt.to_zarr(f"{bucket}/test.zarr", mode="w")
Gives:
CPU times: user 287 ms, sys: 43.9 ms, total: 331 ms
Wall time: 331 ms
CPU times: user 3.22 s, sys: 219 ms, total: 3.44 s
Wall time: 1min 4s
This is a bit better than in the original issue due to improvements elsewhere in the stack, but still really slow for heavily nested but otherwise small datasets.
Potential Improvements
#9014 did make some decent improvements to read speed. When reading the dataset written above I get:
%timeit xr.backends.api.open_datatree(f"{bucket}/test.zarr", engine="zarr")
%timeit datatree.open_datatree(f"{bucket}/test.zarr", engine="zarr")
882 ms ± 47.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.47 s ± 86.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
We'll need similar optimizations on the write side. The fundamental issue is that DataTree.to_zarr
relies on serial Dataset.to_zarr
calls for each node:
xarray/xarray/core/datatree_io.py
Lines 153 to 171 in 12c690f
This results in many fsspec
calls to list dirs, check file existence, and put small metadata and attribute files in the bucket. Here's snakeviz
on the example:
(The 8s block on the right is metadata consolidation)
Workaround
If your data is small enough to dump locally, this works great:
def to_zarr(dt, path):
with TemporaryDirectory() as tmp_path:
dt.to_zarr(tmp_path)
fs.put(tmp_path, path, recursive=True)
Takes about 1s.