Skip to content

Commit

Permalink
More robust guess_can_open for netCDF4/scipy/h5netcdf entrypoints (py…
Browse files Browse the repository at this point in the history
…data#5296)

* More robust guess_can_open for netCDF4/scipy/h5netcdf entrypoints

The new version check magic numbers in files on disk, not just already
open file objects.

I've also added a bunch of unit-tests.

Fixes GH5295

* Fix failures and warning in test_backends.py

* format black
  • Loading branch information
shoyer authored May 14, 2021
1 parent 9e84d09 commit c6f2cf0
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 20 deletions.
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ Deprecations
Bug fixes
~~~~~~~~~

- Opening netCDF files from a path that doesn't end in ``.nc`` without supplying
an explicit ``engine`` works again (:issue:`5295`), fixing a bug introduced in
0.18.0.
By `Stephan Hoyer <https://github.com/shoyer>`_

Documentation
~~~~~~~~~~~~~
Expand Down
18 changes: 11 additions & 7 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import numpy as np

from ..core import indexing
from ..core.utils import FrozenDict, is_remote_uri, read_magic_number
from ..core.utils import (
FrozenDict,
is_remote_uri,
read_magic_number_from_file,
try_read_magic_number_from_file_or_path,
)
from ..core.variable import Variable
from .common import (
BACKEND_ENTRYPOINTS,
Expand Down Expand Up @@ -140,10 +145,10 @@ def open(
"try passing a path or file-like object"
)
elif isinstance(filename, io.IOBase):
magic_number = read_magic_number(filename)
magic_number = read_magic_number_from_file(filename)
if not magic_number.startswith(b"\211HDF\r\n\032\n"):
raise ValueError(
f"{magic_number} is not the signature of a valid netCDF file"
f"{magic_number} is not the signature of a valid netCDF4 file"
)

if format not in [None, "NETCDF4"]:
Expand Down Expand Up @@ -333,10 +338,9 @@ def close(self, **kwargs):

class H5netcdfBackendEntrypoint(BackendEntrypoint):
def guess_can_open(self, filename_or_obj):
try:
return read_magic_number(filename_or_obj).startswith(b"\211HDF\r\n\032\n")
except TypeError:
pass
magic_number = try_read_magic_number_from_file_or_path(filename_or_obj)
if magic_number is not None:
return magic_number.startswith(b"\211HDF\r\n\032\n")

try:
_, ext = os.path.splitext(filename_or_obj)
Expand Down
11 changes: 10 additions & 1 deletion xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from .. import coding
from ..coding.variables import pop_to
from ..core import indexing
from ..core.utils import FrozenDict, close_on_error, is_remote_uri
from ..core.utils import (
FrozenDict,
close_on_error,
is_remote_uri,
try_read_magic_number_from_path,
)
from ..core.variable import Variable
from .common import (
BACKEND_ENTRYPOINTS,
Expand Down Expand Up @@ -510,6 +515,10 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint):
def guess_can_open(self, filename_or_obj):
if isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj):
return True
magic_number = try_read_magic_number_from_path(filename_or_obj)
if magic_number is not None:
# netcdf 3 or HDF5
return magic_number.startswith((b"CDF", b"\211HDF\r\n\032\n"))
try:
_, ext = os.path.splitext(filename_or_obj)
except TypeError:
Expand Down
21 changes: 14 additions & 7 deletions xarray/backends/scipy_.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import gzip
import io
import os

import numpy as np

from ..core.indexing import NumpyIndexingAdapter
from ..core.utils import Frozen, FrozenDict, close_on_error, read_magic_number
from ..core.utils import (
Frozen,
FrozenDict,
close_on_error,
try_read_magic_number_from_file_or_path,
)
from ..core.variable import Variable
from .common import (
BACKEND_ENTRYPOINTS,
Expand Down Expand Up @@ -72,8 +78,6 @@ def __setitem__(self, key, value):


def _open_scipy_netcdf(filename, mode, mmap, version):
import gzip

# if the string ends with .gz, then gunzip and open as netcdf file
if isinstance(filename, str) and filename.endswith(".gz"):
try:
Expand Down Expand Up @@ -235,10 +239,13 @@ def close(self):

class ScipyBackendEntrypoint(BackendEntrypoint):
def guess_can_open(self, filename_or_obj):
try:
return read_magic_number(filename_or_obj).startswith(b"CDF")
except TypeError:
pass

magic_number = try_read_magic_number_from_file_or_path(filename_or_obj)
if magic_number is not None and magic_number.startswith(b"\x1f\x8b"):
with gzip.open(filename_or_obj) as f:
magic_number = try_read_magic_number_from_file_or_path(f)
if magic_number is not None:
return magic_number.startswith(b"CDF")

try:
_, ext = os.path.splitext(filename_or_obj)
Expand Down
28 changes: 26 additions & 2 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import io
import itertools
import os
import re
import warnings
from enum import Enum
Expand Down Expand Up @@ -646,7 +647,7 @@ def is_remote_uri(path: str) -> bool:
return bool(re.search(r"^[a-z][a-z0-9]*(\://|\:\:)", path))


def read_magic_number(filename_or_obj, count=8):
def read_magic_number_from_file(filename_or_obj, count=8) -> bytes:
# check byte header to determine file type
if isinstance(filename_or_obj, bytes):
magic_number = filename_or_obj[:count]
Expand All @@ -657,13 +658,36 @@ def read_magic_number(filename_or_obj, count=8):
"file-like object read/write pointer not at the start of the file, "
"please close and reopen, or use a context manager"
)
magic_number = filename_or_obj.read(count)
magic_number = filename_or_obj.read(count) # type: ignore
filename_or_obj.seek(0)
else:
raise TypeError(f"cannot read the magic number form {type(filename_or_obj)}")
return magic_number


def try_read_magic_number_from_path(pathlike, count=8) -> Optional[bytes]:
if isinstance(pathlike, str) or hasattr(pathlike, "__fspath__"):
path = os.fspath(pathlike)
try:
with open(path, "rb") as f:
return read_magic_number_from_file(f, count)
except (FileNotFoundError, TypeError):
pass
return None


def try_read_magic_number_from_file_or_path(
filename_or_obj, count=8
) -> Optional[bytes]:
magic_number = try_read_magic_number_from_path(filename_or_obj, count)
if magic_number is None:
try:
magic_number = read_magic_number_from_file(filename_or_obj, count)
except TypeError:
pass
return magic_number


def is_uniform_spaced(arr, **kwargs) -> bool:
"""Return True if values of an array are uniformly spaced and sorted.
Expand Down
100 changes: 97 additions & 3 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import contextlib
import gzip
import itertools
import math
import os.path
import pickle
import re
import shutil
import sys
import tempfile
Expand Down Expand Up @@ -30,9 +32,14 @@
save_mfdataset,
)
from xarray.backends.common import robust_getitem
from xarray.backends.h5netcdf_ import H5netcdfBackendEntrypoint
from xarray.backends.netcdf3 import _nc3_dtype_coercions
from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding
from xarray.backends.netCDF4_ import (
NetCDF4BackendEntrypoint,
_extract_nc4_variable_encoding,
)
from xarray.backends.pydap_ import PydapDataStore
from xarray.backends.scipy_ import ScipyBackendEntrypoint
from xarray.coding.variables import SerializationWarning
from xarray.conventions import encode_dataset_coordinates
from xarray.core import indexes, indexing
Expand Down Expand Up @@ -2771,7 +2778,7 @@ def test_open_badbytes(self):
with open_dataset(b"garbage", engine="netcdf4"):
pass
with pytest.raises(
ValueError, match=r"not the signature of a valid netCDF file"
ValueError, match=r"not the signature of a valid netCDF4 file"
):
with open_dataset(BytesIO(b"garbage"), engine="h5netcdf"):
pass
Expand Down Expand Up @@ -2817,7 +2824,11 @@ def test_open_fileobj(self):
with open(tmp_file, "rb") as f:
f.seek(8)
with pytest.raises(ValueError, match="cannot guess the engine"):
open_dataset(f)
with pytest.warns(
RuntimeWarning,
match=re.escape("'h5netcdf' fails while guessing"),
):
open_dataset(f)


@requires_h5netcdf
Expand Down Expand Up @@ -5161,3 +5172,86 @@ def test_chunking_consintency(chunks, tmp_path):

with xr.open_dataset(tmp_path / "test.nc", chunks=chunks) as actual:
xr.testing.assert_chunks_equal(actual, expected)


def _check_guess_can_open_and_open(entrypoint, obj, engine, expected):
assert entrypoint.guess_can_open(obj)
with open_dataset(obj, engine=engine) as actual:
assert_identical(expected, actual)


@requires_netCDF4
def test_netcdf4_entrypoint(tmp_path):
entrypoint = NetCDF4BackendEntrypoint()
ds = create_test_data()

path = tmp_path / "foo"
ds.to_netcdf(path, format="netcdf3_classic")
_check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds)
_check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds)

path = tmp_path / "bar"
ds.to_netcdf(path, format="netcdf4_classic")
_check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds)
_check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds)

assert entrypoint.guess_can_open("http://something/remote")
assert entrypoint.guess_can_open("something-local.nc")
assert entrypoint.guess_can_open("something-local.nc4")
assert entrypoint.guess_can_open("something-local.cdf")
assert not entrypoint.guess_can_open("not-found-and-no-extension")

path = tmp_path / "baz"
with open(path, "wb") as f:
f.write(b"not-a-netcdf-file")
assert not entrypoint.guess_can_open(path)


@requires_scipy
def test_scipy_entrypoint(tmp_path):
entrypoint = ScipyBackendEntrypoint()
ds = create_test_data()

path = tmp_path / "foo"
ds.to_netcdf(path, engine="scipy")
_check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds)
_check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds)
with open(path, "rb") as f:
_check_guess_can_open_and_open(entrypoint, f, engine="scipy", expected=ds)

contents = ds.to_netcdf(engine="scipy")
_check_guess_can_open_and_open(entrypoint, contents, engine="scipy", expected=ds)
_check_guess_can_open_and_open(
entrypoint, BytesIO(contents), engine="scipy", expected=ds
)

path = tmp_path / "foo.nc.gz"
with gzip.open(path, mode="wb") as f:
f.write(contents)
_check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds)
_check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds)

assert entrypoint.guess_can_open("something-local.nc")
assert entrypoint.guess_can_open("something-local.nc.gz")
assert not entrypoint.guess_can_open("not-found-and-no-extension")
assert not entrypoint.guess_can_open(b"not-a-netcdf-file")


@requires_h5netcdf
def test_h5netcdf_entrypoint(tmp_path):
entrypoint = H5netcdfBackendEntrypoint()
ds = create_test_data()

path = tmp_path / "foo"
ds.to_netcdf(path, engine="h5netcdf")
_check_guess_can_open_and_open(entrypoint, path, engine="h5netcdf", expected=ds)
_check_guess_can_open_and_open(
entrypoint, str(path), engine="h5netcdf", expected=ds
)
with open(path, "rb") as f:
_check_guess_can_open_and_open(entrypoint, f, engine="h5netcdf", expected=ds)

assert entrypoint.guess_can_open("something-local.nc")
assert entrypoint.guess_can_open("something-local.nc4")
assert entrypoint.guess_can_open("something-local.cdf")
assert not entrypoint.guess_can_open("not-found-and-no-extension")

0 comments on commit c6f2cf0

Please sign in to comment.