Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zarr v3 support #292

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ jobs:
- name: Test with pytest
shell: bash -l {0}
run: |
export ZARR_V3_EXPERIMENTAL_API=1
pytest -v --cov
6 changes: 5 additions & 1 deletion kerchunk/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def process_file(
extension=None,
inline_threshold=100,
primary_attr_to_group=False,
zarr_version=None,
):
"""
Create JSON references for a single FITS file as a zarr group
Expand All @@ -55,6 +56,9 @@ def process_file(
primary_attr_to_group: bool
Whether the output top-level group contains the attributes of the primary extension
(which often contains no data, just a general description)
zarr_version: int
The desired zarr spec version to target (currently 2 or 3). The default
of None will use the default zarr version.

Returns
-------
Expand All @@ -64,7 +68,7 @@ def process_file(

storage_options = storage_options or {}
out = {}
g = zarr.open(out)
g = zarr.open(out, zarr_version=zarr_version)

with fsspec.open(url, mode="rb", **storage_options) as f:
infile = fits.open(f, do_not_scale_image_data=True)
Expand Down
7 changes: 5 additions & 2 deletions kerchunk/grib2.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def scan_grib(
inline_threshold=100,
skip=0,
filter={},
zarr_version=None,
):
"""
Generate references for a GRIB2 file
Expand All @@ -116,7 +117,9 @@ def scan_grib(
the exact value or is in the given set, are processed.
E.g., the cf-style filter ``{'typeOfLevel': 'heightAboveGround', 'level': 2}``
only keeps messages where heightAboveGround==2.

zarr_version: int
The desired zarr spec version to target (currently 2 or 3). The default
of None will use the default zarr version.
Returns
-------

Expand Down Expand Up @@ -147,7 +150,7 @@ def scan_grib(
if good is False:
continue

z = zarr.open_group(store)
z = zarr.group(store=store, overwrite=True, zarr_version=zarr_version)
global_attrs = {
k: m[k] for k in cfgrib.dataset.GLOBAL_ATTRIBUTES_KEYS if k in m
}
Expand Down
6 changes: 5 additions & 1 deletion kerchunk/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
storage_options=None,
error="warn",
vlen_encode="embed",
zarr_version=None,
):

# Open HDF5 file in read mode...
Expand All @@ -85,14 +86,17 @@ def __init__(
else:
self.input_file = h5f
self.spec = spec
self.zarr_version = zarr_version
self.inline = inline_threshold
if vlen_encode not in ["embed", "null", "leave", "encode"]:
raise NotImplementedError
self.vlen = vlen_encode
self._h5f = h5py.File(self.input_file, mode="r")

self.store = {}
self._zroot = zarr.group(store=self.store, overwrite=True)
self._zroot = zarr.group(
store=self.store, overwrite=True, zarr_version=self.zarr_version
)

self._uri = url
self.error = error
Expand Down
64 changes: 48 additions & 16 deletions kerchunk/netCDF3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from operator import mul

import numpy as np
import zarr
from .utils import do_inline, _encode_for_JSON

try:
Expand Down Expand Up @@ -30,6 +31,7 @@ def __init__(
storage_options=None,
inline_threshold=100,
max_chunk_size=0,
zarr_version=None,
**kwargs,
):
"""
Expand All @@ -47,7 +49,10 @@ def __init__(
subchunking, and there is never subchunking for coordinate/dimension arrays.
E.g., if an array contains 10,000bytes, and this value is 6000, there will
be two output chunks, split on the biggest available dimension. [TBC]
args, kwargs: passed to scipy superclass ``scipy.io.netcdf.netcdf_file``
zarr_version: int
The desired zarr spec version to target (currently 2 or 3). The default
of None will use the default zarr version.
args, kwargs: passed to scipy superclass ``scipy.io.netcdf.netcdf_file``]
"""
assert kwargs.pop("mmap", False) is False
assert kwargs.pop("mode", "r") == "r"
Expand All @@ -58,6 +63,7 @@ def __init__(
self.chunks = {}
self.threshold = inline_threshold
self.max_chunk_size = max_chunk_size
self.zarr_version = zarr_version
self.out = {}
with fsspec.open(filename, **(storage_options or {})) as fp:
super().__init__(
Expand Down Expand Up @@ -150,10 +156,11 @@ def translate(self):
Parameters
----------
"""
import zarr

out = self.out
z = zarr.open(out, mode="w")
zroot = zarr.group(
store=self.out, overwrite=True, zarr_version=self.zarr_version
)

for dim, var in self.variables.items():
if dim in self.dimensions:
shape = self.dimensions[dim]
Expand All @@ -175,16 +182,25 @@ def translate(self):
fill = float(fill)
if fill is not None and var.data.dtype.kind == "i":
fill = int(fill)
arr = z.create_dataset(
arr = zroot.create_dataset(
name=dim,
shape=shape,
dtype=var.data.dtype,
fill_value=fill,
chunks=shape,
compression=None,
overwrite=True,
)
part = ".".join(["0"] * len(shape)) or "0"
out[f"{dim}/{part}"] = [self.filename] + [

if self.zarr_version == 3:
part = "/".join(["0"] * len(shape)) or "0"
key = f"data/root/{dim}/c{part}"
else:
part = ".".join(["0"] * len(shape)) or "0"

key = f"{dim}/{part}"

self.out[key] = [self.filename] + [
int(self.chunks[dim][0]),
int(self.chunks[dim][1]),
]
Expand Down Expand Up @@ -218,13 +234,14 @@ def translate(self):
fill = float(fill)
if fill is not None and base.kind == "i":
fill = int(fill)
arr = z.create_dataset(
arr = zroot.create_dataset(
name=name,
shape=shape,
dtype=base,
fill_value=fill,
chunks=(1,) + dtype.shape,
compression=None,
overwrite=True,
)
arr.attrs.update(
{
Expand All @@ -239,18 +256,33 @@ def translate(self):

arr.attrs["_ARRAY_DIMENSIONS"] = list(var.dimensions)

suffix = (
("." + ".".join("0" for _ in dtype.shape)) if dtype.shape else ""
)
if self.zarr_version == 3:
suffix = (
("/" + "/".join("0" for _ in dtype.shape))
if dtype.shape
else ""
)
else:
suffix = (
("." + ".".join("0" for _ in dtype.shape))
if dtype.shape
else ""
)
for i in range(outer_shape):
out[f"{name}/{i}{suffix}"] = [

if self.zarr_version == 3:
key = f"data/root/{name}/c{i}{suffix}"
else:
key = f"{name}/{i}{suffix}"

self.out[key] = [
self.filename,
int(offset + i * dt.itemsize),
int(dtype.itemsize),
]

offset += dtype.itemsize
z.attrs.update(
zroot.attrs.update(
{
k: v.decode() if isinstance(v, bytes) else str(v)
for k, v in self._attributes.items()
Expand All @@ -259,10 +291,10 @@ def translate(self):
)

if self.threshold > 0:
out = do_inline(out, self.threshold)
out = _encode_for_JSON(out)
self.out = do_inline(self.out, self.threshold)
self.out = _encode_for_JSON(self.out)

return {"version": 1, "refs": out}
return {"version": 1, "refs": self.out}


netcdf_recording_file = NetCDF3ToZarr
35 changes: 20 additions & 15 deletions kerchunk/tests/test_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
var = os.path.join(testdir, "variable_length_table.fits")


def test_ascii_table():
@pytest.mark.parametrize("zarr_version", [2, 3])
def test_ascii_table(zarr_version):
# this one directly hits a remote server - should cache?
url = "https://fits.gsfc.nasa.gov/samples/WFPC2u5780205r_c0fx.fits"
out = kerchunk.fits.process_file(url, extension=1)
out = kerchunk.fits.process_file(url, extension=1, zarr_version=zarr_version)
m = fsspec.get_mapper("reference://", fo=out, remote_protocol="https")
g = zarr.open(m)
g = zarr.open(m, zarr_version=zarr_version)
arr = g["u5780205r_cvt.c0h.tab"][:]
with fsspec.open(
"https://fits.gsfc.nasa.gov/samples/WFPC2u5780205r_c0fx.fits"
Expand All @@ -28,10 +29,11 @@ def test_ascii_table():
assert list(hdu.data.astype(arr.dtype) == arr) == [True, True, True, True]


def test_binary_table():
out = kerchunk.fits.process_file(btable, extension=1)
@pytest.mark.parametrize("zarr_version", [2, 3])
def test_binary_table(zarr_version):
out = kerchunk.fits.process_file(btable, extension=1, zarr_version=zarr_version)
m = fsspec.get_mapper("reference://", fo=out)
z = zarr.open(m)
z = zarr.open(m, zarr_version=zarr_version)
arr = z["1"]
with open(btable, "rb") as f:
hdul = fits.open(f)
Expand All @@ -45,38 +47,41 @@ def test_binary_table():
).all() # string come out as bytes


def test_cube():
out = kerchunk.fits.process_file(range_im)
@pytest.mark.parametrize("zarr_version", [2, 3])
def test_cube(zarr_version):
out = kerchunk.fits.process_file(range_im, zarr_version=zarr_version)
m = fsspec.get_mapper("reference://", fo=out)
z = zarr.open(m)
z = zarr.open(m, zarr_version=zarr_version)
arr = z["PRIMARY"]
with open(range_im, "rb") as f:
hdul = fits.open(f)
expected = hdul[0].data
assert (arr[:] == expected).all()


def test_with_class():
ftz = kerchunk.fits.FitsToZarr(range_im)
@pytest.mark.parametrize("zarr_version", [2, 3])
def test_with_class(zarr_version):
ftz = kerchunk.fits.FitsToZarr(range_im, zarr_version=zarr_version)
out = ftz.translate()
assert "fits" in repr(ftz)
m = fsspec.get_mapper("reference://", fo=out)
z = zarr.open(m)
z = zarr.open(m, zarr_version=zarr_version)
arr = z["PRIMARY"]
with open(range_im, "rb") as f:
hdul = fits.open(f)
expected = hdul[0].data
assert (arr[:] == expected).all()


def test_var():
@pytest.mark.parametrize("zarr_version", [2, 3])
def test_var(zarr_version):
data = fits.open(var)[1].data
expected = [_.tolist() for _ in data["var"]]

ftz = kerchunk.fits.FitsToZarr(var)
ftz = kerchunk.fits.FitsToZarr(var, zarr_version=zarr_version)
out = ftz.translate()
m = fsspec.get_mapper("reference://", fo=out)
z = zarr.open(m)
z = zarr.open(m, zarr_version=zarr_version)
arr = z["1"]
vars = [_.tolist() for _ in arr["var"]]

Expand Down
11 changes: 8 additions & 3 deletions kerchunk/tests/test_grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@
here = os.path.dirname(__file__)


def test_one():
@pytest.mark.parametrize("zarr_version", [2, 3])
def test_one(zarr_version):
# from https://dd.weather.gc.ca/model_gem_regional/10km/grib2/00/000
fn = os.path.join(here, "CMC_reg_DEPR_ISBL_10_ps10km_2022072000_P000.grib2")
out = scan_grib(fn)
out = scan_grib(fn, zarr_version=zarr_version)
ds = xr.open_dataset(
"reference://",
engine="zarr",
backend_kwargs={"consolidated": False, "storage_options": {"fo": out[0]}},
backend_kwargs={
"consolidated": False,
"zarr_version": zarr_version,
"storage_options": {"fo": out[0]},
},
)

assert ds.attrs["centre"] == "cwao"
Expand Down
18 changes: 15 additions & 3 deletions kerchunk/tests/test_hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,30 @@
here = osp.dirname(__file__)


def test_single():
@pytest.mark.parametrize("zarr_version", [2, 3])
def test_single(zarr_version):
"""Test creating references for a single HDF file"""
url = "s3://noaa-nwm-retro-v2.0-pds/full_physics/2017/201704010000.CHRTOUT_DOMAIN1.comp"
so = dict(anon=True, default_fill_cache=False, default_cache_type="none")
with fsspec.open(url, **so) as f:
h5chunks = SingleHdf5ToZarr(f, url, storage_options=so)
h5chunks = SingleHdf5ToZarr(
f, url, storage_options=so, zarr_version=zarr_version
)
test_dict = h5chunks.translate()

m = fsspec.get_mapper(
"reference://", fo=test_dict, remote_protocol="s3", remote_options=so
)
ds = xr.open_dataset(m, engine="zarr", backend_kwargs=dict(consolidated=False))

if zarr_version == 2:
assert ".zgroup" in test_dict["refs"]
elif zarr_version == 3:
assert "zarr.json" in test_dict["refs"]
assert "meta/root.group.json" in test_dict["refs"]

backend_kwargs = {"zarr_version": zarr_version, "consolidated": False}
# TODO: drop consolidated kwarg for v3 stores
ds = xr.open_dataset(m, engine="zarr", backend_kwargs=backend_kwargs)

with fsspec.open(url, **so) as f:
expected = xr.open_dataset(f, engine="h5netcdf")
Expand Down
Loading