From c6f2cf05b137061fd3ef7fdebcc2b82977791692 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 14 May 2021 15:40:13 -0700 Subject: [PATCH] More robust guess_can_open for netCDF4/scipy/h5netcdf entrypoints (#5296) * More robust guess_can_open for netCDF4/scipy/h5netcdf entrypoints The new version check magic numbers in files on disk, not just already open file objects. I've also added a bunch of unit-tests. Fixes GH5295 * Fix failures and warning in test_backends.py * format black --- doc/whats-new.rst | 4 ++ xarray/backends/h5netcdf_.py | 18 +++--- xarray/backends/netCDF4_.py | 11 +++- xarray/backends/scipy_.py | 21 ++++--- xarray/core/utils.py | 28 +++++++++- xarray/tests/test_backends.py | 100 +++++++++++++++++++++++++++++++++- 6 files changed, 162 insertions(+), 20 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d3ab043e5f8..5bad23ad705 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -41,6 +41,10 @@ Deprecations Bug fixes ~~~~~~~~~ +- Opening netCDF files from a path that doesn't end in ``.nc`` without supplying + an explicit ``engine`` works again (:issue:`5295`), fixing a bug introduced in + 0.18.0. + By `Stephan Hoyer `_ Documentation ~~~~~~~~~~~~~ diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 9f744d0c1ef..a6e04fe7567 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -6,7 +6,12 @@ import numpy as np from ..core import indexing -from ..core.utils import FrozenDict, is_remote_uri, read_magic_number +from ..core.utils import ( + FrozenDict, + is_remote_uri, + read_magic_number_from_file, + try_read_magic_number_from_file_or_path, +) from ..core.variable import Variable from .common import ( BACKEND_ENTRYPOINTS, @@ -140,10 +145,10 @@ def open( "try passing a path or file-like object" ) elif isinstance(filename, io.IOBase): - magic_number = read_magic_number(filename) + magic_number = read_magic_number_from_file(filename) if not magic_number.startswith(b"\211HDF\r\n\032\n"): raise ValueError( - f"{magic_number} is not the signature of a valid netCDF file" + f"{magic_number} is not the signature of a valid netCDF4 file" ) if format not in [None, "NETCDF4"]: @@ -333,10 +338,9 @@ def close(self, **kwargs): class H5netcdfBackendEntrypoint(BackendEntrypoint): def guess_can_open(self, filename_or_obj): - try: - return read_magic_number(filename_or_obj).startswith(b"\211HDF\r\n\032\n") - except TypeError: - pass + magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) + if magic_number is not None: + return magic_number.startswith(b"\211HDF\r\n\032\n") try: _, ext = os.path.splitext(filename_or_obj) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 694b0d2fdd2..95e8943dacb 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -9,7 +9,12 @@ from .. import coding from ..coding.variables import pop_to from ..core import indexing -from ..core.utils import FrozenDict, close_on_error, is_remote_uri +from ..core.utils import ( + FrozenDict, + close_on_error, + is_remote_uri, + try_read_magic_number_from_path, +) from ..core.variable import Variable from .common import ( BACKEND_ENTRYPOINTS, @@ -510,6 +515,10 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint): def guess_can_open(self, filename_or_obj): if isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj): return True + magic_number = try_read_magic_number_from_path(filename_or_obj) + if magic_number is not None: + # netcdf 3 or HDF5 + return magic_number.startswith((b"CDF", b"\211HDF\r\n\032\n")) try: _, ext = os.path.splitext(filename_or_obj) except TypeError: diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 9c33b172639..7394770cbe8 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -1,10 +1,16 @@ +import gzip import io import os import numpy as np from ..core.indexing import NumpyIndexingAdapter -from ..core.utils import Frozen, FrozenDict, close_on_error, read_magic_number +from ..core.utils import ( + Frozen, + FrozenDict, + close_on_error, + try_read_magic_number_from_file_or_path, +) from ..core.variable import Variable from .common import ( BACKEND_ENTRYPOINTS, @@ -72,8 +78,6 @@ def __setitem__(self, key, value): def _open_scipy_netcdf(filename, mode, mmap, version): - import gzip - # if the string ends with .gz, then gunzip and open as netcdf file if isinstance(filename, str) and filename.endswith(".gz"): try: @@ -235,10 +239,13 @@ def close(self): class ScipyBackendEntrypoint(BackendEntrypoint): def guess_can_open(self, filename_or_obj): - try: - return read_magic_number(filename_or_obj).startswith(b"CDF") - except TypeError: - pass + + magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) + if magic_number is not None and magic_number.startswith(b"\x1f\x8b"): + with gzip.open(filename_or_obj) as f: + magic_number = try_read_magic_number_from_file_or_path(f) + if magic_number is not None: + return magic_number.startswith(b"CDF") try: _, ext = os.path.splitext(filename_or_obj) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 31ac43ed214..62b66278b24 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -4,6 +4,7 @@ import functools import io import itertools +import os import re import warnings from enum import Enum @@ -646,7 +647,7 @@ def is_remote_uri(path: str) -> bool: return bool(re.search(r"^[a-z][a-z0-9]*(\://|\:\:)", path)) -def read_magic_number(filename_or_obj, count=8): +def read_magic_number_from_file(filename_or_obj, count=8) -> bytes: # check byte header to determine file type if isinstance(filename_or_obj, bytes): magic_number = filename_or_obj[:count] @@ -657,13 +658,36 @@ def read_magic_number(filename_or_obj, count=8): "file-like object read/write pointer not at the start of the file, " "please close and reopen, or use a context manager" ) - magic_number = filename_or_obj.read(count) + magic_number = filename_or_obj.read(count) # type: ignore filename_or_obj.seek(0) else: raise TypeError(f"cannot read the magic number form {type(filename_or_obj)}") return magic_number +def try_read_magic_number_from_path(pathlike, count=8) -> Optional[bytes]: + if isinstance(pathlike, str) or hasattr(pathlike, "__fspath__"): + path = os.fspath(pathlike) + try: + with open(path, "rb") as f: + return read_magic_number_from_file(f, count) + except (FileNotFoundError, TypeError): + pass + return None + + +def try_read_magic_number_from_file_or_path( + filename_or_obj, count=8 +) -> Optional[bytes]: + magic_number = try_read_magic_number_from_path(filename_or_obj, count) + if magic_number is None: + try: + magic_number = read_magic_number_from_file(filename_or_obj, count) + except TypeError: + pass + return magic_number + + def is_uniform_spaced(arr, **kwargs) -> bool: """Return True if values of an array are uniformly spaced and sorted. diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 3e3d6e8b8d0..60eb5b924ca 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1,8 +1,10 @@ import contextlib +import gzip import itertools import math import os.path import pickle +import re import shutil import sys import tempfile @@ -30,9 +32,14 @@ save_mfdataset, ) from xarray.backends.common import robust_getitem +from xarray.backends.h5netcdf_ import H5netcdfBackendEntrypoint from xarray.backends.netcdf3 import _nc3_dtype_coercions -from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding +from xarray.backends.netCDF4_ import ( + NetCDF4BackendEntrypoint, + _extract_nc4_variable_encoding, +) from xarray.backends.pydap_ import PydapDataStore +from xarray.backends.scipy_ import ScipyBackendEntrypoint from xarray.coding.variables import SerializationWarning from xarray.conventions import encode_dataset_coordinates from xarray.core import indexes, indexing @@ -2771,7 +2778,7 @@ def test_open_badbytes(self): with open_dataset(b"garbage", engine="netcdf4"): pass with pytest.raises( - ValueError, match=r"not the signature of a valid netCDF file" + ValueError, match=r"not the signature of a valid netCDF4 file" ): with open_dataset(BytesIO(b"garbage"), engine="h5netcdf"): pass @@ -2817,7 +2824,11 @@ def test_open_fileobj(self): with open(tmp_file, "rb") as f: f.seek(8) with pytest.raises(ValueError, match="cannot guess the engine"): - open_dataset(f) + with pytest.warns( + RuntimeWarning, + match=re.escape("'h5netcdf' fails while guessing"), + ): + open_dataset(f) @requires_h5netcdf @@ -5161,3 +5172,86 @@ def test_chunking_consintency(chunks, tmp_path): with xr.open_dataset(tmp_path / "test.nc", chunks=chunks) as actual: xr.testing.assert_chunks_equal(actual, expected) + + +def _check_guess_can_open_and_open(entrypoint, obj, engine, expected): + assert entrypoint.guess_can_open(obj) + with open_dataset(obj, engine=engine) as actual: + assert_identical(expected, actual) + + +@requires_netCDF4 +def test_netcdf4_entrypoint(tmp_path): + entrypoint = NetCDF4BackendEntrypoint() + ds = create_test_data() + + path = tmp_path / "foo" + ds.to_netcdf(path, format="netcdf3_classic") + _check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds) + + path = tmp_path / "bar" + ds.to_netcdf(path, format="netcdf4_classic") + _check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds) + + assert entrypoint.guess_can_open("http://something/remote") + assert entrypoint.guess_can_open("something-local.nc") + assert entrypoint.guess_can_open("something-local.nc4") + assert entrypoint.guess_can_open("something-local.cdf") + assert not entrypoint.guess_can_open("not-found-and-no-extension") + + path = tmp_path / "baz" + with open(path, "wb") as f: + f.write(b"not-a-netcdf-file") + assert not entrypoint.guess_can_open(path) + + +@requires_scipy +def test_scipy_entrypoint(tmp_path): + entrypoint = ScipyBackendEntrypoint() + ds = create_test_data() + + path = tmp_path / "foo" + ds.to_netcdf(path, engine="scipy") + _check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds) + with open(path, "rb") as f: + _check_guess_can_open_and_open(entrypoint, f, engine="scipy", expected=ds) + + contents = ds.to_netcdf(engine="scipy") + _check_guess_can_open_and_open(entrypoint, contents, engine="scipy", expected=ds) + _check_guess_can_open_and_open( + entrypoint, BytesIO(contents), engine="scipy", expected=ds + ) + + path = tmp_path / "foo.nc.gz" + with gzip.open(path, mode="wb") as f: + f.write(contents) + _check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds) + + assert entrypoint.guess_can_open("something-local.nc") + assert entrypoint.guess_can_open("something-local.nc.gz") + assert not entrypoint.guess_can_open("not-found-and-no-extension") + assert not entrypoint.guess_can_open(b"not-a-netcdf-file") + + +@requires_h5netcdf +def test_h5netcdf_entrypoint(tmp_path): + entrypoint = H5netcdfBackendEntrypoint() + ds = create_test_data() + + path = tmp_path / "foo" + ds.to_netcdf(path, engine="h5netcdf") + _check_guess_can_open_and_open(entrypoint, path, engine="h5netcdf", expected=ds) + _check_guess_can_open_and_open( + entrypoint, str(path), engine="h5netcdf", expected=ds + ) + with open(path, "rb") as f: + _check_guess_can_open_and_open(entrypoint, f, engine="h5netcdf", expected=ds) + + assert entrypoint.guess_can_open("something-local.nc") + assert entrypoint.guess_can_open("something-local.nc4") + assert entrypoint.guess_can_open("something-local.cdf") + assert not entrypoint.guess_can_open("not-found-and-no-extension")