From 2a433855789cd20d31f61917a01d0a26a0a9e91a Mon Sep 17 00:00:00 2001 From: alexamici Date: Mon, 18 Jan 2021 16:19:34 +0100 Subject: [PATCH 01/14] Remove the references to `_file_obj` outside low level code paths, change to `_close` (#4809) * Move from _file_obj object to _close function * Remove all references to _close outside of low level * Fix type hints * Cleanup code style * Fix non-trivial type hint problem * Revert adding the `close` argument and add a set_close instead * Remove helper class for an easier helper function + code style * Add set_close docstring * Code style * Revert changes in _replace to keep cose as an exception See: https://github.com/pydata/xarray/pull/4809/files#r557628298 * One more bit to revert * One more bit to revert * Add What's New entry * Use set_close setter * Apply suggestions from code review Co-authored-by: Stephan Hoyer * Rename user-visible argument * Sync wording in docstrings. Co-authored-by: Stephan Hoyer --- doc/whats-new.rst | 2 ++ xarray/backends/api.py | 25 +++++++++---------------- xarray/backends/apiv2.py | 2 +- xarray/backends/rasterio_.py | 2 +- xarray/backends/store.py | 3 +-- xarray/conventions.py | 6 +++--- xarray/core/common.py | 29 ++++++++++++++++++++++++----- xarray/core/dataarray.py | 5 +++-- xarray/core/dataset.py | 17 +++++++++-------- 9 files changed, 53 insertions(+), 38 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 88994a5bfc0..09bd56dbe94 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -108,6 +108,8 @@ Internal Changes By `Maximilian Roos `_. - Speed up attribute style access (e.g. ``ds.somevar`` instead of ``ds["somevar"]``) and tab completion in ipython (:issue:`4741`, :pull:`4742`). By `Richard Kleijn `_. +- Added the ``set_close`` method to ``Dataset`` and ``DataArray`` for beckends to specify how to voluntary release + all resources. (:pull:`#4809`), By `Alessandro Amici `_. .. _whats-new.0.16.2: diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 4958062a262..81314588784 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -522,7 +522,7 @@ def maybe_decode_store(store, chunks): else: ds2 = ds - ds2._file_obj = ds._file_obj + ds2.set_close(ds._close) return ds2 filename_or_obj = _normalize_path(filename_or_obj) @@ -701,7 +701,7 @@ def open_dataarray( else: (data_array,) = dataset.data_vars.values() - data_array._file_obj = dataset._file_obj + data_array.set_close(dataset._close) # Reset names if they were changed during saving # to ensure that we can 'roundtrip' perfectly @@ -715,17 +715,6 @@ def open_dataarray( return data_array -class _MultiFileCloser: - __slots__ = ("file_objs",) - - def __init__(self, file_objs): - self.file_objs = file_objs - - def close(self): - for f in self.file_objs: - f.close() - - def open_mfdataset( paths, chunks=None, @@ -918,14 +907,14 @@ def open_mfdataset( getattr_ = getattr datasets = [open_(p, **open_kwargs) for p in paths] - file_objs = [getattr_(ds, "_file_obj") for ds in datasets] + closers = [getattr_(ds, "_close") for ds in datasets] if preprocess is not None: datasets = [preprocess(ds) for ds in datasets] if parallel: # calling compute here will return the datasets/file_objs lists, # the underlying datasets will still be stored as dask arrays - datasets, file_objs = dask.compute(datasets, file_objs) + datasets, closers = dask.compute(datasets, closers) # Combine all datasets, closing them in case of a ValueError try: @@ -963,7 +952,11 @@ def open_mfdataset( ds.close() raise - combined._file_obj = _MultiFileCloser(file_objs) + def multi_file_closer(): + for closer in closers: + closer() + + combined.set_close(multi_file_closer) # read global attributes from the attrs_file or from the first dataset if attrs_file is not None: diff --git a/xarray/backends/apiv2.py b/xarray/backends/apiv2.py index 0f98291983d..d31fc9ea773 100644 --- a/xarray/backends/apiv2.py +++ b/xarray/backends/apiv2.py @@ -90,7 +90,7 @@ def _dataset_from_backend_dataset( **extra_tokens, ) - ds._file_obj = backend_ds._file_obj + ds.set_close(backend_ds._close) # Ensure source filename always stored in dataset object (GH issue #2550) if "source" not in ds.encoding: diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index a0500c7e1c2..c689c1e99d7 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -361,6 +361,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc result = result.chunk(chunks, name_prefix=name_prefix, token=token) # Make the file closeable - result._file_obj = manager + result.set_close(manager.close) return result diff --git a/xarray/backends/store.py b/xarray/backends/store.py index d314a9c3ca9..20fa13af202 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -19,7 +19,6 @@ def open_backend_dataset_store( decode_timedelta=None, ): vars, attrs = store.load() - file_obj = store encoding = store.get_encoding() vars, attrs, coord_names = conventions.decode_cf_variables( @@ -36,7 +35,7 @@ def open_backend_dataset_store( ds = Dataset(vars, attrs=attrs) ds = ds.set_coords(coord_names.intersection(vars)) - ds._file_obj = file_obj + ds.set_close(store.close) ds.encoding = encoding return ds diff --git a/xarray/conventions.py b/xarray/conventions.py index bb0b92c77a1..e33ae53b31d 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -576,12 +576,12 @@ def decode_cf( vars = obj._variables attrs = obj.attrs extra_coords = set(obj.coords) - file_obj = obj._file_obj + close = obj._close encoding = obj.encoding elif isinstance(obj, AbstractDataStore): vars, attrs = obj.load() extra_coords = set() - file_obj = obj + close = obj.close encoding = obj.get_encoding() else: raise TypeError("can only decode Dataset or DataStore objects") @@ -599,7 +599,7 @@ def decode_cf( ) ds = Dataset(vars, attrs=attrs) ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars)) - ds._file_obj = file_obj + ds.set_close(close) ds.encoding = encoding return ds diff --git a/xarray/core/common.py b/xarray/core/common.py index 283114770cf..a69ba03a7a4 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -11,6 +11,7 @@ Iterator, List, Mapping, + Optional, Tuple, TypeVar, Union, @@ -330,7 +331,9 @@ def get_squeeze_dims( class DataWithCoords(SupportsArithmetic, AttrAccessMixin): """Shared base class for Dataset and DataArray.""" - __slots__ = () + _close: Optional[Callable[[], None]] + + __slots__ = ("_close",) _rolling_exp_cls = RollingExp @@ -1263,11 +1266,27 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): return ops.where_method(self, cond, other) + def set_close(self, close: Optional[Callable[[], None]]) -> None: + """Register the function that releases any resources linked to this object. + + This method controls how xarray cleans up resources associated + with this object when the ``.close()`` method is called. It is mostly + intended for backend developers and it is rarely needed by regular + end-users. + + Parameters + ---------- + close : callable + The function that when called like ``close()`` releases + any resources linked to this object. + """ + self._close = close + def close(self: Any) -> None: - """Close any files linked to this object""" - if self._file_obj is not None: - self._file_obj.close() - self._file_obj = None + """Release any resources linked to this object.""" + if self._close is not None: + self._close() + self._close = None def isnull(self, keep_attrs: bool = None): """Test each value in the array for whether it is a missing value. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 6fdda8fc418..e13ea44baad 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -344,6 +344,7 @@ class DataArray(AbstractArray, DataWithCoords): _cache: Dict[str, Any] _coords: Dict[Any, Variable] + _close: Optional[Callable[[], None]] _indexes: Optional[Dict[Hashable, pd.Index]] _name: Optional[Hashable] _variable: Variable @@ -351,7 +352,7 @@ class DataArray(AbstractArray, DataWithCoords): __slots__ = ( "_cache", "_coords", - "_file_obj", + "_close", "_indexes", "_name", "_variable", @@ -421,7 +422,7 @@ def __init__( # public interface. self._indexes = indexes - self._file_obj = None + self._close = None def _replace( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7edc2fab067..136edffb202 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -636,6 +636,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): _coord_names: Set[Hashable] _dims: Dict[Hashable, int] _encoding: Optional[Dict[Hashable, Any]] + _close: Optional[Callable[[], None]] _indexes: Optional[Dict[Hashable, pd.Index]] _variables: Dict[Hashable, Variable] @@ -645,7 +646,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): "_coord_names", "_dims", "_encoding", - "_file_obj", + "_close", "_indexes", "_variables", "__weakref__", @@ -687,7 +688,7 @@ def __init__( ) self._attrs = dict(attrs) if attrs is not None else None - self._file_obj = None + self._close = None self._encoding = None self._variables = variables self._coord_names = coord_names @@ -703,7 +704,7 @@ def load_store(cls, store, decoder=None) -> "Dataset": if decoder: variables, attributes = decoder(variables, attributes) obj = cls(variables, attrs=attributes) - obj._file_obj = store + obj.set_close(store.close) return obj @property @@ -876,7 +877,7 @@ def __dask_postcompute__(self): self._attrs, self._indexes, self._encoding, - self._file_obj, + self._close, ) return self._dask_postcompute, args @@ -896,7 +897,7 @@ def __dask_postpersist__(self): self._attrs, self._indexes, self._encoding, - self._file_obj, + self._close, ) return self._dask_postpersist, args @@ -1007,7 +1008,7 @@ def _construct_direct( attrs=None, indexes=None, encoding=None, - file_obj=None, + close=None, ): """Shortcut around __init__ for internal use when we want to skip costly validation @@ -1020,7 +1021,7 @@ def _construct_direct( obj._dims = dims obj._indexes = indexes obj._attrs = attrs - obj._file_obj = file_obj + obj._close = close obj._encoding = encoding return obj @@ -2122,7 +2123,7 @@ def isel( attrs=self._attrs, indexes=indexes, encoding=self._encoding, - file_obj=self._file_obj, + close=self._close, ) def _isel_fancy( From ba42c08af9afbd9e79d47bda404bf4a92a7314a0 Mon Sep 17 00:00:00 2001 From: alexamici Date: Mon, 18 Jan 2021 16:21:15 +0100 Subject: [PATCH 02/14] Fix RST. --- doc/whats-new.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 09bd56dbe94..e873a76cab0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -39,8 +39,8 @@ Breaking changes always be set such that ``int64`` values can be used. In the past, no units finer than "seconds" were chosen, which would sometimes mean that ``float64`` values were required, which would lead to inaccurate I/O round-trips. -- remove deprecated ``autoclose`` kwargs from :py:func:`open_dataset` (:pull: `4725`). - By `Aureliana Barghini `_ +- remove deprecated ``autoclose`` kwargs from :py:func:`open_dataset` (:pull:`4725`). + By `Aureliana Barghini `_. New Features From 295606707a0464cd13727794a979f5b709cd92a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Mesejo-Le=C3=B3n?= Date: Tue, 19 Jan 2021 00:59:08 +0100 Subject: [PATCH 03/14] Add drop_isel (#4819) * Closes #4658 - Use get_index(dim) in drop_sel - Add drop_isel * address issues in PR * extract dict creation out of the loop --- doc/api.rst | 2 + doc/whats-new.rst | 1 + xarray/core/dataarray.py | 22 +++++++++++ xarray/core/dataset.py | 67 +++++++++++++++++++++++++++++++++- xarray/tests/test_dataarray.py | 6 +++ xarray/tests/test_dataset.py | 36 +++++++++++++++++- 6 files changed, 131 insertions(+), 3 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index ceab7dcc976..9cb02441d37 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -126,6 +126,7 @@ Indexing Dataset.isel Dataset.sel Dataset.drop_sel + Dataset.drop_isel Dataset.head Dataset.tail Dataset.thin @@ -307,6 +308,7 @@ Indexing DataArray.isel DataArray.sel DataArray.drop_sel + DataArray.drop_isel DataArray.head DataArray.tail DataArray.thin diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e873a76cab0..16b0cbf4ea1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -80,6 +80,7 @@ Bug fixes - Expand user directory paths (e.g. ``~/``) in :py:func:`open_mfdataset` and :py:meth:`Dataset.to_zarr` (:issue:`4783`, :pull:`4795`). By `Julien Seguinot `_. +- Add :py:meth:`Dataset.drop_isel` and :py:meth:`DataArray.drop_isel` (:issue:`4658`, :pull:`4819`). By `Daniel Mesejo `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e13ea44baad..f062b70aac1 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2248,6 +2248,28 @@ def drop_sel( ds = self._to_temp_dataset().drop_sel(labels, errors=errors) return self._from_temp_dataset(ds) + def drop_isel(self, indexers=None, **indexers_kwargs): + """Drop index positions from this DataArray. + + Parameters + ---------- + indexers : mapping of hashable to Any + Index locations to drop + **indexers_kwargs : {dim: position, ...}, optional + The keyword arguments form of ``dim`` and ``positions`` + + Returns + ------- + dropped : DataArray + + Raises + ------ + IndexError + """ + dataset = self._to_temp_dataset() + dataset = dataset.drop_isel(indexers=indexers, **indexers_kwargs) + return self._from_temp_dataset(dataset) + def dropna( self, dim: Hashable, how: str = "any", thresh: int = None ) -> "DataArray": diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 136edffb202..8954ebfcc38 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4054,13 +4054,78 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): labels_for_dim = [labels_for_dim] labels_for_dim = np.asarray(labels_for_dim) try: - index = self.indexes[dim] + index = self.get_index(dim) except KeyError: raise ValueError("dimension %r does not have coordinate labels" % dim) new_index = index.drop(labels_for_dim, errors=errors) ds = ds.loc[{dim: new_index}] return ds + def drop_isel(self, indexers=None, **indexers_kwargs): + """Drop index positions from this Dataset. + + Parameters + ---------- + indexers : mapping of hashable to Any + Index locations to drop + **indexers_kwargs : {dim: position, ...}, optional + The keyword arguments form of ``dim`` and ``positions`` + + Returns + ------- + dropped : Dataset + + Raises + ------ + IndexError + + Examples + -------- + >>> data = np.arange(6).reshape(2, 3) + >>> labels = ["a", "b", "c"] + >>> ds = xr.Dataset({"A": (["x", "y"], data), "y": labels}) + >>> ds + + Dimensions: (x: 2, y: 3) + Coordinates: + * y (y) >> ds.drop_isel(y=[0, 2]) + + Dimensions: (x: 2, y: 1) + Coordinates: + * y (y) >> ds.drop_isel(y=1) + + Dimensions: (x: 2, y: 2) + Coordinates: + * y (y) "Dataset": diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 3ead427e22e..afb234029dc 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2327,6 +2327,12 @@ def test_drop_index_labels(self): with pytest.warns(DeprecationWarning): arr.drop([0, 1, 3], dim="y", errors="ignore") + def test_drop_index_positions(self): + arr = DataArray(np.random.randn(2, 3), dims=["x", "y"]) + actual = arr.drop_sel(y=[0, 1]) + expected = arr[:, 2:] + assert_identical(actual, expected) + def test_dropna(self): x = np.random.randn(4, 4) x[::2, 0] = np.nan diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index bd1938455b1..f71b8ec7741 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2371,8 +2371,12 @@ def test_drop_index_labels(self): data.drop(DataArray(["a", "b", "c"]), dim="x", errors="ignore") assert_identical(expected, actual) - with raises_regex(ValueError, "does not have coordinate labels"): - data.drop_sel(y=1) + actual = data.drop_sel(y=[1]) + expected = data.isel(y=[0, 2]) + assert_identical(expected, actual) + + with raises_regex(KeyError, "not found in axis"): + data.drop_sel(x=0) def test_drop_labels_by_keyword(self): data = Dataset( @@ -2410,6 +2414,34 @@ def test_drop_labels_by_keyword(self): with pytest.raises(ValueError): data.drop(dim="x", x="a") + def test_drop_labels_by_position(self): + data = Dataset( + {"A": (["x", "y"], np.random.randn(2, 6)), "x": ["a", "b"], "y": range(6)} + ) + # Basic functionality. + assert len(data.coords["x"]) == 2 + + actual = data.drop_isel(x=0) + expected = data.drop_sel(x="a") + assert_identical(expected, actual) + + actual = data.drop_isel(x=[0]) + expected = data.drop_sel(x=["a"]) + assert_identical(expected, actual) + + actual = data.drop_isel(x=[0, 1]) + expected = data.drop_sel(x=["a", "b"]) + assert_identical(expected, actual) + assert actual.coords["x"].size == 0 + + actual = data.drop_isel(x=[0, 1], y=range(0, 6, 2)) + expected = data.drop_sel(x=["a", "b"], y=range(0, 6, 2)) + assert_identical(expected, actual) + assert actual.coords["x"].size == 0 + + with pytest.raises(KeyError): + data.drop_isel(z=1) + def test_drop_dims(self): data = xr.Dataset( { From 7dbbdcafed7f796ab77039ff797bcd31d9185903 Mon Sep 17 00:00:00 2001 From: aurghs <35919497+aurghs@users.noreply.github.com> Date: Tue, 19 Jan 2021 11:10:25 +0100 Subject: [PATCH 04/14] Bugfix in list_engine (#4811) * fix list_engine * fix store engine and netcdf4 * reve * revert changes in guess_engine * add resister of backend if dependencies aere instralled * style mypy * fix import * use import instead of importlib * black * replace ImportError with ModuleNotFoundError * fix typo * fix typos * remove else * Revert remove imports inside backends functions * Revert remove imports inside cfgrib * modify check on imports inside the backends * remove not used import --- xarray/backends/cfgrib_.py | 20 ++++++++++++++++++-- xarray/backends/common.py | 4 ++++ xarray/backends/h5netcdf_.py | 20 ++++++++++++++++---- xarray/backends/netCDF4_.py | 15 +++++++++++++-- xarray/backends/plugins.py | 24 +----------------------- xarray/backends/pseudonetcdf_.py | 20 ++++++++++++++++++-- xarray/backends/pydap_.py | 20 ++++++++++++++++++-- xarray/backends/pynio_.py | 20 ++++++++++++++++++-- xarray/backends/scipy_.py | 20 +++++++++++++++++--- xarray/backends/store.py | 5 ++++- xarray/backends/zarr.py | 15 +++++++++++++-- xarray/tests/test_plugins.py | 2 +- 12 files changed, 141 insertions(+), 44 deletions(-) diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py index d4933e370c7..4a0ac7d67f9 100644 --- a/xarray/backends/cfgrib_.py +++ b/xarray/backends/cfgrib_.py @@ -5,10 +5,23 @@ from ..core import indexing from ..core.utils import Frozen, FrozenDict, close_on_error from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray, BackendEntrypoint +from .common import ( + BACKEND_ENTRYPOINTS, + AbstractDataStore, + BackendArray, + BackendEntrypoint, +) from .locks import SerializableLock, ensure_lock from .store import open_backend_dataset_store +try: + import cfgrib + + has_cfgrib = True +except ModuleNotFoundError: + has_cfgrib = False + + # FIXME: Add a dedicated lock, even if ecCodes is supposed to be thread-safe # in most circumstances. See: # https://confluence.ecmwf.int/display/ECC/Frequently+Asked+Questions @@ -38,7 +51,6 @@ class CfGribDataStore(AbstractDataStore): """ def __init__(self, filename, lock=None, **backend_kwargs): - import cfgrib if lock is None: lock = ECCODES_LOCK @@ -129,3 +141,7 @@ def open_backend_dataset_cfgrib( cfgrib_backend = BackendEntrypoint( open_dataset=open_backend_dataset_cfgrib, guess_can_open=guess_can_open_cfgrib ) + + +if has_cfgrib: + BACKEND_ENTRYPOINTS["cfgrib"] = cfgrib_backend diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 72a63957662..adb70658fab 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -1,6 +1,7 @@ import logging import time import traceback +from typing import Dict import numpy as np @@ -349,3 +350,6 @@ def __init__(self, open_dataset, open_dataset_parameters=None, guess_can_open=No self.open_dataset = open_dataset self.open_dataset_parameters = open_dataset_parameters self.guess_can_open = guess_can_open + + +BACKEND_ENTRYPOINTS: Dict[str, BackendEntrypoint] = {} diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index b2996369ee7..562600de4b6 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -8,7 +8,12 @@ from ..core import indexing from ..core.utils import FrozenDict, is_remote_uri, read_magic_number from ..core.variable import Variable -from .common import BackendEntrypoint, WritableCFDataStore, find_root_and_group +from .common import ( + BACKEND_ENTRYPOINTS, + BackendEntrypoint, + WritableCFDataStore, + find_root_and_group, +) from .file_manager import CachingFileManager, DummyFileManager from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from .netCDF4_ import ( @@ -20,6 +25,13 @@ ) from .store import open_backend_dataset_store +try: + import h5netcdf + + has_h5netcdf = True +except ModuleNotFoundError: + has_h5netcdf = False + class H5NetCDFArrayWrapper(BaseNetCDF4Array): def get_array(self, needs_lock=True): @@ -85,8 +97,6 @@ class H5NetCDFStore(WritableCFDataStore): def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=False): - import h5netcdf - if isinstance(manager, (h5netcdf.File, h5netcdf.Group)): if group is None: root, group = find_root_and_group(manager) @@ -122,7 +132,6 @@ def open( invalid_netcdf=None, phony_dims=None, ): - import h5netcdf if isinstance(filename, bytes): raise ValueError( @@ -375,3 +384,6 @@ def open_backend_dataset_h5netcdf( h5netcdf_backend = BackendEntrypoint( open_dataset=open_backend_dataset_h5netcdf, guess_can_open=guess_can_open_h5netcdf ) + +if has_h5netcdf: + BACKEND_ENTRYPOINTS["h5netcdf"] = h5netcdf_backend diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 0e35270ea9a..5bb4eec837b 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -12,6 +12,7 @@ from ..core.utils import FrozenDict, close_on_error, is_remote_uri from ..core.variable import Variable from .common import ( + BACKEND_ENTRYPOINTS, BackendArray, BackendEntrypoint, WritableCFDataStore, @@ -23,6 +24,14 @@ from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable from .store import open_backend_dataset_store +try: + import netCDF4 + + has_netcdf4 = True +except ModuleNotFoundError: + has_netcdf4 = False + + # This lookup table maps from dtype.byteorder to a readable endian # string used by netCDF4. _endian_lookup = {"=": "native", ">": "big", "<": "little", "|": "native"} @@ -298,7 +307,6 @@ class NetCDF4DataStore(WritableCFDataStore): def __init__( self, manager, group=None, mode=None, lock=NETCDF4_PYTHON_LOCK, autoclose=False ): - import netCDF4 if isinstance(manager, netCDF4.Dataset): if group is None: @@ -335,7 +343,6 @@ def open( lock_maker=None, autoclose=False, ): - import netCDF4 if isinstance(filename, pathlib.Path): filename = os.fspath(filename) @@ -563,3 +570,7 @@ def open_backend_dataset_netcdf4( netcdf4_backend = BackendEntrypoint( open_dataset=open_backend_dataset_netcdf4, guess_can_open=guess_can_open_netcdf4 ) + + +if has_netcdf4: + BACKEND_ENTRYPOINTS["netcdf4"] = netcdf4_backend diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index d5799a78f91..6d3ec7e7da5 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -2,33 +2,11 @@ import inspect import itertools import logging -import typing as T import warnings import pkg_resources -from .cfgrib_ import cfgrib_backend -from .common import BackendEntrypoint -from .h5netcdf_ import h5netcdf_backend -from .netCDF4_ import netcdf4_backend -from .pseudonetcdf_ import pseudonetcdf_backend -from .pydap_ import pydap_backend -from .pynio_ import pynio_backend -from .scipy_ import scipy_backend -from .store import store_backend -from .zarr import zarr_backend - -BACKEND_ENTRYPOINTS: T.Dict[str, BackendEntrypoint] = { - "store": store_backend, - "netcdf4": netcdf4_backend, - "h5netcdf": h5netcdf_backend, - "scipy": scipy_backend, - "pseudonetcdf": pseudonetcdf_backend, - "zarr": zarr_backend, - "cfgrib": cfgrib_backend, - "pydap": pydap_backend, - "pynio": pynio_backend, -} +from .common import BACKEND_ENTRYPOINTS def remove_duplicates(backend_entrypoints): diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index d9128d1d503..c2bfd519bed 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -3,11 +3,24 @@ from ..core import indexing from ..core.utils import Frozen, FrozenDict, close_on_error from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray, BackendEntrypoint +from .common import ( + BACKEND_ENTRYPOINTS, + AbstractDataStore, + BackendArray, + BackendEntrypoint, +) from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock from .store import open_backend_dataset_store +try: + from PseudoNetCDF import pncopen + + has_pseudonetcdf = True +except ModuleNotFoundError: + has_pseudonetcdf = False + + # psuedonetcdf can invoke netCDF libraries internally PNETCDF_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK]) @@ -40,7 +53,6 @@ class PseudoNetCDFDataStore(AbstractDataStore): @classmethod def open(cls, filename, lock=None, mode=None, **format_kwargs): - from PseudoNetCDF import pncopen keywords = {"kwargs": format_kwargs} # only include mode if explicitly passed @@ -138,3 +150,7 @@ def open_backend_dataset_pseudonetcdf( open_dataset=open_backend_dataset_pseudonetcdf, open_dataset_parameters=open_dataset_parameters, ) + + +if has_pseudonetcdf: + BACKEND_ENTRYPOINTS["pseudonetcdf"] = pseudonetcdf_backend diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 4995045a739..c5ce943a10a 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -4,9 +4,22 @@ from ..core.pycompat import integer_types from ..core.utils import Frozen, FrozenDict, close_on_error, is_dict_like, is_remote_uri from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray, BackendEntrypoint, robust_getitem +from .common import ( + BACKEND_ENTRYPOINTS, + AbstractDataStore, + BackendArray, + BackendEntrypoint, + robust_getitem, +) from .store import open_backend_dataset_store +try: + import pydap.client + + has_pydap = True +except ModuleNotFoundError: + has_pydap = False + class PydapArrayWrapper(BackendArray): def __init__(self, array): @@ -74,7 +87,6 @@ def __init__(self, ds): @classmethod def open(cls, url, session=None): - import pydap.client ds = pydap.client.open_url(url, session=session) return cls(ds) @@ -133,3 +145,7 @@ def open_backend_dataset_pydap( pydap_backend = BackendEntrypoint( open_dataset=open_backend_dataset_pydap, guess_can_open=guess_can_open_pydap ) + + +if has_pydap: + BACKEND_ENTRYPOINTS["pydap"] = pydap_backend diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index dc6c47935e8..261daa69880 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -3,11 +3,24 @@ from ..core import indexing from ..core.utils import Frozen, FrozenDict, close_on_error from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray, BackendEntrypoint +from .common import ( + BACKEND_ENTRYPOINTS, + AbstractDataStore, + BackendArray, + BackendEntrypoint, +) from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, SerializableLock, combine_locks, ensure_lock from .store import open_backend_dataset_store +try: + import Nio + + has_pynio = True +except ModuleNotFoundError: + has_pynio = False + + # PyNIO can invoke netCDF libraries internally # Add a dedicated lock just in case NCL as well isn't thread-safe. NCL_LOCK = SerializableLock() @@ -45,7 +58,6 @@ class NioDataStore(AbstractDataStore): """Store for accessing datasets via PyNIO""" def __init__(self, filename, mode="r", lock=None, **kwargs): - import Nio if lock is None: lock = PYNIO_LOCK @@ -119,3 +131,7 @@ def open_backend_dataset_pynio( pynio_backend = BackendEntrypoint(open_dataset=open_backend_dataset_pynio) + + +if has_pynio: + BACKEND_ENTRYPOINTS["pynio"] = pynio_backend diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 873a91f9c07..df51d07d686 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -6,12 +6,24 @@ from ..core.indexing import NumpyIndexingAdapter from ..core.utils import Frozen, FrozenDict, close_on_error, read_magic_number from ..core.variable import Variable -from .common import BackendArray, BackendEntrypoint, WritableCFDataStore +from .common import ( + BACKEND_ENTRYPOINTS, + BackendArray, + BackendEntrypoint, + WritableCFDataStore, +) from .file_manager import CachingFileManager, DummyFileManager from .locks import ensure_lock, get_write_lock from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name from .store import open_backend_dataset_store +try: + import scipy.io + + has_scipy = True +except ModuleNotFoundError: + has_scipy = False + def _decode_string(s): if isinstance(s, bytes): @@ -61,8 +73,6 @@ def __setitem__(self, key, value): def _open_scipy_netcdf(filename, mode, mmap, version): import gzip - import scipy.io - # if the string ends with .gz, then gunzip and open as netcdf file if isinstance(filename, str) and filename.endswith(".gz"): try: @@ -271,3 +281,7 @@ def open_backend_dataset_scipy( scipy_backend = BackendEntrypoint( open_dataset=open_backend_dataset_scipy, guess_can_open=guess_can_open_scipy ) + + +if has_scipy: + BACKEND_ENTRYPOINTS["scipy"] = scipy_backend diff --git a/xarray/backends/store.py b/xarray/backends/store.py index 20fa13af202..66fca0d39c3 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -1,6 +1,6 @@ from .. import conventions from ..core.dataset import Dataset -from .common import AbstractDataStore, BackendEntrypoint +from .common import BACKEND_ENTRYPOINTS, AbstractDataStore, BackendEntrypoint def guess_can_open_store(store_spec): @@ -44,3 +44,6 @@ def open_backend_dataset_store( store_backend = BackendEntrypoint( open_dataset=open_backend_dataset_store, guess_can_open=guess_can_open_store ) + + +BACKEND_ENTRYPOINTS["store"] = store_backend diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 3b4b3a3d9d5..ceeb23cac9b 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -9,6 +9,7 @@ from ..core.utils import FrozenDict, HiddenKeyDict, close_on_error from ..core.variable import Variable from .common import ( + BACKEND_ENTRYPOINTS, AbstractWritableDataStore, BackendArray, BackendEntrypoint, @@ -16,6 +17,14 @@ ) from .store import open_backend_dataset_store +try: + import zarr + + has_zarr = True +except ModuleNotFoundError: + has_zarr = False + + # need some special secret attributes to tell us the dimensions DIMENSION_KEY = "_ARRAY_DIMENSIONS" @@ -289,7 +298,6 @@ def open_group( append_dim=None, write_region=None, ): - import zarr # zarr doesn't support pathlib.Path objects yet. zarr-python#601 if isinstance(store, pathlib.Path): @@ -409,7 +417,6 @@ def store( dimension on which the zarray will be appended only needed in append mode """ - import zarr existing_variables = { vn for vn in variables if _encode_variable_name(vn) in self.ds @@ -705,3 +712,7 @@ def open_backend_dataset_zarr( zarr_backend = BackendEntrypoint(open_dataset=open_backend_dataset_zarr) + + +if has_zarr: + BACKEND_ENTRYPOINTS["zarr"] = zarr_backend diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py index 110ef47209f..38ebce6da1a 100644 --- a/xarray/tests/test_plugins.py +++ b/xarray/tests/test_plugins.py @@ -92,7 +92,7 @@ def test_set_missing_parameters_raise_error(): with pytest.raises(TypeError): plugins.set_missing_parameters({"engine": backend}) - backend = plugins.BackendEntrypoint( + backend = common.BackendEntrypoint( dummy_open_dataset_kwargs, ("filename_or_obj", "decoder") ) plugins.set_missing_parameters({"engine": backend}) From 93ea177bdd49e205047a1416c2342fb645afafa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Mesejo-Le=C3=B3n?= Date: Wed, 20 Jan 2021 05:12:06 +0100 Subject: [PATCH 05/14] fix issues in drop_sel and drop_isel (#4828) --- xarray/core/dataset.py | 18 +++++++++++++----- xarray/tests/test_dataarray.py | 2 +- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 8954ebfcc38..874e26ff465 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4021,9 +4021,17 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): Examples -------- - >>> data = np.random.randn(2, 3) + >>> data = np.arange(6).reshape(2, 3) >>> labels = ["a", "b", "c"] >>> ds = xr.Dataset({"A": (["x", "y"], data), "y": labels}) + >>> ds + + Dimensions: (x: 2, y: 3) + Coordinates: + * y (y) >> ds.drop_sel(y=["a", "c"]) Dimensions: (x: 2, y: 1) @@ -4031,7 +4039,7 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): * y (y) >> ds.drop_sel(y="b") Dimensions: (x: 2, y: 2) @@ -4039,12 +4047,12 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): * y (y) Date: Wed, 20 Jan 2021 23:24:04 -0800 Subject: [PATCH 06/14] Move skip ci instructions to contributing guide (#4829) --- .github/PULL_REQUEST_TEMPLATE.md | 8 -------- doc/contributing.rst | 1 + 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 09ef053bb39..c7ea19a53cb 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -5,11 +5,3 @@ - [ ] Passes `pre-commit run --all-files` - [ ] User visible changes (including notable bug fixes) are documented in `whats-new.rst` - [ ] New functions/methods are listed in `api.rst` - - - -

- Overriding CI behaviors -

- By default, the upstream dev CI is disabled on pull request and push events. You can override this behavior per commit by adding a [test-upstream] tag to the first line of the commit message. For documentation-only commits, you can skip the CI per commit by adding a [skip-ci] tag to the first line of the commit message -
diff --git a/doc/contributing.rst b/doc/contributing.rst index 9c4ce5a0af2..439791cbbd6 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -836,6 +836,7 @@ PR checklist - Write new tests if needed. See `"Test-driven development/code writing" `_. - Test the code using `Pytest `_. Running all tests (type ``pytest`` in the root directory) takes a while, so feel free to only run the tests you think are needed based on your PR (example: ``pytest xarray/tests/test_dataarray.py``). CI will catch any failing tests. + - By default, the upstream dev CI is disabled on pull request and push events. You can override this behavior per commit by adding a [test-upstream] tag to the first line of the commit message. For documentation-only commits, you can skip the CI per commit by adding a "[skip-ci]" tag to the first line of the commit message. - **Properly format your code** and verify that it passes the formatting guidelines set by `Black `_ and `Flake8 `_. See `"Code formatting" `_. You can use `pre-commit `_ to run these automatically on each commit. From d555172c7d069ca9cf7a9a32bfd5f422be133861 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sun, 24 Jan 2021 15:46:48 -0800 Subject: [PATCH 07/14] Allow swap_dims to take kwargs (#4841) --- doc/whats-new.rst | 5 ++++- xarray/core/dataarray.py | 9 ++++++++- xarray/core/dataset.py | 10 +++++++++- xarray/tests/test_dataarray.py | 10 ++++++++++ xarray/tests/test_dataset.py | 7 +++++++ 5 files changed, 38 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 16b0cbf4ea1..0f2bf423449 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -46,7 +46,10 @@ Breaking changes New Features ~~~~~~~~~~~~ - Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables. - By `Deepak Cherian `_ + By `Deepak Cherian `_. +- :py:meth:`DataArray.swap_dims` & :py:meth:`Dataset.swap_dims` now accept dims + in the form of kwargs as well as a dict, like most similar methods. + By `Maximilian Roos `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index f062b70aac1..2fef3edbc43 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1699,7 +1699,9 @@ def rename( new_name_or_name_dict = cast(Hashable, new_name_or_name_dict) return self._replace(name=new_name_or_name_dict) - def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray": + def swap_dims( + self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs + ) -> "DataArray": """Returns a new DataArray with swapped dimensions. Parameters @@ -1708,6 +1710,10 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray": Dictionary whose keys are current dimension names and whose values are new names. + **dim_kwargs : {dim: , ...}, optional + The keyword arguments form of ``dims_dict``. + One of dims_dict or dims_kwargs must be provided. + Returns ------- swapped : DataArray @@ -1749,6 +1755,7 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray": DataArray.rename Dataset.swap_dims """ + dims_dict = either_dict_or_kwargs(dims_dict, dims_kwargs, "swap_dims") ds = self._to_temp_dataset().swap_dims(dims_dict) return self._from_temp_dataset(ds) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 874e26ff465..f8718377104 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3155,7 +3155,9 @@ def rename_vars( ) return self._replace(variables, coord_names, dims=dims, indexes=indexes) - def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset": + def swap_dims( + self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs + ) -> "Dataset": """Returns a new object with swapped dimensions. Parameters @@ -3164,6 +3166,10 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset": Dictionary whose keys are current dimension names and whose values are new names. + **dim_kwargs : {existing_dim: new_dim, ...}, optional + The keyword arguments form of ``dims_dict``. + One of dims_dict or dims_kwargs must be provided. + Returns ------- swapped : Dataset @@ -3214,6 +3220,8 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset": """ # TODO: deprecate this method in favor of a (less confusing) # rename_dims() method that only renames dimensions. + + dims_dict = either_dict_or_kwargs(dims_dict, dims_kwargs, "swap_dims") for k, v in dims_dict.items(): if k not in self.dims: raise ValueError( diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index adf282ff34c..fc84687511e 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1639,6 +1639,16 @@ def test_swap_dims(self): expected.indexes[dim_name], actual.indexes[dim_name] ) + # as kwargs + array = DataArray(np.random.randn(3), {"x": list("abc")}, "x") + expected = DataArray(array.values, {"x": ("y", list("abc"))}, dims="y") + actual = array.swap_dims(x="y") + assert_identical(expected, actual) + for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()): + pd.testing.assert_index_equal( + expected.indexes[dim_name], actual.indexes[dim_name] + ) + # multiindex case idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"]) array = DataArray(np.random.randn(3), {"y": ("x", idx)}, "x") diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f71b8ec7741..fed9098701b 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2748,6 +2748,13 @@ def test_swap_dims(self): actual = original.swap_dims({"x": "u"}) assert_identical(expected, actual) + # as kwargs + expected = Dataset( + {"y": ("u", list("abc")), "z": 42}, coords={"x": ("u", [1, 2, 3])} + ) + actual = original.swap_dims(x="u") + assert_identical(expected, actual) + # handle multiindex case idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"]) original = Dataset({"x": [1, 2, 3], "y": ("x", idx), "z": 42}) From a0c71c1508f34345ad7eef244cdbbe224e031c1b Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sun, 24 Jan 2021 15:48:04 -0800 Subject: [PATCH 08/14] Faster unstacking (#4746) * Significantly improve unstacking performance * Hack to get sparse tests passing * Use the existing unstack function for dask & sparse * Add whatsnew * Require numpy 1.17 for new unstack * Also special case pint * Revert "Also special case pint" This reverts commit b33adedbfbd92df0f4188568691c7e2915bf8c19. * Only run fast unstack on numpy arrays * Update asvs for unstacking * Update whatsnew --- asv_bench/benchmarks/unstacking.py | 15 ++++-- doc/whats-new.rst | 7 ++- xarray/core/dataset.py | 75 ++++++++++++++++++++++++++++-- xarray/core/variable.py | 68 +++++++++++++++++++++++++-- 4 files changed, 153 insertions(+), 12 deletions(-) diff --git a/asv_bench/benchmarks/unstacking.py b/asv_bench/benchmarks/unstacking.py index 342475b96df..8d0c3932870 100644 --- a/asv_bench/benchmarks/unstacking.py +++ b/asv_bench/benchmarks/unstacking.py @@ -7,18 +7,23 @@ class Unstacking: def setup(self): - data = np.random.RandomState(0).randn(1, 1000, 500) - self.ds = xr.DataArray(data).stack(flat_dim=["dim_1", "dim_2"]) + data = np.random.RandomState(0).randn(500, 1000) + self.da_full = xr.DataArray(data, dims=list("ab")).stack(flat_dim=[...]) + self.da_missing = self.da_full[:-1] + self.df_missing = self.da_missing.to_pandas() def time_unstack_fast(self): - self.ds.unstack("flat_dim") + self.da_full.unstack("flat_dim") def time_unstack_slow(self): - self.ds[:, ::-1].unstack("flat_dim") + self.da_missing.unstack("flat_dim") + + def time_unstack_pandas_slow(self): + self.df_missing.unstack() class UnstackingDask(Unstacking): def setup(self, *args, **kwargs): requires_dask() super().setup(**kwargs) - self.ds = self.ds.chunk({"flat_dim": 50}) + self.da_full = self.da_full.chunk({"flat_dim": 50}) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0f2bf423449..488d8baa650 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -17,7 +17,7 @@ What's New .. _whats-new.0.16.3: -v0.16.3 (unreleased) +v0.17.0 (unreleased) -------------------- Breaking changes @@ -45,6 +45,11 @@ Breaking changes New Features ~~~~~~~~~~~~ +- Significantly higher ``unstack`` performance on numpy-backed arrays which + contain missing values; 8x faster in our benchmark, and 2x faster than pandas. + (:pull:`4746`); + By `Maximilian Roos `_. + - Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables. By `Deepak Cherian `_. - :py:meth:`DataArray.swap_dims` & :py:meth:`Dataset.swap_dims` now accept dims diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f8718377104..a73e299e27a 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4,6 +4,7 @@ import sys import warnings from collections import defaultdict +from distutils.version import LooseVersion from html import escape from numbers import Number from operator import methodcaller @@ -79,7 +80,7 @@ ) from .missing import get_clean_interp_index from .options import OPTIONS, _get_keep_attrs -from .pycompat import is_duck_dask_array +from .pycompat import is_duck_dask_array, sparse_array_type from .utils import ( Default, Frozen, @@ -3715,7 +3716,40 @@ def ensure_stackable(val): return data_array - def _unstack_once(self, dim: Hashable, fill_value, sparse) -> "Dataset": + def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset": + index = self.get_index(dim) + index = remove_unused_levels_categories(index) + + variables: Dict[Hashable, Variable] = {} + indexes = {k: v for k, v in self.indexes.items() if k != dim} + + for name, var in self.variables.items(): + if name != dim: + if dim in var.dims: + if isinstance(fill_value, Mapping): + fill_value_ = fill_value[name] + else: + fill_value_ = fill_value + + variables[name] = var._unstack_once( + index=index, dim=dim, fill_value=fill_value_ + ) + else: + variables[name] = var + + for name, lev in zip(index.names, index.levels): + variables[name] = IndexVariable(name, lev) + indexes[name] = lev + + coord_names = set(self._coord_names) - {dim} | set(index.names) + + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes + ) + + def _unstack_full_reindex( + self, dim: Hashable, fill_value, sparse: bool + ) -> "Dataset": index = self.get_index(dim) index = remove_unused_levels_categories(index) full_idx = pd.MultiIndex.from_product(index.levels, names=index.names) @@ -3812,7 +3846,38 @@ def unstack( result = self.copy(deep=False) for dim in dims: - result = result._unstack_once(dim, fill_value, sparse) + + if ( + # Dask arrays don't support assignment by index, which the fast unstack + # function requires. + # https://github.com/pydata/xarray/pull/4746#issuecomment-753282125 + any(is_duck_dask_array(v.data) for v in self.variables.values()) + # Sparse doesn't currently support (though we could special-case + # it) + # https://github.com/pydata/sparse/issues/422 + or any( + isinstance(v.data, sparse_array_type) + for v in self.variables.values() + ) + or sparse + # numpy full_like only added `shape` in 1.17 + or LooseVersion(np.__version__) < LooseVersion("1.17") + # Until https://github.com/pydata/xarray/pull/4751 is resolved, + # we check explicitly whether it's a numpy array. Once that is + # resolved, explicitly exclude pint arrays. + # # pint doesn't implement `np.full_like` in a way that's + # # currently compatible. + # # https://github.com/pydata/xarray/pull/4746#issuecomment-753425173 + # # or any( + # # isinstance(v.data, pint_array_type) for v in self.variables.values() + # # ) + or any( + not isinstance(v.data, np.ndarray) for v in self.variables.values() + ) + ): + result = result._unstack_full_reindex(dim, fill_value, sparse) + else: + result = result._unstack_once(dim, fill_value) return result def update(self, other: "CoercibleMapping") -> "Dataset": @@ -4982,6 +5047,10 @@ def _set_numpy_data_from_dataframe( self[name] = (dims, values) return + # NB: similar, more general logic, now exists in + # variable.unstack_once; we could consider combining them at some + # point. + shape = tuple(lev.size for lev in idx.levels) indexer = tuple(idx.codes) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 797de65bbcf..64c1895da59 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -10,6 +10,7 @@ Any, Dict, Hashable, + List, Mapping, Optional, Sequence, @@ -1488,7 +1489,7 @@ def set_dims(self, dims, shape=None): ) return expanded_var.transpose(*dims) - def _stack_once(self, dims, new_dim): + def _stack_once(self, dims: List[Hashable], new_dim: Hashable): if not set(dims) <= set(self.dims): raise ValueError("invalid existing dimensions: %s" % dims) @@ -1544,7 +1545,15 @@ def stack(self, dimensions=None, **dimensions_kwargs): result = result._stack_once(dims, new_dim) return result - def _unstack_once(self, dims, old_dim): + def _unstack_once_full( + self, dims: Mapping[Hashable, int], old_dim: Hashable + ) -> "Variable": + """ + Unstacks the variable without needing an index. + + Unlike `_unstack_once`, this function requires the existing dimension to + contain the full product of the new dimensions. + """ new_dim_names = tuple(dims.keys()) new_dim_sizes = tuple(dims.values()) @@ -1573,6 +1582,53 @@ def _unstack_once(self, dims, old_dim): return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) + def _unstack_once( + self, + index: pd.MultiIndex, + dim: Hashable, + fill_value=dtypes.NA, + ) -> "Variable": + """ + Unstacks this variable given an index to unstack and the name of the + dimension to which the index refers. + """ + + reordered = self.transpose(..., dim) + + new_dim_sizes = [lev.size for lev in index.levels] + new_dim_names = index.names + indexer = index.codes + + # Potentially we could replace `len(other_dims)` with just `-1` + other_dims = [d for d in self.dims if d != dim] + new_shape = list(reordered.shape[: len(other_dims)]) + new_dim_sizes + new_dims = reordered.dims[: len(other_dims)] + new_dim_names + + if fill_value is dtypes.NA: + is_missing_values = np.prod(new_shape) > np.prod(self.shape) + if is_missing_values: + dtype, fill_value = dtypes.maybe_promote(self.dtype) + else: + dtype = self.dtype + fill_value = dtypes.get_fill_value(dtype) + else: + dtype = self.dtype + + # Currently fails on sparse due to https://github.com/pydata/sparse/issues/422 + data = np.full_like( + self.data, + fill_value=fill_value, + shape=new_shape, + dtype=dtype, + ) + + # Indexer is a list of lists of locations. Each list is the locations + # on the new dimension. This is robust to the data being sparse; in that + # case the destinations will be NaN / zero. + data[(..., *indexer)] = reordered + + return self._replace(dims=new_dims, data=data) + def unstack(self, dimensions=None, **dimensions_kwargs): """ Unstack an existing dimension into multiple new dimensions. @@ -1580,6 +1636,10 @@ def unstack(self, dimensions=None, **dimensions_kwargs): New dimensions will be added at the end, and the order of the data along each new dimension will be in contiguous (C) order. + Note that unlike ``DataArray.unstack`` and ``Dataset.unstack``, this + method requires the existing dimension to contain the full product of + the new dimensions. + Parameters ---------- dimensions : mapping of hashable to mapping of hashable to int @@ -1598,11 +1658,13 @@ def unstack(self, dimensions=None, **dimensions_kwargs): See also -------- Variable.stack + DataArray.unstack + Dataset.unstack """ dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, "unstack") result = self for old_dim, dims in dimensions.items(): - result = result._unstack_once(dims, old_dim) + result = result._unstack_once_full(dims, old_dim) return result def fillna(self, value): From d524d72c6cc97a87787117dd39c642254754bac4 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Tue, 26 Jan 2021 18:30:19 +0100 Subject: [PATCH 09/14] iris update doc url (#4845) --- doc/conf.py | 2 +- doc/faq.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index d83e966f3fa..14b28b4e471 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -411,7 +411,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), - "iris": ("https://scitools.org.uk/iris/docs/latest", None), + "iris": ("https://scitools-iris.readthedocs.io/en/latest", None), "numpy": ("https://numpy.org/doc/stable", None), "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), "numba": ("https://numba.pydata.org/numba-doc/latest", None), diff --git a/doc/faq.rst b/doc/faq.rst index a2b8be47e06..a2151cc4b37 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -166,7 +166,7 @@ different approaches to handling metadata: Iris strictly interprets `CF conventions`_. Iris particularly shines at mapping, thanks to its integration with Cartopy_. -.. _Iris: http://scitools.org.uk/iris/ +.. _Iris: https://scitools-iris.readthedocs.io/en/stable/ .. _Cartopy: http://scitools.org.uk/cartopy/docs/latest/ `UV-CDAT`__ is another Python library that implements in-memory netCDF-like From a4bb7e1dc80c7e413dd9b459671d10a666b395e7 Mon Sep 17 00:00:00 2001 From: Michael Mann Date: Tue, 26 Jan 2021 12:53:46 -0500 Subject: [PATCH 10/14] Update related-projects.rst (#4844) adding mention of geowombat for remote sensing applications Co-authored-by: Keewis --- doc/related-projects.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/related-projects.rst b/doc/related-projects.rst index 456cb64197f..0a010195d6d 100644 --- a/doc/related-projects.rst +++ b/doc/related-projects.rst @@ -15,6 +15,7 @@ Geosciences - `aospy `_: Automated analysis and management of gridded climate data. - `climpred `_: Analysis of ensemble forecast models for climate prediction. - `geocube `_: Tool to convert geopandas vector data into rasterized xarray data. +- `GeoWombat `_: Utilities for analysis of remotely sensed and gridded raster data at scale (easily tame Landsat, Sentinel, Quickbird, and PlanetScope). - `infinite-diff `_: xarray-based finite-differencing, focused on gridded climate/meterology data - `marc_analysis `_: Analysis package for CESM/MARC experiments and output. - `MetPy `_: A collection of tools in Python for reading, visualizing, and performing calculations with weather data. From 9fea799761ae178e586c59d1a67f480abecf2637 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 27 Jan 2021 09:05:30 +0100 Subject: [PATCH 11/14] weighted: small improvements (#4818) * weighted: small improvements * use T_DataWithCoords --- xarray/core/common.py | 11 ++++++++- xarray/core/weighted.py | 49 +++++++++++++++-------------------------- 2 files changed, 28 insertions(+), 32 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index a69ba03a7a4..c5836c68759 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -3,6 +3,7 @@ from html import escape from textwrap import dedent from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -32,6 +33,12 @@ ALL_DIMS = ... +if TYPE_CHECKING: + from .dataarray import DataArray + from .weighted import Weighted + +T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") + C = TypeVar("C") T = TypeVar("T") @@ -772,7 +779,9 @@ def groupby_bins( }, ) - def weighted(self, weights): + def weighted( + self: T_DataWithCoords, weights: "DataArray" + ) -> "Weighted[T_DataWithCoords]": """ Weighted operations. diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index dbd4e1ad103..449a7200ee7 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,13 +1,16 @@ -from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload +from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, TypeVar, Union from . import duck_array_ops from .computation import dot -from .options import _get_keep_attrs from .pycompat import is_duck_dask_array if TYPE_CHECKING: + from .common import DataWithCoords # noqa: F401 from .dataarray import DataArray, Dataset +T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") + + _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """ Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s). @@ -56,7 +59,7 @@ """ -class Weighted: +class Weighted(Generic[T_DataWithCoords]): """An object that implements weighted operations. You should create a Weighted object by using the ``DataArray.weighted`` or @@ -70,15 +73,7 @@ class Weighted: __slots__ = ("obj", "weights") - @overload - def __init__(self, obj: "DataArray", weights: "DataArray") -> None: - ... - - @overload - def __init__(self, obj: "Dataset", weights: "DataArray") -> None: - ... - - def __init__(self, obj, weights): + def __init__(self, obj: T_DataWithCoords, weights: "DataArray"): """ Create a Weighted object @@ -121,8 +116,8 @@ def _weight_check(w): else: _weight_check(weights.data) - self.obj = obj - self.weights = weights + self.obj: T_DataWithCoords = obj + self.weights: "DataArray" = weights @staticmethod def _reduce( @@ -146,7 +141,6 @@ def _reduce( # `dot` does not broadcast arrays, so this avoids creating a large # DataArray (if `weights` has additional dimensions) - # maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`) return dot(da, weights, dims=dim) def _sum_of_weights( @@ -203,7 +197,7 @@ def sum_of_weights( self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, keep_attrs: Optional[bool] = None, - ) -> Union["DataArray", "Dataset"]: + ) -> T_DataWithCoords: return self._implementation( self._sum_of_weights, dim=dim, keep_attrs=keep_attrs @@ -214,7 +208,7 @@ def sum( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, keep_attrs: Optional[bool] = None, - ) -> Union["DataArray", "Dataset"]: + ) -> T_DataWithCoords: return self._implementation( self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs @@ -225,7 +219,7 @@ def mean( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, keep_attrs: Optional[bool] = None, - ) -> Union["DataArray", "Dataset"]: + ) -> T_DataWithCoords: return self._implementation( self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs @@ -239,22 +233,15 @@ def __repr__(self): return f"{klass} with weights along dimensions: {weight_dims}" -class DataArrayWeighted(Weighted): - def _implementation(self, func, dim, **kwargs): - - keep_attrs = kwargs.pop("keep_attrs") - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) - - weighted = func(self.obj, dim=dim, **kwargs) - - if keep_attrs: - weighted.attrs = self.obj.attrs +class DataArrayWeighted(Weighted["DataArray"]): + def _implementation(self, func, dim, **kwargs) -> "DataArray": - return weighted + dataset = self.obj._to_temp_dataset() + dataset = dataset.map(func, dim=dim, **kwargs) + return self.obj._from_temp_dataset(dataset) -class DatasetWeighted(Weighted): +class DatasetWeighted(Weighted["Dataset"]): def _implementation(self, func, dim, **kwargs) -> "Dataset": return self.obj.map(func, dim=dim, **kwargs) From 8cc34cb412ba89ebca12fc84f76a9e452628f1bc Mon Sep 17 00:00:00 2001 From: Aureliana Barghini <35919497+aurghs@users.noreply.github.com> Date: Thu, 28 Jan 2021 16:20:59 +0100 Subject: [PATCH 12/14] WIP: backend interface, now it uses subclassing (#4836) * draft * working version * fix: instantiate BackendEtrypoints * rename AbstractBackendEntrypoint in BackendEntrypoint * fix plugins tests * style * style * raise NotImplemetedError if BackendEntrypoint.open_dataset is not implemented --- xarray/backends/cfgrib_.py | 106 +++++++++++++++--------------- xarray/backends/common.py | 15 +++-- xarray/backends/h5netcdf_.py | 107 +++++++++++++++---------------- xarray/backends/netCDF4_.py | 107 +++++++++++++++---------------- xarray/backends/plugins.py | 14 ++-- xarray/backends/pseudonetcdf_.py | 96 ++++++++++++++------------- xarray/backends/pydap_.py | 68 +++++++++----------- xarray/backends/pynio_.py | 64 +++++++++--------- xarray/backends/scipy_.py | 94 +++++++++++++-------------- xarray/backends/store.py | 84 ++++++++++++------------ xarray/backends/zarr.py | 81 ++++++++++++----------- xarray/tests/test_plugins.py | 65 ++++++++++--------- 12 files changed, 446 insertions(+), 455 deletions(-) diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py index 4a0ac7d67f9..65c5bc2a02b 100644 --- a/xarray/backends/cfgrib_.py +++ b/xarray/backends/cfgrib_.py @@ -12,7 +12,7 @@ BackendEntrypoint, ) from .locks import SerializableLock, ensure_lock -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: import cfgrib @@ -86,62 +86,58 @@ def get_encoding(self): return encoding -def guess_can_open_cfgrib(store_spec): - try: - _, ext = os.path.splitext(store_spec) - except TypeError: - return False - return ext in {".grib", ".grib2", ".grb", ".grb2"} - - -def open_backend_dataset_cfgrib( - filename_or_obj, - *, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - lock=None, - indexpath="{path}.{short_hash}.idx", - filter_by_keys={}, - read_keys=[], - encode_cf=("parameter", "time", "geography", "vertical"), - squeeze=True, - time_dims=("time", "step"), -): - - store = CfGribDataStore( +class CfgribfBackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + return ext in {".grib", ".grib2", ".grb", ".grb2"} + + def open_dataset( + self, filename_or_obj, - indexpath=indexpath, - filter_by_keys=filter_by_keys, - read_keys=read_keys, - encode_cf=encode_cf, - squeeze=squeeze, - time_dims=time_dims, - lock=lock, - ) - - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + *, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + lock=None, + indexpath="{path}.{short_hash}.idx", + filter_by_keys={}, + read_keys=[], + encode_cf=("parameter", "time", "geography", "vertical"), + squeeze=True, + time_dims=("time", "step"), + ): + + store = CfGribDataStore( + filename_or_obj, + indexpath=indexpath, + filter_by_keys=filter_by_keys, + read_keys=read_keys, + encode_cf=encode_cf, + squeeze=squeeze, + time_dims=time_dims, + lock=lock, ) - return ds - - -cfgrib_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_cfgrib, guess_can_open=guess_can_open_cfgrib -) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds if has_cfgrib: - BACKEND_ENTRYPOINTS["cfgrib"] = cfgrib_backend + BACKEND_ENTRYPOINTS["cfgrib"] = CfgribfBackendEntrypoint diff --git a/xarray/backends/common.py b/xarray/backends/common.py index adb70658fab..e2905d0866b 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -1,7 +1,7 @@ import logging import time import traceback -from typing import Dict +from typing import Dict, Tuple, Type, Union import numpy as np @@ -344,12 +344,13 @@ def encode(self, variables, attributes): class BackendEntrypoint: - __slots__ = ("guess_can_open", "open_dataset", "open_dataset_parameters") + open_dataset_parameters: Union[Tuple, None] = None - def __init__(self, open_dataset, open_dataset_parameters=None, guess_can_open=None): - self.open_dataset = open_dataset - self.open_dataset_parameters = open_dataset_parameters - self.guess_can_open = guess_can_open + def open_dataset(self): + raise NotImplementedError + def guess_can_open(self, store_spec): + return False -BACKEND_ENTRYPOINTS: Dict[str, BackendEntrypoint] = {} + +BACKEND_ENTRYPOINTS: Dict[str, Type[BackendEntrypoint]] = {} diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 562600de4b6..aa892c4f89c 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -23,7 +23,7 @@ _get_datatype, _nc4_require_group, ) -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: import h5netcdf @@ -328,62 +328,61 @@ def close(self, **kwargs): self._manager.close(**kwargs) -def guess_can_open_h5netcdf(store_spec): - try: - return read_magic_number(store_spec).startswith(b"\211HDF\r\n\032\n") - except TypeError: - pass - - try: - _, ext = os.path.splitext(store_spec) - except TypeError: - return False - - return ext in {".nc", ".nc4", ".cdf"} - - -def open_backend_dataset_h5netcdf( - filename_or_obj, - *, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - format=None, - group=None, - lock=None, - invalid_netcdf=None, - phony_dims=None, -): - - store = H5NetCDFStore.open( +class H5netcdfBackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + try: + return read_magic_number(store_spec).startswith(b"\211HDF\r\n\032\n") + except TypeError: + pass + + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + + return ext in {".nc", ".nc4", ".cdf"} + + def open_dataset( + self, filename_or_obj, - format=format, - group=group, - lock=lock, - invalid_netcdf=invalid_netcdf, - phony_dims=phony_dims, - ) + *, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + format=None, + group=None, + lock=None, + invalid_netcdf=None, + phony_dims=None, + ): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, - ) - return ds + store = H5NetCDFStore.open( + filename_or_obj, + format=format, + group=group, + lock=lock, + invalid_netcdf=invalid_netcdf, + phony_dims=phony_dims, + ) + store_entrypoint = StoreBackendEntrypoint() + + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds -h5netcdf_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_h5netcdf, guess_can_open=guess_can_open_h5netcdf -) if has_h5netcdf: - BACKEND_ENTRYPOINTS["h5netcdf"] = h5netcdf_backend + BACKEND_ENTRYPOINTS["h5netcdf"] = H5netcdfBackendEntrypoint diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 5bb4eec837b..e3d87aaf83f 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -22,7 +22,7 @@ from .file_manager import CachingFileManager, DummyFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock, get_write_lock from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: import netCDF4 @@ -512,65 +512,62 @@ def close(self, **kwargs): self._manager.close(**kwargs) -def guess_can_open_netcdf4(store_spec): - if isinstance(store_spec, str) and is_remote_uri(store_spec): - return True - try: - _, ext = os.path.splitext(store_spec) - except TypeError: - return False - return ext in {".nc", ".nc4", ".cdf"} - - -def open_backend_dataset_netcdf4( - filename_or_obj, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - group=None, - mode="r", - format="NETCDF4", - clobber=True, - diskless=False, - persist=False, - lock=None, - autoclose=False, -): +class NetCDF4BackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + if isinstance(store_spec, str) and is_remote_uri(store_spec): + return True + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + return ext in {".nc", ".nc4", ".cdf"} - store = NetCDF4DataStore.open( + def open_dataset( + self, filename_or_obj, - mode=mode, - format=format, - group=group, - clobber=clobber, - diskless=diskless, - persist=persist, - lock=lock, - autoclose=autoclose, - ) + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + group=None, + mode="r", + format="NETCDF4", + clobber=True, + diskless=False, + persist=False, + lock=None, + autoclose=False, + ): - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + store = NetCDF4DataStore.open( + filename_or_obj, + mode=mode, + format=format, + group=group, + clobber=clobber, + diskless=diskless, + persist=persist, + lock=lock, + autoclose=autoclose, ) - return ds - -netcdf4_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_netcdf4, guess_can_open=guess_can_open_netcdf4 -) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds if has_netcdf4: - BACKEND_ENTRYPOINTS["netcdf4"] = netcdf4_backend + BACKEND_ENTRYPOINTS["netcdf4"] = NetCDF4BackendEntrypoint diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index 6d3ec7e7da5..b8cd2bf6378 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -36,6 +36,7 @@ def remove_duplicates(backend_entrypoints): def detect_parameters(open_dataset): signature = inspect.signature(open_dataset) parameters = signature.parameters + parameters_list = [] for name, param in parameters.items(): if param.kind in ( inspect.Parameter.VAR_KEYWORD, @@ -45,7 +46,9 @@ def detect_parameters(open_dataset): f"All the parameters in {open_dataset!r} signature should be explicit. " "*args and **kwargs is not supported" ) - return tuple(parameters) + if name != "self": + parameters_list.append(name) + return tuple(parameters_list) def create_engines_dict(backend_entrypoints): @@ -57,8 +60,8 @@ def create_engines_dict(backend_entrypoints): return engines -def set_missing_parameters(engines): - for name, backend in engines.items(): +def set_missing_parameters(backend_entrypoints): + for name, backend in backend_entrypoints.items(): if backend.open_dataset_parameters is None: open_dataset = backend.open_dataset backend.open_dataset_parameters = detect_parameters(open_dataset) @@ -70,7 +73,10 @@ def build_engines(entrypoints): external_backend_entrypoints = create_engines_dict(pkg_entrypoints) backend_entrypoints.update(external_backend_entrypoints) set_missing_parameters(backend_entrypoints) - return backend_entrypoints + engines = {} + for name, backend in backend_entrypoints.items(): + engines[name] = backend() + return engines @functools.lru_cache(maxsize=1) diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index c2bfd519bed..80485fce459 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -11,7 +11,7 @@ ) from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: from PseudoNetCDF import pncopen @@ -100,57 +100,55 @@ def close(self): self._manager.close() -def open_backend_dataset_pseudonetcdf( - filename_or_obj, - mask_and_scale=False, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - mode=None, - lock=None, - **format_kwargs, -): - - store = PseudoNetCDFDataStore.open( - filename_or_obj, lock=lock, mode=mode, **format_kwargs +class PseudoNetCDFBackendEntrypoint(BackendEntrypoint): + + # *args and **kwargs are not allowed in open_backend_dataset_ kwargs, + # unless the open_dataset_parameters are explicity defined like this: + open_dataset_parameters = ( + "filename_or_obj", + "mask_and_scale", + "decode_times", + "concat_characters", + "decode_coords", + "drop_variables", + "use_cftime", + "decode_timedelta", + "mode", + "lock", ) - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + def open_dataset( + self, + filename_or_obj, + mask_and_scale=False, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + mode=None, + lock=None, + **format_kwargs, + ): + store = PseudoNetCDFDataStore.open( + filename_or_obj, lock=lock, mode=mode, **format_kwargs ) - return ds - - -# *args and **kwargs are not allowed in open_backend_dataset_ kwargs, -# unless the open_dataset_parameters are explicity defined like this: -open_dataset_parameters = ( - "filename_or_obj", - "mask_and_scale", - "decode_times", - "concat_characters", - "decode_coords", - "drop_variables", - "use_cftime", - "decode_timedelta", - "mode", - "lock", -) -pseudonetcdf_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_pseudonetcdf, - open_dataset_parameters=open_dataset_parameters, -) + + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds if has_pseudonetcdf: - BACKEND_ENTRYPOINTS["pseudonetcdf"] = pseudonetcdf_backend + BACKEND_ENTRYPOINTS["pseudonetcdf"] = PseudoNetCDFBackendEntrypoint diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index c5ce943a10a..7f8622ca66e 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -11,7 +11,7 @@ BackendEntrypoint, robust_getitem, ) -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: import pydap.client @@ -107,45 +107,41 @@ def get_dimensions(self): return Frozen(self.ds.dimensions) -def guess_can_open_pydap(store_spec): - return isinstance(store_spec, str) and is_remote_uri(store_spec) +class PydapBackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + return isinstance(store_spec, str) and is_remote_uri(store_spec) - -def open_backend_dataset_pydap( - filename_or_obj, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - session=None, -): - - store = PydapDataStore.open( + def open_dataset( + self, filename_or_obj, - session=session, - ) - - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + session=None, + ): + store = PydapDataStore.open( + filename_or_obj, + session=session, ) - return ds - -pydap_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_pydap, guess_can_open=guess_can_open_pydap -) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds if has_pydap: - BACKEND_ENTRYPOINTS["pydap"] = pydap_backend + BACKEND_ENTRYPOINTS["pydap"] = PydapBackendEntrypoint diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 261daa69880..41c99efd076 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -11,7 +11,7 @@ ) from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, SerializableLock, combine_locks, ensure_lock -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: import Nio @@ -97,41 +97,39 @@ def close(self): self._manager.close() -def open_backend_dataset_pynio( - filename_or_obj, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - mode="r", - lock=None, -): - - store = NioDataStore( +class PynioBackendEntrypoint(BackendEntrypoint): + def open_dataset( filename_or_obj, - mode=mode, - lock=lock, - ) - - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + mode="r", + lock=None, + ): + store = NioDataStore( + filename_or_obj, + mode=mode, + lock=lock, ) - return ds - -pynio_backend = BackendEntrypoint(open_dataset=open_backend_dataset_pynio) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds if has_pynio: - BACKEND_ENTRYPOINTS["pynio"] = pynio_backend + BACKEND_ENTRYPOINTS["pynio"] = PynioBackendEntrypoint diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index df51d07d686..ddc157ed8e4 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -15,7 +15,7 @@ from .file_manager import CachingFileManager, DummyFileManager from .locks import ensure_lock, get_write_lock from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: import scipy.io @@ -232,56 +232,54 @@ def close(self): self._manager.close() -def guess_can_open_scipy(store_spec): - try: - return read_magic_number(store_spec).startswith(b"CDF") - except TypeError: - pass +class ScipyBackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + try: + return read_magic_number(store_spec).startswith(b"CDF") + except TypeError: + pass - try: - _, ext = os.path.splitext(store_spec) - except TypeError: - return False - return ext in {".nc", ".nc4", ".cdf", ".gz"} - - -def open_backend_dataset_scipy( - filename_or_obj, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - mode="r", - format=None, - group=None, - mmap=None, - lock=None, -): - - store = ScipyDataStore( - filename_or_obj, mode=mode, format=format, group=group, mmap=mmap, lock=lock - ) - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, - ) - return ds + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + return ext in {".nc", ".nc4", ".cdf", ".gz"} + + def open_dataset( + self, + filename_or_obj, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + mode="r", + format=None, + group=None, + mmap=None, + lock=None, + ): + store = ScipyDataStore( + filename_or_obj, mode=mode, format=format, group=group, mmap=mmap, lock=lock + ) -scipy_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_scipy, guess_can_open=guess_can_open_scipy -) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds if has_scipy: - BACKEND_ENTRYPOINTS["scipy"] = scipy_backend + BACKEND_ENTRYPOINTS["scipy"] = ScipyBackendEntrypoint diff --git a/xarray/backends/store.py b/xarray/backends/store.py index 66fca0d39c3..d57b3ab9df8 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -3,47 +3,43 @@ from .common import BACKEND_ENTRYPOINTS, AbstractDataStore, BackendEntrypoint -def guess_can_open_store(store_spec): - return isinstance(store_spec, AbstractDataStore) - - -def open_backend_dataset_store( - store, - *, - mask_and_scale=True, - decode_times=True, - concat_characters=True, - decode_coords=True, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, -): - vars, attrs = store.load() - encoding = store.get_encoding() - - vars, attrs, coord_names = conventions.decode_cf_variables( - vars, - attrs, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, - ) - - ds = Dataset(vars, attrs=attrs) - ds = ds.set_coords(coord_names.intersection(vars)) - ds.set_close(store.close) - ds.encoding = encoding - - return ds - - -store_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_store, guess_can_open=guess_can_open_store -) - - -BACKEND_ENTRYPOINTS["store"] = store_backend +class StoreBackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + return isinstance(store_spec, AbstractDataStore) + + def open_dataset( + self, + store, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + ): + vars, attrs = store.load() + encoding = store.get_encoding() + + vars, attrs, coord_names = conventions.decode_cf_variables( + vars, + attrs, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + + ds = Dataset(vars, attrs=attrs) + ds = ds.set_coords(coord_names.intersection(vars)) + ds.set_close(store.close) + ds.encoding = encoding + + return ds + + +BACKEND_ENTRYPOINTS["store"] = StoreBackendEntrypoint diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index ceeb23cac9b..1d667a38b53 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -15,7 +15,7 @@ BackendEntrypoint, _encode_variable_name, ) -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: import zarr @@ -670,49 +670,48 @@ def open_zarr( return ds -def open_backend_dataset_zarr( - filename_or_obj, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - group=None, - mode="r", - synchronizer=None, - consolidated=False, - consolidate_on_close=False, - chunk_store=None, -): - - store = ZarrStore.open_group( +class ZarrBackendEntrypoint(BackendEntrypoint): + def open_dataset( + self, filename_or_obj, - group=group, - mode=mode, - synchronizer=synchronizer, - consolidated=consolidated, - consolidate_on_close=consolidate_on_close, - chunk_store=chunk_store, - ) - - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + group=None, + mode="r", + synchronizer=None, + consolidated=False, + consolidate_on_close=False, + chunk_store=None, + ): + store = ZarrStore.open_group( + filename_or_obj, + group=group, + mode=mode, + synchronizer=synchronizer, + consolidated=consolidated, + consolidate_on_close=consolidate_on_close, + chunk_store=chunk_store, ) - return ds - -zarr_backend = BackendEntrypoint(open_dataset=open_backend_dataset_zarr) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds if has_zarr: - BACKEND_ENTRYPOINTS["zarr"] = zarr_backend + BACKEND_ENTRYPOINTS["zarr"] = ZarrBackendEntrypoint diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py index 38ebce6da1a..64a1c563dba 100644 --- a/xarray/tests/test_plugins.py +++ b/xarray/tests/test_plugins.py @@ -6,19 +6,24 @@ from xarray.backends import common, plugins -def dummy_open_dataset_args(filename_or_obj, *args): - pass +class DummyBackendEntrypointArgs(common.BackendEntrypoint): + def open_dataset(filename_or_obj, *args): + pass -def dummy_open_dataset_kwargs(filename_or_obj, **kwargs): - pass +class DummyBackendEntrypointKwargs(common.BackendEntrypoint): + def open_dataset(filename_or_obj, **kwargs): + pass -def dummy_open_dataset(filename_or_obj, *, decoder): - pass +class DummyBackendEntrypoint1(common.BackendEntrypoint): + def open_dataset(self, filename_or_obj, *, decoder): + pass -dummy_cfgrib = common.BackendEntrypoint(dummy_open_dataset) +class DummyBackendEntrypoint2(common.BackendEntrypoint): + def open_dataset(self, filename_or_obj, *, decoder): + pass @pytest.fixture @@ -65,46 +70,48 @@ def test_create_engines_dict(): def test_set_missing_parameters(): - backend_1 = common.BackendEntrypoint(dummy_open_dataset) - backend_2 = common.BackendEntrypoint(dummy_open_dataset, ("filename_or_obj",)) + backend_1 = DummyBackendEntrypoint1 + backend_2 = DummyBackendEntrypoint2 + backend_2.open_dataset_parameters = ("filename_or_obj",) engines = {"engine_1": backend_1, "engine_2": backend_2} plugins.set_missing_parameters(engines) assert len(engines) == 2 - engine_1 = engines["engine_1"] - assert engine_1.open_dataset_parameters == ("filename_or_obj", "decoder") - engine_2 = engines["engine_2"] - assert engine_2.open_dataset_parameters == ("filename_or_obj",) + assert backend_1.open_dataset_parameters == ("filename_or_obj", "decoder") + assert backend_2.open_dataset_parameters == ("filename_or_obj",) + + backend = DummyBackendEntrypointKwargs() + backend.open_dataset_parameters = ("filename_or_obj", "decoder") + plugins.set_missing_parameters({"engine": backend}) + assert backend.open_dataset_parameters == ("filename_or_obj", "decoder") + + backend = DummyBackendEntrypointArgs() + backend.open_dataset_parameters = ("filename_or_obj", "decoder") + plugins.set_missing_parameters({"engine": backend}) + assert backend.open_dataset_parameters == ("filename_or_obj", "decoder") def test_set_missing_parameters_raise_error(): - backend = common.BackendEntrypoint(dummy_open_dataset_args) + backend = DummyBackendEntrypointKwargs() with pytest.raises(TypeError): plugins.set_missing_parameters({"engine": backend}) - backend = common.BackendEntrypoint( - dummy_open_dataset_args, ("filename_or_obj", "decoder") - ) - plugins.set_missing_parameters({"engine": backend}) - - backend = common.BackendEntrypoint(dummy_open_dataset_kwargs) + backend = DummyBackendEntrypointArgs() with pytest.raises(TypeError): plugins.set_missing_parameters({"engine": backend}) - backend = common.BackendEntrypoint( - dummy_open_dataset_kwargs, ("filename_or_obj", "decoder") - ) - plugins.set_missing_parameters({"engine": backend}) - -@mock.patch("pkg_resources.EntryPoint.load", mock.MagicMock(return_value=dummy_cfgrib)) +@mock.patch( + "pkg_resources.EntryPoint.load", + mock.MagicMock(return_value=DummyBackendEntrypoint1), +) def test_build_engines(): - dummy_cfgrib_pkg_entrypoint = pkg_resources.EntryPoint.parse( + dummy_pkg_entrypoint = pkg_resources.EntryPoint.parse( "cfgrib = xarray.tests.test_plugins:backend_1" ) - backend_entrypoints = plugins.build_engines([dummy_cfgrib_pkg_entrypoint]) - assert backend_entrypoints["cfgrib"] is dummy_cfgrib + backend_entrypoints = plugins.build_engines([dummy_pkg_entrypoint]) + assert isinstance(backend_entrypoints["cfgrib"], DummyBackendEntrypoint1) assert backend_entrypoints["cfgrib"].open_dataset_parameters == ( "filename_or_obj", "decoder", From f4b95cd28f5f34ed5ef6cbd9280904fb5449c2a7 Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Fri, 29 Jan 2021 22:59:29 +0000 Subject: [PATCH 13/14] dim -> coord in DataArray.integrate (#3993) * changed arg and docstrings * test warning is raised * updated kwarg in test * updated what's new * Fix error checking * Fix whats-new * fix docstring * small fix Co-authored-by: dcherian --- doc/whats-new.rst | 7 +++++++ xarray/core/dataarray.py | 35 ++++++++++++++++++++++++++++------- xarray/core/dataset.py | 14 ++++++++------ xarray/tests/test_dataset.py | 3 +++ xarray/tests/test_units.py | 2 +- 5 files changed, 47 insertions(+), 14 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 488d8baa650..471e91a8512 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -42,6 +42,13 @@ Breaking changes - remove deprecated ``autoclose`` kwargs from :py:func:`open_dataset` (:pull:`4725`). By `Aureliana Barghini `_. +Deprecations +~~~~~~~~~~~~ + +- ``dim`` argument to :py:meth:`DataArray.integrate` is being deprecated in + favour of a ``coord`` argument, for consistency with :py:meth:`Dataset.integrate`. + For now using ``dim`` issues a ``FutureWarning``. By `Tom Nicholas `_. + New Features ~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 2fef3edbc43..0155cdc4e19 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3481,21 +3481,26 @@ def differentiate( return self._from_temp_dataset(ds) def integrate( - self, dim: Union[Hashable, Sequence[Hashable]], datetime_unit: str = None + self, + coord: Union[Hashable, Sequence[Hashable]] = None, + datetime_unit: str = None, + *, + dim: Union[Hashable, Sequence[Hashable]] = None, ) -> "DataArray": - """ integrate the array with the trapezoidal rule. + """Integrate along the given coordinate using the trapezoidal rule. .. note:: - This feature is limited to simple cartesian geometry, i.e. dim + This feature is limited to simple cartesian geometry, i.e. coord must be one dimensional. Parameters ---------- + coord: hashable, or a sequence of hashable + Coordinate(s) used for the integration. dim : hashable, or sequence of hashable Coordinate(s) used for the integration. - datetime_unit : {"Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns", \ - "ps", "fs", "as"}, optional - Can be used to specify the unit if datetime coordinate is used. + datetime_unit: {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \ + 'ps', 'fs', 'as'}, optional Returns ------- @@ -3503,6 +3508,7 @@ def integrate( See also -------- + Dataset.integrate numpy.trapz: corresponding numpy function Examples @@ -3528,7 +3534,22 @@ def integrate( array([5.4, 6.6, 7.8]) Dimensions without coordinates: y """ - ds = self._to_temp_dataset().integrate(dim, datetime_unit) + if dim is not None and coord is not None: + raise ValueError( + "Cannot pass both 'dim' and 'coord'. Please pass only 'coord' instead." + ) + + if dim is not None and coord is None: + coord = dim + msg = ( + "The `dim` keyword argument to `DataArray.integrate` is " + "being replaced with `coord`, for consistency with " + "`Dataset.integrate`. Please pass `coord` instead." + " `dim` will be removed in version 0.19.0." + ) + warnings.warn(msg, FutureWarning, stacklevel=2) + + ds = self._to_temp_dataset().integrate(coord, datetime_unit) return self._from_temp_dataset(ds) def unify_chunks(self) -> "DataArray": diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a73e299e27a..8376b4875f9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5963,8 +5963,10 @@ def differentiate(self, coord, edge_order=1, datetime_unit=None): variables[k] = v return self._replace(variables) - def integrate(self, coord, datetime_unit=None): - """ integrate the array with the trapezoidal rule. + def integrate( + self, coord: Union[Hashable, Sequence[Hashable]], datetime_unit: str = None + ) -> "Dataset": + """Integrate along the given coordinate using the trapezoidal rule. .. note:: This feature is limited to simple cartesian geometry, i.e. coord @@ -5972,11 +5974,11 @@ def integrate(self, coord, datetime_unit=None): Parameters ---------- - coord: str, or sequence of str + coord: hashable, or a sequence of hashable Coordinate(s) used for the integration. - datetime_unit : {"Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns", \ - "ps", "fs", "as"}, optional - Can be specify the unit if datetime coordinate is used. + datetime_unit: {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \ + 'ps', 'fs', 'as'}, optional + Specify the unit if datetime coordinate is used. Returns ------- diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index fed9098701b..db47faa8d2b 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6603,6 +6603,9 @@ def test_integrate(dask): with pytest.raises(ValueError): da.integrate("x2d") + with pytest.warns(FutureWarning): + da.integrate(dim="x") + @pytest.mark.parametrize("dask", [True, False]) @pytest.mark.parametrize("which_datetime", ["np", "cftime"]) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index bb3127e90b5..76dd830de23 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -3681,7 +3681,7 @@ def test_stacking_reordering(self, func, dtype): ( method("diff", dim="x"), method("differentiate", coord="x"), - method("integrate", dim="x"), + method("integrate", coord="x"), method("quantile", q=[0.25, 0.75]), method("reduce", func=np.sum, dim="x"), pytest.param(lambda x: x.dot(x), id="method_dot"), From 39048f95c5048b95505abc3afaec3bf386cbdf10 Mon Sep 17 00:00:00 2001 From: keewis Date: Sat, 30 Jan 2021 00:05:57 +0100 Subject: [PATCH 14/14] speed up the repr for big MultiIndex objects (#4846) * print the repr of a multiindex using only a subset of the coordinate values * don't index if we have less items than available width * don't try to shorten arrays which are way too short * col_width seems to be the maximum number of elements, not characters * add a asv benchmark * Apply suggestions from code review Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- asv_bench/benchmarks/repr.py | 18 ++++++++++++++++++ xarray/core/formatting.py | 11 +++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) create mode 100644 asv_bench/benchmarks/repr.py diff --git a/asv_bench/benchmarks/repr.py b/asv_bench/benchmarks/repr.py new file mode 100644 index 00000000000..b218c0be870 --- /dev/null +++ b/asv_bench/benchmarks/repr.py @@ -0,0 +1,18 @@ +import pandas as pd + +import xarray as xr + + +class ReprMultiIndex: + def setup(self, key): + index = pd.MultiIndex.from_product( + [range(10000), range(10000)], names=("level_0", "level_1") + ) + series = pd.Series(range(100000000), index=index) + self.da = xr.DataArray(series) + + def time_repr(self): + repr(self.da) + + def time_repr_html(self): + self.da._repr_html_() diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 282620e3569..0c1be1cc175 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -300,11 +300,18 @@ def _summarize_coord_multiindex(coord, col_width, marker): def _summarize_coord_levels(coord, col_width, marker="-"): + if len(coord) > 100 and col_width < len(coord): + n_values = col_width + indices = list(range(0, n_values)) + list(range(-n_values, 0)) + subset = coord[indices] + else: + subset = coord + return "\n".join( summarize_variable( - lname, coord.get_level_variable(lname), col_width, marker=marker + lname, subset.get_level_variable(lname), col_width, marker=marker ) - for lname in coord.level_names + for lname in subset.level_names )