Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initialize_zarr #8460

Draft
wants to merge 29 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
994af64
Add initialize_zarr
dcherian Nov 16, 2023
4437aab
Update xarray/backends/zarr.py
dcherian Nov 16, 2023
e8bf524
Update xarray/backends/zarr.py
dcherian Nov 16, 2023
c0cf4ee
add init_zarr_v2 to demonstrate another approach
jhamman Nov 17, 2023
45b6b27
Merge branch 'main' into init-zarr
dcherian Dec 20, 2023
c273e62
clean up _put_attrs
dcherian Dec 20, 2023
be75bb2
Use add_array_to_store in main code path.
dcherian Dec 20, 2023
a88a878
Rewrite
dcherian Dec 20, 2023
1bc2d84
Some fixes
dcherian Dec 20, 2023
a449220
Switch to using Xarray's Zarr store
dcherian Dec 21, 2023
e4cced7
Some typing
dcherian Dec 21, 2023
9f69e51
Move to api.py
dcherian Dec 21, 2023
82a9747
minor typing
dcherian Dec 21, 2023
eac8e66
Add tests
dcherian Dec 21, 2023
2b78a12
Skip checking indexes for region writes.
dcherian Jan 3, 2024
bffda0a
WIP
dcherian Jan 3, 2024
84fe5f6
Support encoding
dcherian Jan 3, 2024
c4f75e9
Avoid SerializationWarning
dcherian Jan 3, 2024
6dd0d78
Fix mode="w-" test
dcherian Jan 3, 2024
0a78c52
Remove store=None case
dcherian Jan 3, 2024
d197cc3
Fix typing
dcherian Jan 4, 2024
ce3b17d
Explicitly avoid encoding coordinates multiple times
dcherian Jan 4, 2024
1498c35
Merge remote-tracking branch 'upstream/main' into init-zarr
dcherian Jan 4, 2024
8d11876
Comprehensively test with every store.
dcherian Jan 4, 2024
9bec06d
Avoid region change.
dcherian Jan 4, 2024
c01edf1
Handle scalar vars properly
dcherian Jan 4, 2024
56bb1a1
Fix docstring?
dcherian Jan 4, 2024
bb9f72f
Small updates
dcherian Jan 5, 2024
a529f1d
Merge branch 'main' into init-zarr
dcherian Jan 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Top-level functions
combine_nested
where
infer_freq
initialize_zarr
full_like
zeros_like
ones_like
Expand Down
2 changes: 2 additions & 0 deletions xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from xarray import testing, tutorial
from xarray.backends.api import (
initialize_zarr,
load_dataarray,
load_dataset,
open_dataarray,
Expand Down Expand Up @@ -75,6 +76,7 @@
"full_like",
"get_options",
"infer_freq",
"initialize_zarr",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want xarray.zarr.initialize_zarr or xarray.backends.initialize_zarr instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a case for Dataset.to_zarr_initialize? It's a bit awkward, but it does avoid another top-level method, which is harder to find for non-expert users.

"load_dataarray",
"load_dataset",
"map_blocks",
Expand Down
249 changes: 243 additions & 6 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from xarray.core.dataset import Dataset, _get_chunk, _maybe_chunk
from xarray.core.indexes import Index
from xarray.core.parallelcompat import guess_chunkmanager
from xarray.core.pycompat import is_chunked_array
from xarray.core.types import ZarrWriteModes
from xarray.core.utils import is_remote_uri

Expand Down Expand Up @@ -1338,7 +1339,14 @@ def to_netcdf(


def dump_to_store(
dataset, store, writer=None, encoder=None, encoding=None, unlimited_dims=None
dataset,
store,
writer=None,
encoder=None,
encoding=None,
unlimited_dims=None,
*,
encode_coordinates=True,
):
"""Store dataset contents to a backends.*DataStore object."""
if writer is None:
Expand All @@ -1347,7 +1355,12 @@ def dump_to_store(
if encoding is None:
encoding = {}

variables, attrs = conventions.encode_dataset_coordinates(dataset)
# IMPORTANT: Any changes here will need to be duplicated in initialize_zarr
if encode_coordinates:
variables, attrs = conventions.encode_dataset_coordinates(dataset)
else:
variables = {k: v.copy(deep=False) for k, v in dataset._variables.items()}
attrs = dataset.attrs

check_encoding = set()
for k, enc in encoding.items():
Expand Down Expand Up @@ -1827,18 +1840,242 @@ def to_zarr(
"mode='r+'. To allow writing new variables, set mode='a'."
)

return write_to_zarr_store(
dataset,
zstore,
encoding,
compute,
chunkmanager_store_kwargs,
encode_coordinates=True,
)


def write_to_zarr_store(
dataset, store, encoding, compute, chunkmanager_store_kwargs, encode_coordinates
):
writer = ArrayWriter()
# TODO: figure out how to properly handle unlimited_dims
dump_to_store(dataset, zstore, writer, encoding=encoding)
dump_to_store(
dataset, store, writer, encoding=encoding, encode_coordinates=encode_coordinates
)
writes = writer.sync(
compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs
)

if compute:
_finalize_store(writes, zstore)
_finalize_store(writes, store)
else:
import dask

return dask.delayed(_finalize_store)(writes, zstore)
return dask.delayed(_finalize_store)(writes, store)
return store


def initialize_zarr(
ds: Dataset,
store: MutableMapping,
*,
region_dims: Iterable[Hashable] = tuple(),
mode: Literal["w", "w-"] = "w-",
zarr_version: int | None = None,
consolidated: bool | None = None,
encoding: dict | None = None,
**kwargs,
) -> Dataset:
"""
Initialize a Zarr store with metadata.

This function initializes a Zarr store with all indexed coordinate variables, and
metadata for every variable in the dataset.
If ``region_dims`` is specified, it will also
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a necessary change, but one plausible way to frame this is

  • default region_dims=[...]
  • ...and then there's no need for the if — we default to writing the minimum and returning the maximum.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting suggestion!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an edge case around scalars. We would write them if the user wants region writes, but not write them if they don't. So we need to distinguish between region_dims=(), region_dims=(...,) and region_dims=tuple(ds.dims). If those last two are identical then we can't handle the scalar case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK interesting point.

OOI, is this function helpful if we're not doing region writes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, on reflection — is there a case where we really don't want to write scalars? (genuine question, very possibly there is...)

1. Write variables that don't have any of ``region_dims``, and
2. Return a dataset with the remaining variables, which contain one or more of ``region_dims``.
Copy link
Collaborator

@max-sixty max-sixty Jan 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
2. Return a dataset with the remaining variables, which contain one or more of ``region_dims``.
2. Return a dataset with the remaining variables and indexes, which contain one or more of ``region_dims``.

...is that right?

Should we drop the non-region indexes?

My mental model of this (but it's from a few months ago):

  • For indexes that are part of the region, we want to keep them now that we have region="auto". When we use region="auto" in each process, it correctly won't rewrite those indexes
  • For other indexes, we want to drop them, because otherwise each process will incorrectly attempt to write them, and this can lead to conflicts

Copy link
Contributor Author

@dcherian dcherian Jan 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we use region="auto" in each process, it correctly won't rewrite those indexes

Surprising to me, but you are correct! #8589

Currently we keep any indexes needed for any variable that has dimensions that overlap with region_dims. So for the test case, there is a variable with dims (x, y). If region_dims==("x",), then the index y is also returned. This works fine with region="auto" but requires adding "y": slice(None) when explicitly specifying region. I think this is OK, if you're specifying region, explicitly specify the whole thing as region="auto" does.

For other indexes, we want to drop them, because otherwise each process will incorrectly attempt to write them, and this can lead to conflicts

We do end up dropping those indexes. The test case has a variable with dimensions ("z",) but the regions only contain "x" and/or "y". The z index ends up getting dropped.

This dataset can then be used for region writes in parallel.

Parameters
----------
ds : Dataset
Dataset to write.
store : MutableMapping or str)
Zarr store to write to.
region_dims : Iterable[Hashable], optional
An iterable of dimension names that will be passed to the ``region``
kwarg of ``to_zarr`` later.
mode : {'w', 'w-'}
Write mode for initializing the store.
- "w" means create (overwrite if exists);
- "w-" means create (fail if exists);
zarr_version : int or None, optional
The desired zarr spec version to target (currently 2 or 3). The
default of None will attempt to determine the zarr version from
``store`` when possible, otherwise defaulting to 2.
consolidated : bool, optional
If True, apply zarr's `consolidate_metadata` function to the store
after writing metadata and read existing stores with consolidated
metadata; if False, do not. The default (`consolidated=None`) means
write consolidated metadata and attempt to read consolidated
metadata for existing stores (falling back to non-consolidated).

When the experimental ``zarr_version=3``, ``consolidated`` must be
either be ``None`` or ``False``.
**kwargs
Passed on to to_zarr

Returns
-------
Dataset
Dataset containing variables with one or more ``region_dims``
dimensions. Use this for writing to the store in parallel later.

Raises
------
ValueError
"""
import zarr

from xarray.backends.zarr import add_array_to_store, encode_zarr_variable

if encoding is None:
encoding = {}

if "compute" in kwargs:
raise ValueError("The ``compute`` kwarg is not supported in `initialize_zarr`.")

return zstore
if not any(is_chunked_array(v._data) for v in ds._variables.values()):
raise ValueError("This function should be used with chunked Datasets.")

if mode not in ["w", "w-"]:
raise ValueError(
f"Only mode='w' or mode='w-' is allowed for initialize_zarr. Received mode={mode!r}"
)

if zarr_version is None:
# default to 2 if store doesn't specify it's version (e.g. a path)
zarr_version = int(getattr(store, "_store_version", 2))

# The design here is to write to an in-memory temporary store,
# and flush that to the actual `store`. This is a major improvement
# for V3 high-latency stores (e.g. cloud buckets)
if zarr_version == 2:
temp_store = zarr.MemoryStore()
elif zarr_version == 3:
temp_store = zarr.MemoryStoreV3()
if consolidated:
raise ValueError("Consolidating metadata is not supported by Zarr V3.")
else:
raise ValueError(f"Invalid zarr_version={zarr_version}.")

# Needed to handle `store` being a string path
store = zarr.hierarchy.normalize_store_arg(
store,
zarr_version=zarr_version,
mode=mode,
storage_options=kwargs.get("storage_options", None),
)
if mode == "w-":
# assert that the path does not already exist
zarr.open_group(
Comment on lines +1979 to +1980
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way?

store,
mode=mode,
storage_options=kwargs.get("storage_options", None),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obv don't change but FYI the None is default in python here

path=kwargs.get("group", None),
)

# Use this to open the group once with all the expected default options
# We will reuse xzstore.zarr_group later.
xtempstore = backends.ZarrStore.open_group(
temp_store,
mode="w", # always write to the temp store
zarr_version=zarr_version,
consolidated=False,
**kwargs,
)

# need to do this separately to get the "coordinates" attribute coorect
variables, attrs = conventions.encode_dataset_coordinates(ds)
ds = ds._replace(variables=variables, attrs=attrs)

all_variables = set(ds._variables)
# TODO: how do we work with the new index API?
index_vars = {dim for dim in ds.dims if dim in all_variables}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No expert, but maybe:
index_vars = set(ds.indexes.keys())

vars_without_region = {
key
for key in all_variables - index_vars
if (not (set(ds._variables[key].dims) & set(region_dims)))
}
chunked_vars_without_region = {
key for key in vars_without_region if is_chunked_array(ds._variables[key])
}

def extract_encoding(varnames: Iterable[Hashable]) -> dict:
return {k: v for k, v in encoding.items() if k in varnames}

# Always write index variables, and any in-memory variables without region dims
eager_write_vars = index_vars | (vars_without_region - chunked_vars_without_region)

write_to_zarr_store(
ds[eager_write_vars],
xtempstore,
encoding=extract_encoding(eager_write_vars),
compute=True,
chunkmanager_store_kwargs=kwargs.get("chunkmanager_store_kwargs", None),
encode_coordinates=False,
)

# Now initialize the arrays we have not written yet with metadata
# but skip any chunked vars without the region, these will get written later
vars_to_init = (all_variables - eager_write_vars) - chunked_vars_without_region
array_kwargs = {
key: kwargs[key]
for key in ["safe_chunks", "write_empty", "raise_on_invalid"]
if key in kwargs
}
array_kwargs.setdefault("write_empty", False)
if mode == "w":
array_kwargs.setdefault("overwrite", True)
enc = extract_encoding(vars_to_init)
for var in vars_to_init:
variable = ds._variables[var]
# duplicates dump_to_store
if var in enc:
variable.encoding = enc[var]
encoded = encode_zarr_variable(variable)
add_array_to_store(var, encoded, group=xtempstore.zarr_group, **array_kwargs)
max-sixty marked this conversation as resolved.
Show resolved Hide resolved

if zarr_version == 2 and consolidated in (True, None):
zarr.consolidate_metadata(temp_store)

# flush the temp store there at once
try:
store.setitems(temp_store) # type: ignore[attr-defined]
except AttributeError: # not all stores have setitems :(
store.update(temp_store)

# Return a Dataset that can be easily used for further region writes.
if region_dims:
# Write any variables that don't overlap with region dimensions
# Note that these are potentially quite big dask arrays, so we
# do not want to write these to the MemoryStore first.
if chunked_vars_without_region:
xstore = backends.ZarrStore.open_group(
store,
mode="a-", # append new variables, don't overwrite indexes
zarr_version=zarr_version,
consolidated=consolidated,
**kwargs,
)

write_to_zarr_store(
ds[tuple(chunked_vars_without_region)],
xstore,
encoding=extract_encoding(chunked_vars_without_region),
compute=True,
chunkmanager_store_kwargs=kwargs.get("chunkmanager_store_kwargs", None),
encode_coordinates=False,
)

to_drop = (eager_write_vars | chunked_vars_without_region) - index_vars
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If #8904 went in we wouldn't need to drop index_vars here.

return ds.drop_vars(to_drop)

else:
return ds
Loading
Loading