diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 5e9aa06f507..d1c79953a9b 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,4 +1,3 @@ - [ ] Closes #xxxx (remove if there is no corresponding issue, which should only be the case for minor changes) - [ ] Tests added (for all bug fixes or enhancements) - - [ ] Tests passed (for all non-documentation changes) - [ ] Fully documented, including `whats-new.rst` for all changes and `api.rst` for new API (remove if this change should not be visible to users, e.g., if it is an internal clean-up, or if this is part of a larger project that will be documented later) diff --git a/.pep8speaks.yml b/.pep8speaks.yml new file mode 100644 index 00000000000..cd610907007 --- /dev/null +++ b/.pep8speaks.yml @@ -0,0 +1,11 @@ +# File : .pep8speaks.yml + +scanner: + diff_only: True # If True, errors caused by only the patch are shown + +pycodestyle: + max-line-length: 79 + ignore: # Errors and warnings to ignore + - E402, # module level import not at top of file + - E731, # do not assign a lambda expression, use a def + - W503 # line break before binary operator diff --git a/.stickler.yml b/.stickler.yml deleted file mode 100644 index 79d8b7fb717..00000000000 --- a/.stickler.yml +++ /dev/null @@ -1,11 +0,0 @@ -linters: - flake8: - max-line-length: 79 - fixer: false - ignore: I002 - # stickler doesn't support 'exclude' for flake8 properly, so we disable it - # below with files.ignore: - # https://github.com/markstory/lint-review/issues/184 -files: - ignore: - - doc/**/*.py diff --git a/.travis.yml b/.travis.yml index 951b151d829..defb37ec8aa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,5 @@ # Based on http://conda.pydata.org/docs/travis.html -language: python +language: minimal sudo: false # use container based build notifications: email: false @@ -10,72 +10,48 @@ branches: matrix: fast_finish: true include: - - python: 2.7 - env: CONDA_ENV=py27-min - - python: 2.7 - env: CONDA_ENV=py27-cdat+iris+pynio - - python: 3.5 - env: CONDA_ENV=py35 - - python: 3.6 - env: CONDA_ENV=py36 - - python: 3.6 - env: + - env: CONDA_ENV=py27-min + - env: CONDA_ENV=py27-cdat+iris+pynio + - env: CONDA_ENV=py35 + - env: CONDA_ENV=py36 + - env: CONDA_ENV=py37 + - env: - CONDA_ENV=py36 - EXTRA_FLAGS="--run-flaky --run-network-tests" - - python: 3.6 - env: CONDA_ENV=py36-netcdf4-dev + - env: CONDA_ENV=py36-netcdf4-dev addons: apt_packages: - libhdf5-serial-dev - netcdf-bin - libnetcdf-dev - - python: 3.6 - env: CONDA_ENV=py36-dask-dev - - python: 3.6 - env: CONDA_ENV=py36-pandas-dev - - python: 3.6 - env: CONDA_ENV=py36-bottleneck-dev - - python: 3.6 - env: CONDA_ENV=py36-condaforge-rc - - python: 3.6 - env: CONDA_ENV=py36-pynio-dev - - python: 3.6 - env: CONDA_ENV=py36-rasterio-0.36 - - python: 3.6 - env: CONDA_ENV=py36-zarr-dev - - python: 3.5 - env: CONDA_ENV=docs - - python: 3.6 - env: CONDA_ENV=py36-hypothesis + - env: CONDA_ENV=py36-dask-dev + - env: CONDA_ENV=py36-pandas-dev + - env: CONDA_ENV=py36-bottleneck-dev + - env: CONDA_ENV=py36-condaforge-rc + - env: CONDA_ENV=py36-pynio-dev + - env: CONDA_ENV=py36-rasterio-0.36 + - env: CONDA_ENV=py36-zarr-dev + - env: CONDA_ENV=docs + - env: CONDA_ENV=py36-hypothesis + allow_failures: - - python: 3.6 - env: + - env: - CONDA_ENV=py36 - EXTRA_FLAGS="--run-flaky --run-network-tests" - - python: 3.6 - env: CONDA_ENV=py36-netcdf4-dev + - env: CONDA_ENV=py36-netcdf4-dev addons: apt_packages: - libhdf5-serial-dev - netcdf-bin - libnetcdf-dev - - python: 3.6 - env: CONDA_ENV=py36-pandas-dev - - python: 3.6 - env: CONDA_ENV=py36-bottleneck-dev - - python: 3.6 - env: CONDA_ENV=py36-condaforge-rc - - python: 3.6 - env: CONDA_ENV=py36-pynio-dev - - python: 3.6 - env: CONDA_ENV=py36-zarr-dev + - env: CONDA_ENV=py36-pandas-dev + - env: CONDA_ENV=py36-bottleneck-dev + - env: CONDA_ENV=py36-condaforge-rc + - env: CONDA_ENV=py36-pynio-dev + - env: CONDA_ENV=py36-zarr-dev before_install: - - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then - wget http://repo.continuum.io/miniconda/Miniconda-3.16.0-Linux-x86_64.sh -O miniconda.sh; - else - wget http://repo.continuum.io/miniconda/Miniconda3-3.16.0-Linux-x86_64.sh -O miniconda.sh; - fi + - wget http://repo.continuum.io/miniconda/Miniconda3-3.16.0-Linux-x86_64.sh -O miniconda.sh; - bash miniconda.sh -b -p $HOME/miniconda - export PATH="$HOME/miniconda/bin:$PATH" - hash -r @@ -95,9 +71,9 @@ install: - python xarray/util/print_versions.py script: - # TODO: restore this check once the upstream pandas issue is fixed: - # https://github.com/pandas-dev/pandas/issues/21071 - # - python -OO -c "import xarray" + - which python + - python --version + - python -OO -c "import xarray" - if [[ "$CONDA_ENV" == "docs" ]]; then conda install -c conda-forge sphinx sphinx_rtd_theme sphinx-gallery numpydoc; sphinx-build -n -j auto -b html -d _build/doctrees doc _build/html; diff --git a/README.rst b/README.rst index 94beea1dba4..0ac71d33954 100644 --- a/README.rst +++ b/README.rst @@ -15,6 +15,8 @@ xarray: N-D labeled arrays and datasets :target: https://zenodo.org/badge/latestdoi/13221727 .. image:: http://img.shields.io/badge/benchmarked%20by-asv-green.svg?style=flat :target: http://pandas.pydata.org/speed/xarray/ +.. image:: https://img.shields.io/badge/powered%20by-NumFOCUS-orange.svg?style=flat&colorA=E1523D&colorB=007D8A + :target: http://numfocus.org **xarray** (formerly **xray**) is an open source project and Python package that aims to bring the labeled data power of pandas_ to the physical sciences, by providing @@ -103,20 +105,36 @@ Get in touch .. _mailing list: https://groups.google.com/forum/#!forum/xarray .. _on GitHub: http://github.com/pydata/xarray +NumFOCUS +-------- + +.. image:: https://numfocus.org/wp-content/uploads/2017/07/NumFocus_LRG.png + :scale: 25 % + :target: https://numfocus.org/ + +Xarray is a fiscally sponsored project of NumFOCUS_, a nonprofit dedicated +to supporting the open source scientific computing community. If you like +Xarray and want to support our mission, please consider making a donation_ +to support our efforts. + +.. _donation: https://www.flipcause.com/secure/cause_pdetails/NDE2NTU= + History ------- xarray is an evolution of an internal tool developed at `The Climate Corporation`__. It was originally written by Climate Corp researchers Stephan Hoyer, Alex Kleeman and Eugene Brevdo and was released as open source in -May 2014. The project was renamed from "xray" in January 2016. +May 2014. The project was renamed from "xray" in January 2016. Xarray became a +fiscally sponsored project of NumFOCUS_ in August 2018. __ http://climate.com/ +.. _NumFOCUS: https://numfocus.org License ------- -Copyright 2014-2017, xarray Developers +Copyright 2014-2018, xarray Developers Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index b5953436387..e3933b400e6 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -64,6 +64,7 @@ "scipy": [""], "bottleneck": ["", null], "dask": [""], + "distributed": [""], }, diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index 54ed9ac9fa2..da18d541a16 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -1,11 +1,13 @@ from __future__ import absolute_import, division, print_function +import os + import numpy as np import pandas as pd import xarray as xr -from . import randn, randint, requires_dask +from . import randint, randn, requires_dask try: import dask @@ -14,6 +16,9 @@ pass +os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' + + class IOSingleNetCDF(object): """ A few examples that benchmark reading/writing a single netCDF file with @@ -405,3 +410,39 @@ def time_open_dataset_scipy_with_time_chunks(self): with dask.set_options(get=dask.multiprocessing.get): xr.open_mfdataset(self.filenames_list, engine='scipy', chunks=self.time_chunks) + + +def create_delayed_write(): + import dask.array as da + vals = da.random.random(300, chunks=(1,)) + ds = xr.Dataset({'vals': (['a'], vals)}) + return ds.to_netcdf('file.nc', engine='netcdf4', compute=False) + + +class IOWriteNetCDFDask(object): + timeout = 60 + repeat = 1 + number = 5 + + def setup(self): + requires_dask() + self.write = create_delayed_write() + + def time_write(self): + self.write.compute() + + +class IOWriteNetCDFDaskDistributed(object): + def setup(self): + try: + import distributed + except ImportError: + raise NotImplementedError + self.client = distributed.Client() + self.write = create_delayed_write() + + def cleanup(self): + self.client.shutdown() + + def time_write(self): + self.write.compute() diff --git a/asv_bench/benchmarks/unstacking.py b/asv_bench/benchmarks/unstacking.py new file mode 100644 index 00000000000..54436b422e9 --- /dev/null +++ b/asv_bench/benchmarks/unstacking.py @@ -0,0 +1,26 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + +import xarray as xr + +from . import requires_dask + + +class Unstacking(object): + def setup(self): + data = np.random.RandomState(0).randn(1, 1000, 500) + self.ds = xr.DataArray(data).stack(flat_dim=['dim_1', 'dim_2']) + + def time_unstack_fast(self): + self.ds.unstack('flat_dim') + + def time_unstack_slow(self): + self.ds[:, ::-1].unstack('flat_dim') + + +class UnstackingDask(Unstacking): + def setup(self, *args, **kwargs): + requires_dask() + super(UnstackingDask, self).setup(**kwargs) + self.ds = self.ds.chunk({'flat_dim': 50}) diff --git a/ci/requirements-py37.yml b/ci/requirements-py37.yml new file mode 100644 index 00000000000..5f973936f63 --- /dev/null +++ b/ci/requirements-py37.yml @@ -0,0 +1,13 @@ +name: test_env +channels: + - defaults +dependencies: + - python=3.7 + - pip: + - pytest + - flake8 + - mock + - numpy + - pandas + - coveralls + - pytest-cov diff --git a/doc/_static/numfocus_logo.png b/doc/_static/numfocus_logo.png new file mode 100644 index 00000000000..af3c84209e0 Binary files /dev/null and b/doc/_static/numfocus_logo.png differ diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 1826cc86892..0e8143c72ea 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -151,3 +151,5 @@ plot.FacetGrid.set_titles plot.FacetGrid.set_ticks plot.FacetGrid.map + + CFTimeIndex.shift diff --git a/doc/api.rst b/doc/api.rst index 927c0aa072c..662ef567710 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -150,6 +150,7 @@ Computation Dataset.resample Dataset.diff Dataset.quantile + Dataset.differentiate **Aggregation**: :py:attr:`~Dataset.all` @@ -317,6 +318,7 @@ Computation DataArray.diff DataArray.dot DataArray.quantile + DataArray.differentiate **Aggregation**: :py:attr:`~DataArray.all` @@ -555,6 +557,13 @@ Custom Indexes CFTimeIndex +Creating custom indexes +----------------------- +.. autosummary:: + :toctree: generated/ + + cftime_range + Plotting ======== @@ -615,3 +624,6 @@ arguments for the ``from_store`` and ``dump_to_store`` Dataset methods: backends.H5NetCDFStore backends.PydapDataStore backends.ScipyDataStore + backends.FileManager + backends.CachingFileManager + backends.DummyFileManager diff --git a/doc/computation.rst b/doc/computation.rst index 6793e667e06..759c87a6cc7 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -200,6 +200,31 @@ You can also use ``construct`` to compute a weighted rolling sum: To avoid this, use ``skipna=False`` as the above example. +Computation using Coordinates +============================= + +Xarray objects have some handy methods for the computation with their +coordinates. :py:meth:`~xarray.DataArray.differentiate` computes derivatives by +central finite differences using their coordinates, + +.. ipython:: python + + a = xr.DataArray([0, 1, 2, 3], dims=['x'], coords=[[0.1, 0.11, 0.2, 0.3]]) + a + a.differentiate('x') + +This method can be used also for multidimensional arrays, + +.. ipython:: python + + a = xr.DataArray(np.arange(8).reshape(4, 2), dims=['x', 'y'], + coords={'x': [0.1, 0.11, 0.2, 0.3]}) + a.differentiate('x') + +.. note:: + This method is limited to simple cartesian geometry. Differentiation along + multidimensional coordinate is not supported. + .. _compute.broadcasting: Broadcasting by dimension name diff --git a/doc/data-structures.rst b/doc/data-structures.rst index 10d83ca448f..618ccccff3e 100644 --- a/doc/data-structures.rst +++ b/doc/data-structures.rst @@ -408,13 +408,6 @@ operations keep around coordinates: list(ds[['x']]) list(ds.drop('temperature')) -If a dimension name is given as an argument to ``drop``, it also drops all -variables that use that dimension: - -.. ipython:: python - - list(ds.drop('time')) - As an alternate to dictionary-like modifications, you can use :py:meth:`~xarray.Dataset.assign` and :py:meth:`~xarray.Dataset.assign_coords`. These methods return a new dataset with additional (or replaced) or values: diff --git a/doc/index.rst b/doc/index.rst index e66c448f780..45897f4bccb 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -120,12 +120,20 @@ Get in touch .. _mailing list: https://groups.google.com/forum/#!forum/xarray .. _on GitHub: http://github.com/pydata/xarray -License -------- +NumFOCUS +-------- -xarray is available under the open source `Apache License`__. +.. image:: _static/numfocus_logo.png + :scale: 50 % + :target: https://numfocus.org/ + +Xarray is a fiscally sponsored project of NumFOCUS_, a nonprofit dedicated +to supporting the open source scientific computing community. If you like +Xarray and want to support our mission, please consider making a donation_ +to support our efforts. + +.. _donation: https://www.flipcause.com/secure/cause_pdetails/NDE2NTU= -__ http://www.apache.org/licenses/LICENSE-2.0.html History ------- @@ -133,6 +141,15 @@ History xarray is an evolution of an internal tool developed at `The Climate Corporation`__. It was originally written by Climate Corp researchers Stephan Hoyer, Alex Kleeman and Eugene Brevdo and was released as open source in -May 2014. The project was renamed from "xray" in January 2016. +May 2014. The project was renamed from "xray" in January 2016. Xarray became a +fiscally sponsored project of NumFOCUS_ in August 2018. __ http://climate.com/ +.. _NumFOCUS: https://numfocus.org + +License +------- + +xarray is available under the open source `Apache License`__. + +__ http://www.apache.org/licenses/LICENSE-2.0.html diff --git a/doc/installing.rst b/doc/installing.rst index 85cd5a02568..eb74eb7162b 100644 --- a/doc/installing.rst +++ b/doc/installing.rst @@ -6,7 +6,7 @@ Installation Required dependencies --------------------- -- Python 2.7 [1]_, 3.5, or 3.6 +- Python 2.7 [1]_, 3.5, 3.6, or 3.7 - `numpy `__ (1.12 or later) - `pandas `__ (0.19.2 or later) diff --git a/doc/interpolation.rst b/doc/interpolation.rst index e5230e95dae..10e46331d0a 100644 --- a/doc/interpolation.rst +++ b/doc/interpolation.rst @@ -63,6 +63,9 @@ by specifing the time periods required. da_dt64.interp(time=pd.date_range('1/1/2000', '1/3/2000', periods=3)) +Interpolation of data indexed by a :py:class:`~xarray.CFTimeIndex` is also +allowed. See :ref:`CFTimeIndex` for examples. + .. note:: Currently, our interpolation only works for regular grids. diff --git a/doc/related-projects.rst b/doc/related-projects.rst index 9b75d0e1b3e..524ea3b9d8d 100644 --- a/doc/related-projects.rst +++ b/doc/related-projects.rst @@ -35,7 +35,6 @@ Geosciences - `xgcm `_: Extends the xarray data model to understand finite volume grid cells (common in General Circulation Models) and provides interpolation and difference operations for such grids. - `xmitgcm `_: a python package for reading `MITgcm `_ binary MDS files into xarray data structures. - `xshape `_: Tools for working with shapefiles, topographies, and polygons in xarray. -- `xskillscore `_: Metrics for verifying forecasts. Machine Learning ~~~~~~~~~~~~~~~~ @@ -48,10 +47,11 @@ Extend xarray capabilities ~~~~~~~~~~~~~~~~~~~~~~~~~~ - `Collocate `_: Collocate xarray trajectories in arbitrary physical dimensions - `eofs `_: EOF analysis in Python. -- `xarray_extras `_: Advanced algorithms for xarray objects (e.g. intergrations/interpolations). +- `xarray_extras `_: Advanced algorithms for xarray objects (e.g. integrations/interpolations). - `xrft `_: Fourier transforms for xarray data. - `xr-scipy `_: A lightweight scipy wrapper for xarray. - `X-regression `_: Multiple linear regression from Statsmodels library coupled with Xarray library. +- `xskillscore `_: Metrics for verifying forecasts. - `xyzpy `_: Easily generate high dimensional data, including parallelization. Visualization diff --git a/doc/roadmap.rst b/doc/roadmap.rst index 2708cb7cf8f..34d203c3f48 100644 --- a/doc/roadmap.rst +++ b/doc/roadmap.rst @@ -1,3 +1,5 @@ +.. _roadmap: + Development roadmap =================== diff --git a/doc/time-series.rst b/doc/time-series.rst index a7ce9226d4d..c1a686b409f 100644 --- a/doc/time-series.rst +++ b/doc/time-series.rst @@ -70,9 +70,9 @@ You can manual decode arrays in this form by passing a dataset to One unfortunate limitation of using ``datetime64[ns]`` is that it limits the native representation of dates to those that fall between the years 1678 and 2262. When a netCDF file contains dates outside of these bounds, dates will be -returned as arrays of ``cftime.datetime`` objects and a ``CFTimeIndex`` -can be used for indexing. The ``CFTimeIndex`` enables only a subset of -the indexing functionality of a ``pandas.DatetimeIndex`` and is only enabled +returned as arrays of :py:class:`cftime.datetime` objects and a :py:class:`~xarray.CFTimeIndex` +can be used for indexing. The :py:class:`~xarray.CFTimeIndex` enables only a subset of +the indexing functionality of a :py:class:`pandas.DatetimeIndex` and is only enabled when using the standalone version of ``cftime`` (not the version packaged with earlier versions ``netCDF4``). See :ref:`CFTimeIndex` for more information. @@ -219,12 +219,12 @@ Non-standard calendars and dates outside the Timestamp-valid range ------------------------------------------------------------------ Through the standalone ``cftime`` library and a custom subclass of -``pandas.Index``, xarray supports a subset of the indexing functionality enabled -through the standard ``pandas.DatetimeIndex`` for dates from non-standard -calendars or dates using a standard calendar, but outside the -`Timestamp-valid range`_ (approximately between years 1678 and 2262). This -behavior has not yet been turned on by default; to take advantage of this -functionality, you must have the ``enable_cftimeindex`` option set to +:py:class:`pandas.Index`, xarray supports a subset of the indexing +functionality enabled through the standard :py:class:`pandas.DatetimeIndex` for +dates from non-standard calendars or dates using a standard calendar, but +outside the `Timestamp-valid range`_ (approximately between years 1678 and +2262). This behavior has not yet been turned on by default; to take advantage +of this functionality, you must have the ``enable_cftimeindex`` option set to ``True`` within your context (see :py:func:`~xarray.set_options` for more information). It is expected that this will become the default behavior in xarray version 0.11. @@ -232,7 +232,7 @@ xarray version 0.11. For instance, you can create a DataArray indexed by a time coordinate with a no-leap calendar within a context manager setting the ``enable_cftimeindex`` option, and the time index will be cast to a -``CFTimeIndex``: +:py:class:`~xarray.CFTimeIndex`: .. ipython:: python @@ -247,19 +247,28 @@ coordinate with a no-leap calendar within a context manager setting the .. note:: - With the ``enable_cftimeindex`` option activated, a ``CFTimeIndex`` + With the ``enable_cftimeindex`` option activated, a :py:class:`~xarray.CFTimeIndex` will be used for time indexing if any of the following are true: - The dates are from a non-standard calendar - Any dates are outside the Timestamp-valid range - Otherwise a ``pandas.DatetimeIndex`` will be used. In addition, if any + Otherwise a :py:class:`pandas.DatetimeIndex` will be used. In addition, if any variable (not just an index variable) is encoded using a non-standard - calendar, its times will be decoded into ``cftime.datetime`` objects, + calendar, its times will be decoded into :py:class:`cftime.datetime` objects, regardless of whether or not they can be represented using ``np.datetime64[ns]`` objects. - -For data indexed by a ``CFTimeIndex`` xarray currently supports: + +xarray also includes a :py:func:`~xarray.cftime_range` function, which enables +creating a :py:class:`~xarray.CFTimeIndex` with regularly-spaced dates. For instance, we can +create the same dates and DataArray we created above using: + +.. ipython:: python + + dates = xr.cftime_range(start='0001', periods=24, freq='MS', calendar='noleap') + da = xr.DataArray(np.arange(24), coords=[dates], dims=['time'], name='foo') + +For data indexed by a :py:class:`~xarray.CFTimeIndex` xarray currently supports: - `Partial datetime string indexing`_ using strictly `ISO 8601-format`_ partial datetime strings: @@ -285,7 +294,25 @@ For data indexed by a ``CFTimeIndex`` xarray currently supports: .. ipython:: python da.groupby('time.month').sum() - + +- Interpolation using :py:class:`cftime.datetime` objects: + +.. ipython:: python + + da.interp(time=[DatetimeNoLeap(1, 1, 15), DatetimeNoLeap(1, 2, 15)]) + +- Interpolation using datetime strings: + +.. ipython:: python + + da.interp(time=['0001-01-15', '0001-02-15']) + +- Differentiation: + +.. ipython:: python + + da.differentiate('time') + - And serialization: .. ipython:: python @@ -296,7 +323,7 @@ For data indexed by a ``CFTimeIndex`` xarray currently supports: .. note:: Currently resampling along the time dimension for data indexed by a - ``CFTimeIndex`` is not supported. + :py:class:`~xarray.CFTimeIndex` is not supported. .. _Timestamp-valid range: https://pandas.pydata.org/pandas-docs/stable/timeseries.html#timestamp-limitations .. _ISO 8601-format: https://en.wikipedia.org/wiki/ISO_8601 diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 718e68af04b..93083c23353 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,24 +25,117 @@ What's New - `Python 3 Statement `__ - `Tips on porting to Python 3 `__ -.. _whats-new.0.10.9: +.. _whats-new.0.11.0: -v0.10.9 (unreleased) +v0.11.0 (unreleased) -------------------- +Breaking changes +~~~~~~~~~~~~~~~~ + +- Xarray's storage backends now automatically open and close files when + necessary, rather than requiring opening a file with ``autoclose=True``. A + global least-recently-used cache is used to store open files; the default + limit of 128 open files should suffice in most cases, but can be adjusted if + necessary with + ``xarray.set_options(file_cache_maxsize=...)``. The ``autoclose`` argument + to ``open_dataset`` and related functions has been deprecated and is now a + no-op. + + This change, along with an internal refactor of xarray's storage backends, + should significantly improve performance when reading and writing + netCDF files with Dask, especially when working with many files or using + Dask Distributed. By `Stephan Hoyer `_ + Documentation ~~~~~~~~~~~~~ +- Reduction of :py:meth:`DataArray.groupby` and :py:meth:`DataArray.resample` + without dimension argument will change in the next release. + Now we warn a FutureWarning. + By `Keisuke Fujii `_. Enhancements ~~~~~~~~~~~~ -- :py:meth:`plot()` now accepts the kwargs ``xscale, yscale, xlim, ylim, xticks, yticks`` just like Pandas. Also ``xincrease=False, yincrease=False`` now use matplotlib's axis inverting methods instead of setting limits. +- Added support for Python 3.7. (:issue:`2271`). + By `Joe Hamman `_. +- Added :py:meth:`~xarray.CFTimeIndex.shift` for shifting the values of a + CFTimeIndex by a specified frequency. (:issue:`2244`). + By `Spencer Clark `_. +- Added support for using ``cftime.datetime`` coordinates with + :py:meth:`~xarray.DataArray.differentiate`, + :py:meth:`~xarray.Dataset.differentiate`, + :py:meth:`~xarray.DataArray.interp`, and + :py:meth:`~xarray.Dataset.interp`. + By `Spencer Clark `_ + +Bug fixes +~~~~~~~~~ + +- Addition and subtraction operators used with a CFTimeIndex now preserve the + index's type. (:issue:`2244`). + By `Spencer Clark `_. +- ``xarray.DataArray.roll`` correctly handles multidimensional arrays. + (:issue:`2445`) + By `Keisuke Fujii `_. +- ``xarray.plot()`` now properly accepts a ``norm`` argument and does not override + the norm's ``vmin`` and ``vmax``. (:issue:`2381`) + By `Deepak Cherian `_. +- ``xarray.DataArray.std()`` now correctly accepts ``ddof`` keyword argument. + (:issue:`2240`) + By `Keisuke Fujii `_. +- Restore matplotlib's default of plotting dashed negative contours when + a single color is passed to ``DataArray.contour()`` e.g. ``colors='k'``. + By `Deepak Cherian `_. + + +- Fix a bug that caused some indexing operations on arrays opened with + ``open_rasterio`` to error (:issue:`2454`). + By `Stephan Hoyer `_. + +.. _whats-new.0.10.9: + +v0.10.9 (21 September 2018) +--------------------------- + +This minor release contains a number of backwards compatible enhancements. + +Announcements of note: + +- Xarray is now a NumFOCUS fiscally sponsored project! Read + `the anouncement `_ + for more details. +- We have a new :doc:`roadmap` that outlines our future development plans. + +Enhancements +~~~~~~~~~~~~ + +- :py:meth:`~xarray.DataArray.differentiate` and + :py:meth:`~xarray.Dataset.differentiate` are newly added. + (:issue:`1332`) + By `Keisuke Fujii `_. +- Default colormap for sequential and divergent data can now be set via + :py:func:`~xarray.set_options()` + (:issue:`2394`) + By `Julius Busecke `_. + +- min_count option is newly supported in :py:meth:`~xarray.DataArray.sum`, + :py:meth:`~xarray.DataArray.prod` and :py:meth:`~xarray.Dataset.sum`, and + :py:meth:`~xarray.Dataset.prod`. + (:issue:`2230`) + By `Keisuke Fujii `_. + +- :py:meth:`plot()` now accepts the kwargs + ``xscale, yscale, xlim, ylim, xticks, yticks`` just like Pandas. Also ``xincrease=False, yincrease=False`` now use matplotlib's axis inverting methods instead of setting limits. By `Deepak Cherian `_. (:issue:`2224`) - DataArray coordinates and Dataset coordinates and data variables are now displayed as `a b ... y z` rather than `a b c d ...`. (:issue:`1186`) By `Seth P `_. +- A new CFTimeIndex-enabled :py:func:`cftime_range` function for use in + generating dates from standard or non-standard calendars. By `Spencer Clark + `_. - When interpolating over a ``datetime64`` axis, you can now provide a datetime string instead of a ``datetime64`` object. E.g. ``da.interp(time='1991-02-01')`` (:issue:`2284`) @@ -52,10 +145,30 @@ Enhancements (:issue:`2331`) By `Maximilian Roos `_. +- Applying ``unstack`` to a large DataArray or Dataset is now much faster if the MultiIndex has not been modified after stacking the indices. + (:issue:`1560`) + By `Maximilian Maahn `_. + +- You can now control whether or not to offset the coordinates when using + the ``roll`` method and the current behavior, coordinates rolled by default, + raises a deprecation warning unless explicitly setting the keyword argument. + (:issue:`1875`) + By `Andrew Huang `_. + +- You can now call ``unstack`` without arguments to unstack every MultiIndex in a DataArray or Dataset. + By `Julia Signell `_. + +- Added the ability to pass a data kwarg to ``copy`` to create a new object with the + same metadata as the original object but using new values. + By `Julia Signell `_. Bug fixes ~~~~~~~~~ +- ``xarray.plot.imshow()`` correctly uses the ``origin`` argument. + (:issue:`2379`) + By `Deepak Cherian `_. + - Fixed ``DataArray.to_iris()`` failure while creating ``DimCoord`` by falling back to creating ``AuxCoord``. Fixed dependency on ``var_name`` attribute being set. @@ -69,6 +182,9 @@ Bug fixes - Tests can be run in parallel with pytest-xdist By `Tony Tung `_. +- Follow up the renamings in dask; from dask.ghost to dask.overlap + By `Keisuke Fujii `_. + - Now raises a ValueError when there is a conflict between dimension names and level names of MultiIndex. (:issue:`2299`) By `Keisuke Fujii `_. @@ -77,10 +193,16 @@ Bug fixes By `Keisuke Fujii `_. - Now :py:func:`xr.apply_ufunc` raises a ValueError when the size of -``input_core_dims`` is inconsistent with the number of arguments. + ``input_core_dims`` is inconsistent with the number of arguments. (:issue:`2341`) By `Keisuke Fujii `_. +- Fixed ``Dataset.filter_by_attrs()`` behavior not matching ``netCDF4.Dataset.get_variables_by_attributes()``. + When more than one ``key=value`` is passed into ``Dataset.filter_by_attrs()`` it will now return a Dataset with variables which pass + all the filters. + (:issue:`2315`) + By `Andrew Barna `_. + .. _whats-new.0.10.8: v0.10.8 (18 July 2018) @@ -112,7 +234,6 @@ Enhancements :py:meth:`~xarray.DataArray.from_cdms2` (:issue:`2262`). By `Stephane Raynaud `_. - Bug fixes ~~~~~~~~~ diff --git a/properties/test_encode_decode.py b/properties/test_encode_decode.py index 8d84c0f6815..13f63f259cf 100644 --- a/properties/test_encode_decode.py +++ b/properties/test_encode_decode.py @@ -6,14 +6,15 @@ """ from __future__ import absolute_import, division, print_function -from hypothesis import given, settings -import hypothesis.strategies as st import hypothesis.extra.numpy as npst +import hypothesis.strategies as st +from hypothesis import given, settings import xarray as xr # Run for a while - arrays are a bigger search space than usual -settings.deadline = None +settings.register_profile("ci", deadline=None) +settings.load_profile("ci") an_array = npst.arrays( diff --git a/setup.py b/setup.py index 88c27c95118..a7519bac6da 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,8 @@ #!/usr/bin/env python import sys -from setuptools import find_packages, setup - import versioneer - +from setuptools import find_packages, setup DISTNAME = 'xarray' LICENSE = 'Apache' @@ -69,5 +67,6 @@ install_requires=INSTALL_REQUIRES, tests_require=TESTS_REQUIRE, url=URL, + python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*', packages=find_packages(), package_data={'xarray': ['tests/data/*']}) diff --git a/versioneer.py b/versioneer.py index 64fea1c8927..dffd66b69a6 100644 --- a/versioneer.py +++ b/versioneer.py @@ -277,10 +277,7 @@ """ from __future__ import print_function -try: - import configparser -except ImportError: - import ConfigParser as configparser + import errno import json import os @@ -288,6 +285,11 @@ import subprocess import sys +try: + import configparser +except ImportError: + import ConfigParser as configparser + class VersioneerConfig: """Container for Versioneer configuration parameters.""" diff --git a/xarray/__init__.py b/xarray/__init__.py index 7cc7811b783..59a961c6b56 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -10,7 +10,7 @@ from .core.alignment import align, broadcast, broadcast_arrays from .core.common import full_like, zeros_like, ones_like from .core.combine import concat, auto_combine -from .core.computation import apply_ufunc, where, dot +from .core.computation import apply_ufunc, dot, where from .core.extensions import (register_dataarray_accessor, register_dataset_accessor) from .core.variable import as_variable, Variable, IndexVariable, Coordinate @@ -26,6 +26,7 @@ from .conventions import decode_cf, SerializationWarning +from .coding.cftime_offsets import cftime_range from .coding.cftimeindex import CFTimeIndex from .util.print_versions import show_versions @@ -33,3 +34,5 @@ from . import tutorial from . import ufuncs from . import testing + +from .core.common import ALL_DIMS diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py index 47a2011a3af..a2f0d79a6d1 100644 --- a/xarray/backends/__init__.py +++ b/xarray/backends/__init__.py @@ -4,6 +4,7 @@ formats. They should not be used directly, but rather through Dataset objects. """ from .common import AbstractDataStore +from .file_manager import FileManager, CachingFileManager, DummyFileManager from .memory import InMemoryDataStore from .netCDF4_ import NetCDF4DataStore from .pydap_ import PydapDataStore @@ -15,6 +16,9 @@ __all__ = [ 'AbstractDataStore', + 'FileManager', + 'CachingFileManager', + 'DummyFileManager', 'InMemoryDataStore', 'NetCDF4DataStore', 'PydapDataStore', diff --git a/xarray/backends/api.py b/xarray/backends/api.py index b2c0df7b01b..65112527045 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -4,6 +4,7 @@ from glob import glob from io import BytesIO from numbers import Number +import warnings import numpy as np @@ -12,8 +13,9 @@ from ..core.combine import auto_combine from ..core.pycompat import basestring, path_type from ..core.utils import close_on_error, is_remote_uri -from .common import ( - HDF5_LOCK, ArrayWriter, CombinedLock, _get_scheduler, _get_scheduler_lock) +from .common import ArrayWriter +from .locks import _get_scheduler + DATAARRAY_NAME = '__xarray_dataarray_name__' DATAARRAY_VARIABLE = '__xarray_dataarray_variable__' @@ -52,27 +54,6 @@ def _normalize_path(path): return os.path.abspath(os.path.expanduser(path)) -def _default_lock(filename, engine): - if filename.endswith('.gz'): - lock = False - else: - if engine is None: - engine = _get_default_engine(filename, allow_remote=True) - - if engine == 'netcdf4': - if is_remote_uri(filename): - lock = False - else: - # TODO: identify netcdf3 files and don't use the global lock - # for them - lock = HDF5_LOCK - elif engine in {'h5netcdf', 'pynio'}: - lock = HDF5_LOCK - else: - lock = False - return lock - - def _validate_dataset_names(dataset): """DataArray.name and Dataset keys must be a string or None""" def check_name(name): @@ -130,29 +111,14 @@ def _protect_dataset_variables_inplace(dataset, cache): variable.data = data -def _get_lock(engine, scheduler, format, path_or_file): - """ Get the lock(s) that apply to a particular scheduler/engine/format""" - - locks = [] - if format in ['NETCDF4', None] and engine in ['h5netcdf', 'netcdf4']: - locks.append(HDF5_LOCK) - locks.append(_get_scheduler_lock(scheduler, path_or_file)) - - # When we have more than one lock, use the CombinedLock wrapper class - lock = CombinedLock(locks) if len(locks) > 1 else locks[0] - - return lock - - def _finalize_store(write, store): """ Finalize this store by explicitly syncing and closing""" del write # ensure writing is done first - store.sync() store.close() def open_dataset(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=None, decode_times=True, autoclose=False, + mask_and_scale=None, decode_times=True, autoclose=None, concat_characters=True, decode_coords=True, engine=None, chunks=None, lock=None, cache=None, drop_variables=None, backend_kwargs=None): @@ -204,12 +170,11 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, If chunks is provided, it used to load the new dataset into dask arrays. ``chunks={}`` loads the dataset with dask using a single chunk for all arrays. - lock : False, True or threading.Lock, optional - If chunks is provided, this argument is passed on to - :py:func:`dask.array.from_array`. By default, a global lock is - used when reading data from netCDF files with the netcdf4 and h5netcdf - engines to avoid issues with concurrent access when using dask's - multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- @@ -235,6 +200,14 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, -------- open_mfdataset """ + if autoclose is not None: + warnings.warn( + 'The autoclose argument is no longer used by ' + 'xarray.open_dataset() and is now ignored; it will be removed in ' + 'xarray v0.12. If necessary, you can control the maximum number ' + 'of simultaneous open files with ' + 'xarray.set_options(file_cache_maxsize=...).', + FutureWarning, stacklevel=2) if mask_and_scale is None: mask_and_scale = not engine == 'pseudonetcdf' @@ -272,18 +245,11 @@ def maybe_decode_store(store, lock=False): mask_and_scale, decode_times, concat_characters, decode_coords, engine, chunks, drop_variables) name_prefix = 'open_dataset-%s' % token - ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token, - lock=lock) + ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token) ds2._file_obj = ds._file_obj else: ds2 = ds - # protect so that dataset store isn't necessarily closed, e.g., - # streams like BytesIO can't be reopened - # datastore backend is responsible for determining this capability - if store._autoclose: - store.close() - return ds2 if isinstance(filename_or_obj, path_type): @@ -314,36 +280,28 @@ def maybe_decode_store(store, lock=False): engine = _get_default_engine(filename_or_obj, allow_remote=True) if engine == 'netcdf4': - store = backends.NetCDF4DataStore.open(filename_or_obj, - group=group, - autoclose=autoclose, - **backend_kwargs) + store = backends.NetCDF4DataStore.open( + filename_or_obj, group=group, lock=lock, **backend_kwargs) elif engine == 'scipy': - store = backends.ScipyDataStore(filename_or_obj, - autoclose=autoclose, - **backend_kwargs) + store = backends.ScipyDataStore(filename_or_obj, **backend_kwargs) elif engine == 'pydap': - store = backends.PydapDataStore.open(filename_or_obj, - **backend_kwargs) + store = backends.PydapDataStore.open( + filename_or_obj, **backend_kwargs) elif engine == 'h5netcdf': - store = backends.H5NetCDFStore(filename_or_obj, group=group, - autoclose=autoclose, - **backend_kwargs) + store = backends.H5NetCDFStore( + filename_or_obj, group=group, lock=lock, **backend_kwargs) elif engine == 'pynio': - store = backends.NioDataStore(filename_or_obj, - autoclose=autoclose, - **backend_kwargs) + store = backends.NioDataStore( + filename_or_obj, lock=lock, **backend_kwargs) elif engine == 'pseudonetcdf': store = backends.PseudoNetCDFDataStore.open( - filename_or_obj, autoclose=autoclose, **backend_kwargs) + filename_or_obj, lock=lock, **backend_kwargs) else: raise ValueError('unrecognized engine for open_dataset: %r' % engine) - if lock is None: - lock = _default_lock(filename_or_obj, engine) with close_on_error(store): - return maybe_decode_store(store, lock) + return maybe_decode_store(store) else: if engine is not None and engine != 'scipy': raise ValueError('can only read file-like objects with ' @@ -355,7 +313,7 @@ def maybe_decode_store(store, lock=False): def open_dataarray(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=None, decode_times=True, autoclose=False, + mask_and_scale=None, decode_times=True, autoclose=None, concat_characters=True, decode_coords=True, engine=None, chunks=None, lock=None, cache=None, drop_variables=None, backend_kwargs=None): @@ -390,10 +348,6 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, decode_times : bool, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. - autoclose : bool, optional - If True, automatically close files to avoid OS Error of too many files - being open. However, this option doesn't work with streams, e.g., - BytesIO. concat_characters : bool, optional If True, concatenate along the last dimension of character arrays to form string arrays. Dimensions will only be concatenated over (and @@ -409,12 +363,11 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, chunks : int or dict, optional If chunks is provided, it used to load the new dataset into dask arrays. - lock : False, True or threading.Lock, optional - If chunks is provided, this argument is passed on to - :py:func:`dask.array.from_array`. By default, a global lock is - used when reading data from netCDF files with the netcdf4 and h5netcdf - engines to avoid issues with concurrent access when using dask's - multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- @@ -490,7 +443,7 @@ def close(self): def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, compat='no_conflicts', preprocess=None, engine=None, lock=None, data_vars='all', coords='different', - autoclose=False, parallel=False, **kwargs): + autoclose=None, parallel=False, **kwargs): """Open multiple files as a single dataset. Requires dask to be installed. See documentation for details on dask [1]. @@ -537,15 +490,11 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4'. - autoclose : bool, optional - If True, automatically close files to avoid OS Error of too many files - being open. However, this option doesn't work with streams, e.g., - BytesIO. - lock : False, True or threading.Lock, optional - This argument is passed on to :py:func:`dask.array.from_array`. By - default, a per-variable lock is used when reading data from netCDF - files with the netcdf4 and h5netcdf engines to avoid issues with - concurrent access when using dask's multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. data_vars : {'minimal', 'different', 'all' or list of str}, optional These data variables will be concatenated together: * 'minimal': Only data variables in which the dimension already @@ -604,9 +553,6 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, if not paths: raise IOError('no files to open') - if lock is None: - lock = _default_lock(paths[0], engine) - open_kwargs = dict(engine=engine, chunks=chunks or {}, lock=lock, autoclose=autoclose, **kwargs) @@ -656,19 +602,21 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, - engine=None, writer=None, encoding=None, unlimited_dims=None, - compute=True): + engine=None, encoding=None, unlimited_dims=None, compute=True, + multifile=False): """This function creates an appropriate datastore for writing a dataset to disk as a netCDF file See `Dataset.to_netcdf` for full API docs. - The ``writer`` argument is only for the private use of save_mfdataset. + The ``multifile`` argument is only for the private use of save_mfdataset. """ if isinstance(path_or_file, path_type): path_or_file = str(path_or_file) + if encoding is None: encoding = {} + if path_or_file is None: if engine is None: engine = 'scipy' @@ -676,6 +624,10 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, raise ValueError('invalid engine for creating bytes with ' 'to_netcdf: %r. Only the default engine ' "or engine='scipy' is supported" % engine) + if not compute: + raise NotImplementedError( + 'to_netcdf() with compute=False is not yet implemented when ' + 'returning bytes') elif isinstance(path_or_file, basestring): if engine is None: engine = _get_default_engine(path_or_file) @@ -695,45 +647,78 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, if format is not None: format = format.upper() - # if a writer is provided, store asynchronously - sync = writer is None - # handle scheduler specific logic scheduler = _get_scheduler() have_chunks = any(v.chunks for v in dataset.variables.values()) - if (have_chunks and scheduler in ['distributed', 'multiprocessing'] and - engine != 'netcdf4'): + + autoclose = have_chunks and scheduler in ['distributed', 'multiprocessing'] + if autoclose and engine == 'scipy': raise NotImplementedError("Writing netCDF files with the %s backend " "is not currently supported with dask's %s " "scheduler" % (engine, scheduler)) - lock = _get_lock(engine, scheduler, format, path_or_file) - autoclose = (have_chunks and - scheduler in ['distributed', 'multiprocessing']) target = path_or_file if path_or_file is not None else BytesIO() - store = store_open(target, mode, format, group, writer, - autoclose=autoclose, lock=lock) + kwargs = dict(autoclose=True) if autoclose else {} + store = store_open(target, mode, format, group, **kwargs) if unlimited_dims is None: unlimited_dims = dataset.encoding.get('unlimited_dims', None) if isinstance(unlimited_dims, basestring): unlimited_dims = [unlimited_dims] + writer = ArrayWriter() + + # TODO: figure out how to refactor this logic (here and in save_mfdataset) + # to avoid this mess of conditionals try: - dataset.dump_to_store(store, sync=sync, encoding=encoding, - unlimited_dims=unlimited_dims, compute=compute) + # TODO: allow this work (setting up the file for writing array data) + # to be parallelized with dask + dump_to_store(dataset, store, writer, encoding=encoding, + unlimited_dims=unlimited_dims) + if autoclose: + store.close() + + if multifile: + return writer, store + + writes = writer.sync(compute=compute) + if path_or_file is None: + store.sync() return target.getvalue() finally: - if sync and isinstance(path_or_file, basestring): + if not multifile and compute: store.close() if not compute: import dask - return dask.delayed(_finalize_store)(store.delayed_store, store) + return dask.delayed(_finalize_store)(writes, store) + + +def dump_to_store(dataset, store, writer=None, encoder=None, + encoding=None, unlimited_dims=None): + """Store dataset contents to a backends.*DataStore object.""" + if writer is None: + writer = ArrayWriter() + + if encoding is None: + encoding = {} + + variables, attrs = conventions.encode_dataset_coordinates(dataset) + + check_encoding = set() + for k, enc in encoding.items(): + # no need to shallow copy the variable again; that already happened + # in encode_dataset_coordinates + variables[k].encoding = enc + check_encoding.add(k) + + if encoder: + variables, attrs = encoder(variables, attrs) + + store.store(variables, attrs, check_encoding, writer, + unlimited_dims=unlimited_dims) - if not sync: - return store def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, engine=None, compute=True): @@ -806,7 +791,7 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, for obj in datasets: if not isinstance(obj, Dataset): raise TypeError('save_mfdataset only supports writing Dataset ' - 'objects, recieved type %s' % type(obj)) + 'objects, received type %s' % type(obj)) if groups is None: groups = [None] * len(datasets) @@ -816,22 +801,22 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, 'datasets, paths and groups arguments to ' 'save_mfdataset') - writer = ArrayWriter() if compute else None - stores = [to_netcdf(ds, path, mode, format, group, engine, writer, - compute=compute) - for ds, path, group in zip(datasets, paths, groups)] - - if not compute: - import dask - return dask.delayed(stores) + writers, stores = zip(*[ + to_netcdf(ds, path, mode, format, group, engine, compute=compute, + multifile=True) + for ds, path, group in zip(datasets, paths, groups)]) try: - delayed = writer.sync(compute=compute) - for store in stores: - store.sync() + writes = [w.sync(compute=compute) for w in writers] finally: - for store in stores: - store.close() + if compute: + for store in stores: + store.close() + + if not compute: + import dask + return dask.delayed([dask.delayed(_finalize_store)(w, s) + for w, s in zip(writes, stores)]) def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None, @@ -852,13 +837,14 @@ def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None, store = backends.ZarrStore.open_group(store=store, mode=mode, synchronizer=synchronizer, - group=group, writer=None) + group=group) - # I think zarr stores should always be sync'd immediately + writer = ArrayWriter() # TODO: figure out how to properly handle unlimited_dims - dataset.dump_to_store(store, sync=True, encoding=encoding, compute=compute) + dump_to_store(dataset, store, writer, encoding=encoding) + writes = writer.sync(compute=compute) if not compute: import dask - return dask.delayed(_finalize_store)(store.delayed_store, store) + return dask.delayed(_finalize_store)(writes, store) return store diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 99f7698ee92..405d989f4af 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -1,14 +1,10 @@ from __future__ import absolute_import, division, print_function -import contextlib import logging -import multiprocessing -import threading import time import traceback import warnings from collections import Mapping, OrderedDict -from functools import partial import numpy as np @@ -17,13 +13,6 @@ from ..core.pycompat import dask_array_type, iteritems from ..core.utils import FrozenOrderedDict, NdimSizeLenMixin -# Import default lock -try: - from dask.utils import SerializableLock - HDF5_LOCK = SerializableLock() -except ImportError: - HDF5_LOCK = threading.Lock() - # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -31,62 +20,6 @@ NONE_VAR_NAME = '__values__' -def _get_scheduler(get=None, collection=None): - """ Determine the dask scheduler that is being used. - - None is returned if not dask scheduler is active. - - See also - -------- - dask.base.get_scheduler - """ - try: - # dask 0.18.1 and later - from dask.base import get_scheduler - actual_get = get_scheduler(get, collection) - except ImportError: - try: - from dask.utils import effective_get - actual_get = effective_get(get, collection) - except ImportError: - return None - - try: - from dask.distributed import Client - if isinstance(actual_get.__self__, Client): - return 'distributed' - except (ImportError, AttributeError): - try: - import dask.multiprocessing - if actual_get == dask.multiprocessing.get: - return 'multiprocessing' - else: - return 'threaded' - except ImportError: - return 'threaded' - - -def _get_scheduler_lock(scheduler, path_or_file=None): - """ Get the appropriate lock for a certain situation based onthe dask - scheduler used. - - See Also - -------- - dask.utils.get_scheduler_lock - """ - - if scheduler == 'distributed': - from dask.distributed import Lock - return Lock(path_or_file) - elif scheduler == 'multiprocessing': - return multiprocessing.Lock() - elif scheduler == 'threaded': - from dask.utils import SerializableLock - return SerializableLock() - else: - return threading.Lock() - - def _encode_variable_name(name): if name is None: name = NONE_VAR_NAME @@ -133,39 +66,6 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, time.sleep(1e-3 * next_delay) -class CombinedLock(object): - """A combination of multiple locks. - - Like a locked door, a CombinedLock is locked if any of its constituent - locks are locked. - """ - - def __init__(self, locks): - self.locks = tuple(set(locks)) # remove duplicates - - def acquire(self, *args): - return all(lock.acquire(*args) for lock in self.locks) - - def release(self, *args): - for lock in self.locks: - lock.release(*args) - - def __enter__(self): - for lock in self.locks: - lock.__enter__() - - def __exit__(self, *args): - for lock in self.locks: - lock.__exit__(*args) - - @property - def locked(self): - return any(lock.locked for lock in self.locks) - - def __repr__(self): - return "CombinedLock(%r)" % list(self.locks) - - class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): def __array__(self, dtype=None): @@ -174,9 +74,6 @@ def __array__(self, dtype=None): class AbstractDataStore(Mapping): - _autoclose = None - _ds = None - _isopen = False def __iter__(self): return iter(self.variables) @@ -259,7 +156,7 @@ def __exit__(self, exception_type, exception_value, traceback): class ArrayWriter(object): - def __init__(self, lock=HDF5_LOCK): + def __init__(self, lock=None): self.sources = [] self.targets = [] self.lock = lock @@ -274,6 +171,9 @@ def add(self, source, target): def sync(self, compute=True): if self.sources: import dask.array as da + # TODO: consider wrapping targets with dask.delayed, if this makes + # for any discernable difference in perforance, e.g., + # targets = [dask.delayed(t) for t in self.targets] delayed_store = da.store(self.sources, self.targets, lock=self.lock, compute=compute, flush=True) @@ -283,11 +183,6 @@ def sync(self, compute=True): class AbstractWritableDataStore(AbstractDataStore): - def __init__(self, writer=None, lock=HDF5_LOCK): - if writer is None: - writer = ArrayWriter(lock=lock) - self.writer = writer - self.delayed_store = None def encode(self, variables, attributes): """ @@ -329,12 +224,6 @@ def set_attribute(self, k, v): # pragma: no cover def set_variable(self, k, v): # pragma: no cover raise NotImplementedError - def sync(self, compute=True): - if self._isopen and self._autoclose: - # datastore will be reopened during write - self.close() - self.delayed_store = self.writer.sync(compute=compute) - def store_dataset(self, dataset): """ in stores, variables are all variables AND coordinates @@ -345,7 +234,7 @@ def store_dataset(self, dataset): self.store(dataset, dataset.attrs) def store(self, variables, attributes, check_encoding_set=frozenset(), - unlimited_dims=None): + writer=None, unlimited_dims=None): """ Top level method for putting data on this store, this method: - encodes variables/attributes @@ -361,16 +250,19 @@ def store(self, variables, attributes, check_encoding_set=frozenset(), check_encoding_set : list-like List of variables that should be checked for invalid encoding values + writer : ArrayWriter unlimited_dims : list-like List of dimension names that should be treated as unlimited dimensions. """ + if writer is None: + writer = ArrayWriter() variables, attributes = self.encode(variables, attributes) self.set_attributes(attributes) self.set_dimensions(variables, unlimited_dims=unlimited_dims) - self.set_variables(variables, check_encoding_set, + self.set_variables(variables, check_encoding_set, writer, unlimited_dims=unlimited_dims) def set_attributes(self, attributes): @@ -386,7 +278,7 @@ def set_attributes(self, attributes): for k, v in iteritems(attributes): self.set_attribute(k, v) - def set_variables(self, variables, check_encoding_set, + def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None): """ This provides a centralized method to set the variables on the data @@ -399,6 +291,7 @@ def set_variables(self, variables, check_encoding_set, check_encoding_set : list-like List of variables that should be checked for invalid encoding values + writer : ArrayWriter unlimited_dims : list-like List of dimension names that should be treated as unlimited dimensions. @@ -410,7 +303,7 @@ def set_variables(self, variables, check_encoding_set, target, source = self.prepare_variable( name, v, check, unlimited_dims=unlimited_dims) - self.writer.add(source, target) + writer.add(source, target) def set_dimensions(self, variables, unlimited_dims=None): """ @@ -457,87 +350,3 @@ def encode(self, variables, attributes): attributes = OrderedDict([(k, self.encode_attribute(v)) for k, v in attributes.items()]) return variables, attributes - - -class DataStorePickleMixin(object): - """Subclasses must define `ds`, `_opener` and `_mode` attributes. - - Do not subclass this class: it is not part of xarray's external API. - """ - - def __getstate__(self): - state = self.__dict__.copy() - del state['_ds'] - del state['_isopen'] - if self._mode == 'w': - # file has already been created, don't override when restoring - state['_mode'] = 'a' - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self._ds = None - self._isopen = False - - @property - def ds(self): - if self._ds is not None and self._isopen: - return self._ds - ds = self._opener(mode=self._mode) - self._isopen = True - return ds - - @contextlib.contextmanager - def ensure_open(self, autoclose=None): - """ - Helper function to make sure datasets are closed and opened - at appropriate times to avoid too many open file errors. - - Use requires `autoclose=True` argument to `open_mfdataset`. - """ - - if autoclose is None: - autoclose = self._autoclose - - if not self._isopen: - try: - self._ds = self._opener() - self._isopen = True - yield - finally: - if autoclose: - self.close() - else: - yield - - def assert_open(self): - if not self._isopen: - raise AssertionError('internal failure: file must be open ' - 'if `autoclose=True` is used.') - - -class PickleByReconstructionWrapper(object): - - def __init__(self, opener, file, mode='r', **kwargs): - self.opener = partial(opener, file, mode=mode, **kwargs) - self.mode = mode - self._ds = None - - @property - def value(self): - self._ds = self.opener() - return self._ds - - def __getstate__(self): - state = self.__dict__.copy() - del state['_ds'] - if self.mode == 'w': - # file has already been created, don't override when restoring - state['mode'] = 'a' - return state - - def __setstate__(self, state): - self.__dict__.update(state) - - def close(self): - self._ds.close() diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py new file mode 100644 index 00000000000..a93285370b2 --- /dev/null +++ b/xarray/backends/file_manager.py @@ -0,0 +1,206 @@ +import threading + +from ..core import utils +from ..core.options import OPTIONS +from .lru_cache import LRUCache + + +# Global cache for storing open files. +FILE_CACHE = LRUCache( + OPTIONS['file_cache_maxsize'], on_evict=lambda k, v: v.close()) +assert FILE_CACHE.maxsize, 'file cache must be at least size one' + + +_DEFAULT_MODE = utils.ReprObject('') + + +class FileManager(object): + """Manager for acquiring and closing a file object. + + Use FileManager subclasses (CachingFileManager in particular) on backend + storage classes to automatically handle issues related to keeping track of + many open files and transferring them between multiple processes. + """ + + def acquire(self): + """Acquire the file object from this manager.""" + raise NotImplementedError + + def close(self, needs_lock=True): + """Close the file object associated with this manager, if needed.""" + raise NotImplementedError + + +class CachingFileManager(FileManager): + """Wrapper for automatically opening and closing file objects. + + Unlike files, CachingFileManager objects can be safely pickled and passed + between processes. They should be explicitly closed to release resources, + but a per-process least-recently-used cache for open files ensures that you + can safely create arbitrarily large numbers of FileManager objects. + + Don't directly close files acquired from a FileManager. Instead, call + FileManager.close(), which ensures that closed files are removed from the + cache as well. + + Example usage: + + manager = FileManager(open, 'example.txt', mode='w') + f = manager.acquire() + f.write(...) + manager.close() # ensures file is closed + + Note that as long as previous files are still cached, acquiring a file + multiple times from the same FileManager is essentially free: + + f1 = manager.acquire() + f2 = manager.acquire() + assert f1 is f2 + + """ + + def __init__(self, opener, *args, **keywords): + """Initialize a FileManager. + + Parameters + ---------- + opener : callable + Function that when called like ``opener(*args, **kwargs)`` returns + an open file object. The file object must implement a ``close()`` + method. + *args + Positional arguments for opener. A ``mode`` argument should be + provided as a keyword argument (see below). All arguments must be + hashable. + mode : optional + If provided, passed as a keyword argument to ``opener`` along with + ``**kwargs``. ``mode='w' `` has special treatment: after the first + call it is replaced by ``mode='a'`` in all subsequent function to + avoid overriding the newly created file. + kwargs : dict, optional + Keyword arguments for opener, excluding ``mode``. All values must + be hashable. + lock : duck-compatible threading.Lock, optional + Lock to use when modifying the cache inside acquire() and close(). + By default, uses a new threading.Lock() object. If set, this object + should be pickleable. + cache : MutableMapping, optional + Mapping to use as a cache for open files. By default, uses xarray's + global LRU file cache. Because ``cache`` typically points to a + global variable and contains non-picklable file objects, an + unpickled FileManager objects will be restored with the default + cache. + """ + # TODO: replace with real keyword arguments when we drop Python 2 + # support + mode = keywords.pop('mode', _DEFAULT_MODE) + kwargs = keywords.pop('kwargs', None) + lock = keywords.pop('lock', None) + cache = keywords.pop('cache', FILE_CACHE) + if keywords: + raise TypeError('FileManager() got unexpected keyword arguments: ' + '%s' % list(keywords)) + + self._opener = opener + self._args = args + self._mode = mode + self._kwargs = {} if kwargs is None else dict(kwargs) + self._default_lock = lock is None or lock is False + self._lock = threading.Lock() if self._default_lock else lock + self._cache = cache + self._key = self._make_key() + + def _make_key(self): + """Make a key for caching files in the LRU cache.""" + value = (self._opener, + self._args, + self._mode, + tuple(sorted(self._kwargs.items()))) + return _HashedSequence(value) + + def acquire(self): + """Acquiring a file object from the manager. + + A new file is only opened if it has expired from the + least-recently-used cache. + + This method uses a reentrant lock, which ensures that it is + thread-safe. You can safely acquire a file in multiple threads at the + same time, as long as the underlying file object is thread-safe. + + Returns + ------- + An open file object, as returned by ``opener(*args, **kwargs)``. + """ + with self._lock: + try: + file = self._cache[self._key] + except KeyError: + kwargs = self._kwargs + if self._mode is not _DEFAULT_MODE: + kwargs = kwargs.copy() + kwargs['mode'] = self._mode + file = self._opener(*self._args, **kwargs) + if self._mode == 'w': + # ensure file doesn't get overriden when opened again + self._mode = 'a' + self._key = self._make_key() + self._cache[self._key] = file + return file + + def _close(self): + default = None + file = self._cache.pop(self._key, default) + if file is not None: + file.close() + + def close(self, needs_lock=True): + """Explicitly close any associated file object (if necessary).""" + # TODO: remove needs_lock if/when we have a reentrant lock in + # dask.distributed: https://github.com/dask/dask/issues/3832 + if needs_lock: + with self._lock: + self._close() + else: + self._close() + + def __getstate__(self): + """State for pickling.""" + lock = None if self._default_lock else self._lock + return (self._opener, self._args, self._mode, self._kwargs, lock) + + def __setstate__(self, state): + """Restore from a pickle.""" + opener, args, mode, kwargs, lock = state + self.__init__(opener, *args, mode=mode, kwargs=kwargs, lock=lock) + + +class _HashedSequence(list): + """Speedup repeated look-ups by caching hash values. + + Based on what Python uses internally in functools.lru_cache. + + Python doesn't perform this optimization automatically: + https://bugs.python.org/issue1462796 + """ + + def __init__(self, tuple_value): + self[:] = tuple_value + self.hashvalue = hash(tuple_value) + + def __hash__(self): + return self.hashvalue + + +class DummyFileManager(FileManager): + """FileManager that simply wraps an open file in the FileManager interface. + """ + def __init__(self, value): + self._value = value + + def acquire(self): + return self._value + + def close(self, needs_lock=True): + del needs_lock # ignored + self._value.close() diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 959cd221734..59cd4e84793 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -8,11 +8,12 @@ from ..core import indexing from ..core.pycompat import OrderedDict, bytes_type, iteritems, unicode_type from ..core.utils import FrozenOrderedDict, close_on_error -from .common import ( - HDF5_LOCK, DataStorePickleMixin, WritableCFDataStore, find_root) +from .common import WritableCFDataStore +from .file_manager import CachingFileManager +from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from .netCDF4_ import ( - BaseNetCDF4Array, _encode_nc4_variable, _extract_nc4_variable_encoding, - _get_datatype, _nc4_require_group) + BaseNetCDF4Array, GroupWrapper, _encode_nc4_variable, + _extract_nc4_variable_encoding, _get_datatype, _nc4_require_group) class H5NetCDFArrayWrapper(BaseNetCDF4Array): @@ -25,8 +26,9 @@ def _getitem(self, key): # h5py requires using lists for fancy indexing: # https://github.com/h5py/h5py/issues/992 key = tuple(list(k) if isinstance(k, np.ndarray) else k for k in key) - with self.datastore.ensure_open(autoclose=True): - return self.get_array()[key] + array = self.get_array() + with self.datastore.lock: + return array[key] def maybe_decode_bytes(txt): @@ -61,104 +63,102 @@ def _open_h5netcdf_group(filename, mode, group): import h5netcdf ds = h5netcdf.File(filename, mode=mode) with close_on_error(ds): - return _nc4_require_group( + ds = _nc4_require_group( ds, group, mode, create_group=_h5netcdf_create_group) + return GroupWrapper(ds) -class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin): +class H5NetCDFStore(WritableCFDataStore): """Store for reading and writing data via h5netcdf """ def __init__(self, filename, mode='r', format=None, group=None, - writer=None, autoclose=False, lock=HDF5_LOCK): + lock=None, autoclose=False): if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') - opener = functools.partial(_open_h5netcdf_group, filename, mode=mode, - group=group) - self._ds = opener() - if autoclose: - raise NotImplementedError('autoclose=True is not implemented ' - 'for the h5netcdf backend pending ' - 'further exploration, e.g., bug fixes ' - '(in h5netcdf?)') - self._autoclose = False - self._isopen = True + self._manager = CachingFileManager( + _open_h5netcdf_group, filename, mode=mode, + kwargs=dict(group=group)) + + if lock is None: + if mode == 'r': + lock = HDF5_LOCK + else: + lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) + self.format = format - self._opener = opener self._filename = filename self._mode = mode - super(H5NetCDFStore, self).__init__(writer, lock=lock) + self.lock = ensure_lock(lock) + self.autoclose = autoclose + + @property + def ds(self): + return self._manager.acquire().value def open_store_variable(self, name, var): import h5py - with self.ensure_open(autoclose=False): - dimensions = var.dimensions - data = indexing.LazilyOuterIndexedArray( - H5NetCDFArrayWrapper(name, self)) - attrs = _read_attributes(var) - - # netCDF4 specific encoding - encoding = { - 'chunksizes': var.chunks, - 'fletcher32': var.fletcher32, - 'shuffle': var.shuffle, - } - # Convert h5py-style compression options to NetCDF4-Python - # style, if possible - if var.compression == 'gzip': - encoding['zlib'] = True - encoding['complevel'] = var.compression_opts - elif var.compression is not None: - encoding['compression'] = var.compression - encoding['compression_opts'] = var.compression_opts - - # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape - - vlen_dtype = h5py.check_dtype(vlen=var.dtype) - if vlen_dtype is unicode_type: - encoding['dtype'] = str - elif vlen_dtype is not None: # pragma: no cover - # xarray doesn't support writing arbitrary vlen dtypes yet. - pass - else: - encoding['dtype'] = var.dtype + dimensions = var.dimensions + data = indexing.LazilyOuterIndexedArray( + H5NetCDFArrayWrapper(name, self)) + attrs = _read_attributes(var) + + # netCDF4 specific encoding + encoding = { + 'chunksizes': var.chunks, + 'fletcher32': var.fletcher32, + 'shuffle': var.shuffle, + } + # Convert h5py-style compression options to NetCDF4-Python + # style, if possible + if var.compression == 'gzip': + encoding['zlib'] = True + encoding['complevel'] = var.compression_opts + elif var.compression is not None: + encoding['compression'] = var.compression + encoding['compression_opts'] = var.compression_opts + + # save source so __repr__ can detect if it's local or not + encoding['source'] = self._filename + encoding['original_shape'] = var.shape + + vlen_dtype = h5py.check_dtype(vlen=var.dtype) + if vlen_dtype is unicode_type: + encoding['dtype'] = str + elif vlen_dtype is not None: # pragma: no cover + # xarray doesn't support writing arbitrary vlen dtypes yet. + pass + else: + encoding['dtype'] = var.dtype return Variable(dimensions, data, attrs, encoding) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in iteritems(self.ds.variables)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in iteritems(self.ds.variables)) def get_attrs(self): - with self.ensure_open(autoclose=True): - return FrozenOrderedDict(_read_attributes(self.ds)) + return FrozenOrderedDict(_read_attributes(self.ds)) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return self.ds.dimensions + return self.ds.dimensions def get_encoding(self): - with self.ensure_open(autoclose=True): - encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v is None} + encoding = {} + encoding['unlimited_dims'] = { + k for k, v in self.ds.dimensions.items() if v is None} return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - if is_unlimited: - self.ds.dimensions[name] = None - self.ds.resize_dimension(name, length) - else: - self.ds.dimensions[name] = length + if is_unlimited: + self.ds.dimensions[name] = None + self.ds.resize_dimension(name, length) + else: + self.ds.dimensions[name] = length def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - self.ds.attrs[key] = value + self.ds.attrs[key] = value def encode_variable(self, variable): return _encode_nc4_variable(variable) @@ -226,18 +226,11 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data - def sync(self, compute=True): - if not compute: - raise NotImplementedError( - 'compute=False is not supported for the h5netcdf backend yet') - with self.ensure_open(autoclose=True): - super(H5NetCDFStore, self).sync(compute=compute) - self.ds.sync() - - def close(self): - if self._isopen: - # netCDF4 only allows closing the root group - ds = find_root(self.ds) - if not ds._closed: - ds.close() - self._isopen = False + def sync(self): + self.ds.sync() + # if self.autoclose: + # self.close() + # super(H5NetCDFStore, self).sync(compute=compute) + + def close(self, **kwargs): + self._manager.close(**kwargs) diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py new file mode 100644 index 00000000000..f633280ef1d --- /dev/null +++ b/xarray/backends/locks.py @@ -0,0 +1,191 @@ +import multiprocessing +import threading +import weakref + +try: + from dask.utils import SerializableLock +except ImportError: + # no need to worry about serializing the lock + SerializableLock = threading.Lock + + +# Locks used by multiple backends. +# Neither HDF5 nor the netCDF-C library are thread-safe. +HDF5_LOCK = SerializableLock() +NETCDFC_LOCK = SerializableLock() + + +_FILE_LOCKS = weakref.WeakValueDictionary() + + +def _get_threaded_lock(key): + try: + lock = _FILE_LOCKS[key] + except KeyError: + lock = _FILE_LOCKS[key] = threading.Lock() + return lock + + +def _get_multiprocessing_lock(key): + # TODO: make use of the key -- maybe use locket.py? + # https://github.com/mwilliamson/locket.py + del key # unused + return multiprocessing.Lock() + + +def _get_distributed_lock(key): + from dask.distributed import Lock + return Lock(key) + + +_LOCK_MAKERS = { + None: _get_threaded_lock, + 'threaded': _get_threaded_lock, + 'multiprocessing': _get_multiprocessing_lock, + 'distributed': _get_distributed_lock, +} + + +def _get_lock_maker(scheduler=None): + """Returns an appropriate function for creating resource locks. + + Parameters + ---------- + scheduler : str or None + Dask scheduler being used. + + See Also + -------- + dask.utils.get_scheduler_lock + """ + return _LOCK_MAKERS[scheduler] + + +def _get_scheduler(get=None, collection=None): + """Determine the dask scheduler that is being used. + + None is returned if no dask scheduler is active. + + See also + -------- + dask.base.get_scheduler + """ + try: + # dask 0.18.1 and later + from dask.base import get_scheduler + actual_get = get_scheduler(get, collection) + except ImportError: + try: + from dask.utils import effective_get + actual_get = effective_get(get, collection) + except ImportError: + return None + + try: + from dask.distributed import Client + if isinstance(actual_get.__self__, Client): + return 'distributed' + except (ImportError, AttributeError): + try: + import dask.multiprocessing + if actual_get == dask.multiprocessing.get: + return 'multiprocessing' + else: + return 'threaded' + except ImportError: + return 'threaded' + + +def get_write_lock(key): + """Get a scheduler appropriate lock for writing to the given resource. + + Parameters + ---------- + key : str + Name of the resource for which to acquire a lock. Typically a filename. + + Returns + ------- + Lock object that can be used like a threading.Lock object. + """ + scheduler = _get_scheduler() + lock_maker = _get_lock_maker(scheduler) + return lock_maker(key) + + +class CombinedLock(object): + """A combination of multiple locks. + + Like a locked door, a CombinedLock is locked if any of its constituent + locks are locked. + """ + + def __init__(self, locks): + self.locks = tuple(set(locks)) # remove duplicates + + def acquire(self, *args): + return all(lock.acquire(*args) for lock in self.locks) + + def release(self, *args): + for lock in self.locks: + lock.release(*args) + + def __enter__(self): + for lock in self.locks: + lock.__enter__() + + def __exit__(self, *args): + for lock in self.locks: + lock.__exit__(*args) + + @property + def locked(self): + return any(lock.locked for lock in self.locks) + + def __repr__(self): + return "CombinedLock(%r)" % list(self.locks) + + +class DummyLock(object): + """DummyLock provides the lock API without any actual locking.""" + + def acquire(self, *args): + pass + + def release(self, *args): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + @property + def locked(self): + return False + + +def combine_locks(locks): + """Combine a sequence of locks into a single lock.""" + all_locks = [] + for lock in locks: + if isinstance(lock, CombinedLock): + all_locks.extend(lock.locks) + elif lock is not None: + all_locks.append(lock) + + num_locks = len(all_locks) + if num_locks > 1: + return CombinedLock(all_locks) + elif num_locks == 1: + return all_locks[0] + else: + return DummyLock() + + +def ensure_lock(lock): + """Ensure that the given object is a lock.""" + if lock is None or lock is False: + return DummyLock() + return lock diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py new file mode 100644 index 00000000000..321a1ca4da4 --- /dev/null +++ b/xarray/backends/lru_cache.py @@ -0,0 +1,91 @@ +import collections +import threading + +from ..core.pycompat import move_to_end + + +class LRUCache(collections.MutableMapping): + """Thread-safe LRUCache based on an OrderedDict. + + All dict operations (__getitem__, __setitem__, __contains__) update the + priority of the relevant key and take O(1) time. The dict is iterated over + in order from the oldest to newest key, which means that a complete pass + over the dict should not affect the order of any entries. + + When a new item is set and the maximum size of the cache is exceeded, the + oldest item is dropped and called with ``on_evict(key, value)``. + + The ``maxsize`` property can be used to view or adjust the capacity of + the cache, e.g., ``cache.maxsize = new_size``. + """ + def __init__(self, maxsize, on_evict=None): + """ + Parameters + ---------- + maxsize : int + Integer maximum number of items to hold in the cache. + on_evict: callable, optional + Function to call like ``on_evict(key, value)`` when items are + evicted. + """ + if not isinstance(maxsize, int): + raise TypeError('maxsize must be an integer') + if maxsize < 0: + raise ValueError('maxsize must be non-negative') + self._maxsize = maxsize + self._on_evict = on_evict + self._cache = collections.OrderedDict() + self._lock = threading.RLock() + + def __getitem__(self, key): + # record recent use of the key by moving it to the front of the list + with self._lock: + value = self._cache[key] + move_to_end(self._cache, key) + return value + + def _enforce_size_limit(self, capacity): + """Shrink the cache if necessary, evicting the oldest items.""" + while len(self._cache) > capacity: + key, value = self._cache.popitem(last=False) + if self._on_evict is not None: + self._on_evict(key, value) + + def __setitem__(self, key, value): + with self._lock: + if key in self._cache: + # insert the new value at the end + del self._cache[key] + self._cache[key] = value + elif self._maxsize: + # make room if necessary + self._enforce_size_limit(self._maxsize - 1) + self._cache[key] = value + elif self._on_evict is not None: + # not saving, immediately evict + self._on_evict(key, value) + + def __delitem__(self, key): + del self._cache[key] + + def __iter__(self): + # create a list, so accessing the cache during iteration cannot change + # the iteration order + return iter(list(self._cache)) + + def __len__(self): + return len(self._cache) + + @property + def maxsize(self): + """Maximum number of items can be held in the cache.""" + return self._maxsize + + @maxsize.setter + def maxsize(self, size): + """Resize the cache, evicting the oldest items if necessary.""" + if size < 0: + raise ValueError('maxsize must be non-negative') + with self._lock: + self._enforce_size_limit(size) + self._maxsize = size diff --git a/xarray/backends/memory.py b/xarray/backends/memory.py index dcf092557b8..195d4647534 100644 --- a/xarray/backends/memory.py +++ b/xarray/backends/memory.py @@ -17,10 +17,9 @@ class InMemoryDataStore(AbstractWritableDataStore): This store exists purely for internal testing purposes. """ - def __init__(self, variables=None, attributes=None, writer=None): + def __init__(self, variables=None, attributes=None): self._variables = OrderedDict() if variables is None else variables self._attributes = OrderedDict() if attributes is None else attributes - super(InMemoryDataStore, self).__init__(writer) def get_attrs(self): return self._attributes diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 5c6d82fd126..08ba085b77e 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -10,12 +10,13 @@ from .. import Variable, coding from ..coding.variables import pop_to from ..core import indexing -from ..core.pycompat import ( - PY3, OrderedDict, basestring, iteritems, suppress) +from ..core.pycompat import PY3, OrderedDict, basestring, iteritems, suppress from ..core.utils import FrozenOrderedDict, close_on_error, is_remote_uri from .common import ( - HDF5_LOCK, BackendArray, DataStorePickleMixin, WritableCFDataStore, - find_root, robust_getitem) + BackendArray, WritableCFDataStore, find_root, robust_getitem) +from .locks import (NETCDFC_LOCK, HDF5_LOCK, + combine_locks, ensure_lock, get_write_lock) +from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable # This lookup table maps from dtype.byteorder to a readable endian @@ -26,6 +27,9 @@ '|': 'native'} +NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK]) + + class BaseNetCDF4Array(BackendArray): def __init__(self, variable_name, datastore): self.datastore = datastore @@ -43,12 +47,13 @@ def __init__(self, variable_name, datastore): self.dtype = dtype def __setitem__(self, key, value): - with self.datastore.ensure_open(autoclose=True): + with self.datastore.lock: data = self.get_array() data[key] = value + if self.datastore.autoclose: + self.datastore.close(needs_lock=False) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] @@ -64,20 +69,22 @@ def _getitem(self, key): else: getitem = operator.getitem - with self.datastore.ensure_open(autoclose=True): - try: - array = getitem(self.get_array(), key) - except IndexError: - # Catch IndexError in netCDF4 and return a more informative - # error message. This is most often called when an unsorted - # indexer is used before the data is loaded from disk. - msg = ('The indexing operation you are attempting to perform ' - 'is not valid on netCDF4.Variable object. Try loading ' - 'your data into memory first by calling .load().') - if not PY3: - import traceback - msg += '\n\nOriginal traceback:\n' + traceback.format_exc() - raise IndexError(msg) + original_array = self.get_array() + + try: + with self.datastore.lock: + array = getitem(original_array, key) + except IndexError: + # Catch IndexError in netCDF4 and return a more informative + # error message. This is most often called when an unsorted + # indexer is used before the data is loaded from disk. + msg = ('The indexing operation you are attempting to perform ' + 'is not valid on netCDF4.Variable object. Try loading ' + 'your data into memory first by calling .load().') + if not PY3: + import traceback + msg += '\n\nOriginal traceback:\n' + traceback.format_exc() + raise IndexError(msg) return array @@ -224,7 +231,17 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, return encoding -def _open_netcdf4_group(filename, mode, group=None, **kwargs): +class GroupWrapper(object): + """Wrap netCDF4.Group objects so closing them closes the root group.""" + def __init__(self, value): + self.value = value + + def close(self): + # netCDF4 only allows closing the root group + find_root(self.value).close() + + +def _open_netcdf4_group(filename, lock, mode, group=None, **kwargs): import netCDF4 as nc4 ds = nc4.Dataset(filename, mode=mode, **kwargs) @@ -234,7 +251,7 @@ def _open_netcdf4_group(filename, mode, group=None, **kwargs): _disable_auto_decode_group(ds) - return ds + return GroupWrapper(ds) def _disable_auto_decode_variable(var): @@ -280,40 +297,33 @@ def _set_nc_attribute(obj, key, value): obj.setncattr(key, value) -class NetCDF4DataStore(WritableCFDataStore, DataStorePickleMixin): +class NetCDF4DataStore(WritableCFDataStore): """Store for reading and writing data via the Python-NetCDF4 library. This store supports NetCDF3, NetCDF4 and OpenDAP datasets. """ - def __init__(self, netcdf4_dataset, mode='r', writer=None, opener=None, - autoclose=False, lock=HDF5_LOCK): - - if autoclose and opener is None: - raise ValueError('autoclose requires an opener') + def __init__(self, manager, lock=NETCDF4_PYTHON_LOCK, autoclose=False): + import netCDF4 - _disable_auto_decode_group(netcdf4_dataset) + if isinstance(manager, netCDF4.Dataset): + _disable_auto_decode_group(manager) + manager = DummyFileManager(GroupWrapper(manager)) - self._ds = netcdf4_dataset - self._autoclose = autoclose - self._isopen = True + self._manager = manager self.format = self.ds.data_model self._filename = self.ds.filepath() self.is_remote = is_remote_uri(self._filename) - self._mode = mode = 'a' if mode == 'w' else mode - if opener: - self._opener = functools.partial(opener, mode=self._mode) - else: - self._opener = opener - super(NetCDF4DataStore, self).__init__(writer, lock=lock) + self.lock = ensure_lock(lock) + self.autoclose = autoclose @classmethod def open(cls, filename, mode='r', format='NETCDF4', group=None, - writer=None, clobber=True, diskless=False, persist=False, - autoclose=False, lock=HDF5_LOCK): - import netCDF4 as nc4 + clobber=True, diskless=False, persist=False, + lock=None, lock_maker=None, autoclose=False): + import netCDF4 if (len(filename) == 88 and - LooseVersion(nc4.__version__) < "1.3.1"): + LooseVersion(netCDF4.__version__) < "1.3.1"): warnings.warn( 'A segmentation fault may occur when the ' 'file path has exactly 88 characters as it does ' @@ -324,86 +334,91 @@ def open(cls, filename, mode='r', format='NETCDF4', group=None, 'https://github.com/pydata/xarray/issues/1745') if format is None: format = 'NETCDF4' - opener = functools.partial(_open_netcdf4_group, filename, mode=mode, - group=group, clobber=clobber, - diskless=diskless, persist=persist, - format=format) - ds = opener() - return cls(ds, mode=mode, writer=writer, opener=opener, - autoclose=autoclose, lock=lock) - def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - dimensions = var.dimensions - data = indexing.LazilyOuterIndexedArray( - NetCDF4ArrayWrapper(name, self)) - attributes = OrderedDict((k, var.getncattr(k)) - for k in var.ncattrs()) - _ensure_fill_value_valid(data, attributes) - # netCDF4 specific encoding; save _FillValue for later - encoding = {} - filters = var.filters() - if filters is not None: - encoding.update(filters) - chunking = var.chunking() - if chunking is not None: - if chunking == 'contiguous': - encoding['contiguous'] = True - encoding['chunksizes'] = None + if lock is None: + if mode == 'r': + if is_remote_uri(filename): + lock = NETCDFC_LOCK + else: + lock = NETCDF4_PYTHON_LOCK + else: + if format is None or format.startswith('NETCDF4'): + base_lock = NETCDF4_PYTHON_LOCK else: - encoding['contiguous'] = False - encoding['chunksizes'] = tuple(chunking) - # TODO: figure out how to round-trip "endian-ness" without raising - # warnings from netCDF4 - # encoding['endian'] = var.endian() - pop_to(attributes, encoding, 'least_significant_digit') - # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape - encoding['dtype'] = var.dtype + base_lock = NETCDFC_LOCK + lock = combine_locks([base_lock, get_write_lock(filename)]) + + manager = CachingFileManager( + _open_netcdf4_group, filename, lock, mode=mode, + kwargs=dict(group=group, clobber=clobber, diskless=diskless, + persist=persist, format=format)) + return cls(manager, lock=lock, autoclose=autoclose) + + @property + def ds(self): + return self._manager.acquire().value + + def open_store_variable(self, name, var): + dimensions = var.dimensions + data = indexing.LazilyOuterIndexedArray( + NetCDF4ArrayWrapper(name, self)) + attributes = OrderedDict((k, var.getncattr(k)) + for k in var.ncattrs()) + _ensure_fill_value_valid(data, attributes) + # netCDF4 specific encoding; save _FillValue for later + encoding = {} + filters = var.filters() + if filters is not None: + encoding.update(filters) + chunking = var.chunking() + if chunking is not None: + if chunking == 'contiguous': + encoding['contiguous'] = True + encoding['chunksizes'] = None + else: + encoding['contiguous'] = False + encoding['chunksizes'] = tuple(chunking) + # TODO: figure out how to round-trip "endian-ness" without raising + # warnings from netCDF4 + # encoding['endian'] = var.endian() + pop_to(attributes, encoding, 'least_significant_digit') + # save source so __repr__ can detect if it's local or not + encoding['source'] = self._filename + encoding['original_shape'] = var.shape + encoding['dtype'] = var.dtype return Variable(dimensions, data, attributes, encoding) def get_variables(self): - with self.ensure_open(autoclose=False): - dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in - iteritems(self.ds.variables)) + dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in + iteritems(self.ds.variables)) return dsvars def get_attrs(self): - with self.ensure_open(autoclose=True): - attrs = FrozenOrderedDict((k, self.ds.getncattr(k)) - for k in self.ds.ncattrs()) + attrs = FrozenOrderedDict((k, self.ds.getncattr(k)) + for k in self.ds.ncattrs()) return attrs def get_dimensions(self): - with self.ensure_open(autoclose=True): - dims = FrozenOrderedDict((k, len(v)) - for k, v in iteritems(self.ds.dimensions)) + dims = FrozenOrderedDict((k, len(v)) + for k, v in iteritems(self.ds.dimensions)) return dims def get_encoding(self): - with self.ensure_open(autoclose=True): - encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v.isunlimited()} + encoding = {} + encoding['unlimited_dims'] = { + k for k, v in self.ds.dimensions.items() if v.isunlimited()} return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - dim_length = length if not is_unlimited else None - self.ds.createDimension(name, size=dim_length) + dim_length = length if not is_unlimited else None + self.ds.createDimension(name, size=dim_length) def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - if self.format != 'NETCDF4': - value = encode_nc3_attr_value(value) - _set_nc_attribute(self.ds, key, value) - - def set_variables(self, *args, **kwargs): - with self.ensure_open(autoclose=False): - super(NetCDF4DataStore, self).set_variables(*args, **kwargs) + if self.format != 'NETCDF4': + value = encode_nc3_attr_value(value) + _set_nc_attribute(self.ds, key, value) def encode_variable(self, variable): variable = _force_native_endianness(variable) @@ -461,15 +476,8 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data - def sync(self, compute=True): - with self.ensure_open(autoclose=True): - super(NetCDF4DataStore, self).sync(compute=compute) - self.ds.sync() + def sync(self): + self.ds.sync() - def close(self): - if self._isopen: - # netCDF4 only allows closing the root group - ds = find_root(self.ds) - if ds._isopen: - ds.close() - self._isopen = False + def close(self, **kwargs): + self._manager.close(**kwargs) diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index d946c6fa927..e4691d1f7e1 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -1,17 +1,18 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools +from __future__ import absolute_import, division, print_function import numpy as np from .. import Variable -from ..core.pycompat import OrderedDict -from ..core.utils import (FrozenOrderedDict, Frozen) from ..core import indexing +from ..core.pycompat import OrderedDict +from ..core.utils import Frozen +from .common import AbstractDataStore, BackendArray +from .file_manager import CachingFileManager +from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock + -from .common import AbstractDataStore, DataStorePickleMixin, BackendArray +# psuedonetcdf can invoke netCDF libraries internally +PNETCDF_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK]) class PncArrayWrapper(BackendArray): @@ -24,7 +25,6 @@ def __init__(self, variable_name, datastore): self.dtype = np.dtype(array.dtype) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] def __getitem__(self, key): @@ -33,57 +33,55 @@ def __getitem__(self, key): self._getitem) def _getitem(self, key): - with self.datastore.ensure_open(autoclose=True): - return self.get_array()[key] + array = self.get_array() + with self.datastore.lock: + return array[key] -class PseudoNetCDFDataStore(AbstractDataStore, DataStorePickleMixin): +class PseudoNetCDFDataStore(AbstractDataStore): """Store for accessing datasets via PseudoNetCDF """ @classmethod - def open(cls, filename, format=None, writer=None, - autoclose=False, **format_kwds): + def open(cls, filename, lock=None, **format_kwds): from PseudoNetCDF import pncopen - opener = functools.partial(pncopen, filename, **format_kwds) - ds = opener() - mode = format_kwds.get('mode', 'r') - return cls(ds, mode=mode, writer=writer, opener=opener, - autoclose=autoclose) - def __init__(self, pnc_dataset, mode='r', writer=None, opener=None, - autoclose=False): + keywords = dict(kwargs=format_kwds) + # only include mode if explicitly passed + mode = format_kwds.pop('mode', None) + if mode is not None: + keywords['mode'] = mode + + if lock is None: + lock = PNETCDF_LOCK + + manager = CachingFileManager(pncopen, filename, lock=lock, **keywords) + return cls(manager, lock) - if autoclose and opener is None: - raise ValueError('autoclose requires an opener') + def __init__(self, manager, lock=None): + self._manager = manager + self.lock = ensure_lock(lock) - self._ds = pnc_dataset - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode - super(PseudoNetCDFDataStore, self).__init__() + @property + def ds(self): + return self._manager.acquire() def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - data = indexing.LazilyOuterIndexedArray( - PncArrayWrapper(name, self) - ) + data = indexing.LazilyOuterIndexedArray( + PncArrayWrapper(name, self) + ) attrs = OrderedDict((k, getattr(var, k)) for k in var.ncattrs()) return Variable(var.dimensions, data, attrs) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return ((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(dict([(k, getattr(self.ds, k)) - for k in self.ds.ncattrs()])) + return Frozen(dict([(k, getattr(self.ds, k)) + for k in self.ds.ncattrs()])) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -93,6 +91,4 @@ def get_encoding(self): return encoding def close(self): - if self._isopen: - self.ds.close() - self._isopen = False + self._manager.close() diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 98b76928597..574fff744e3 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -1,13 +1,20 @@ from __future__ import absolute_import, division, print_function -import functools - import numpy as np from .. import Variable from ..core import indexing from ..core.utils import Frozen, FrozenOrderedDict -from .common import AbstractDataStore, BackendArray, DataStorePickleMixin +from .common import AbstractDataStore, BackendArray +from .file_manager import CachingFileManager +from .locks import ( + HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock, SerializableLock) + + +# PyNIO can invoke netCDF libraries internally +# Add a dedicated lock just in case NCL as well isn't thread-safe. +NCL_LOCK = SerializableLock() +PYNIO_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK, NCL_LOCK]) class NioArrayWrapper(BackendArray): @@ -20,7 +27,6 @@ def __init__(self, variable_name, datastore): self.dtype = np.dtype(array.typecode()) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] def __getitem__(self, key): @@ -28,46 +34,45 @@ def __getitem__(self, key): key, self.shape, indexing.IndexingSupport.BASIC, self._getitem) def _getitem(self, key): - with self.datastore.ensure_open(autoclose=True): - array = self.get_array() + array = self.get_array() + with self.datastore.lock: if key == () and self.ndim == 0: return array.get_value() - return array[key] -class NioDataStore(AbstractDataStore, DataStorePickleMixin): +class NioDataStore(AbstractDataStore): """Store for accessing datasets via PyNIO """ - def __init__(self, filename, mode='r', autoclose=False): + def __init__(self, filename, mode='r', lock=None): import Nio - opener = functools.partial(Nio.open_file, filename, mode=mode) - self._ds = opener() - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode + if lock is None: + lock = PYNIO_LOCK + self.lock = ensure_lock(lock) + self._manager = CachingFileManager( + Nio.open_file, filename, lock=lock, mode=mode) # xarray provides its own support for FillValue, # so turn off PyNIO's support for the same. self.ds.set_option('MaskedArrayMode', 'MaskedNever') + @property + def ds(self): + return self._manager.acquire() + def open_store_variable(self, name, var): data = indexing.LazilyOuterIndexedArray(NioArrayWrapper(name, self)) return Variable(var.dimensions, data, var.attributes) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.attributes) + return Frozen(self.ds.attributes) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -76,6 +81,4 @@ def get_encoding(self): return encoding def close(self): - if self._isopen: - self.ds.close() - self._isopen = False + self._manager.close() diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 5221cf0e913..5746b4e748d 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -1,21 +1,20 @@ import os +import warnings from collections import OrderedDict from distutils.version import LooseVersion -import warnings import numpy as np from .. import DataArray from ..core import indexing from ..core.utils import is_scalar -from .common import BackendArray, PickleByReconstructionWrapper +from .common import BackendArray +from .file_manager import CachingFileManager +from .locks import SerializableLock -try: - from dask.utils import SerializableLock as Lock -except ImportError: - from threading import Lock -RASTERIO_LOCK = Lock() +# TODO: should this be GDAL_LOCK instead? +RASTERIO_LOCK = SerializableLock() _ERROR_MSG = ('The kind of indexing operation you are trying to do is not ' 'valid on rasterio files. Try to load your data with ds.load()' @@ -25,18 +24,22 @@ class RasterioArrayWrapper(BackendArray): """A wrapper around rasterio dataset objects""" - def __init__(self, riods): - self.riods = riods - self._shape = (riods.value.count, riods.value.height, - riods.value.width) - self._ndims = len(self.shape) + def __init__(self, manager): + self.manager = manager - @property - def dtype(self): - dtypes = self.riods.value.dtypes + # cannot save riods as an attribute: this would break pickleability + riods = manager.acquire() + + self._shape = (riods.count, riods.height, riods.width) + + dtypes = riods.dtypes if not np.all(np.asarray(dtypes) == dtypes[0]): raise ValueError('All bands should have the same dtype') - return np.dtype(dtypes[0]) + self._dtype = np.dtype(dtypes[0]) + + @property + def dtype(self): + return self._dtype @property def shape(self): @@ -95,7 +98,7 @@ def _get_indexer(self, key): if isinstance(key[1], np.ndarray) and isinstance(key[2], np.ndarray): # do outer-style indexing - np_inds[1:] = np.ix_(*np_inds[1:]) + np_inds[-2:] = np.ix_(*np_inds[-2:]) return band_key, tuple(window), tuple(squeeze_axis), tuple(np_inds) @@ -108,7 +111,8 @@ def _getitem(self, key): stop - start for (start, stop) in window) out = np.zeros(shape, dtype=self.dtype) else: - out = self.riods.value.read(band_key, window=window) + riods = self.manager.acquire() + out = riods.read(band_key, window=window) if squeeze_axis: out = np.squeeze(out, axis=squeeze_axis) @@ -203,7 +207,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, import rasterio - riods = PickleByReconstructionWrapper(rasterio.open, filename, mode='r') + manager = CachingFileManager(rasterio.open, filename, mode='r') + riods = manager.acquire() if cache is None: cache = chunks is None @@ -211,20 +216,20 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, coords = OrderedDict() # Get bands - if riods.value.count < 1: + if riods.count < 1: raise ValueError('Unknown dims') - coords['band'] = np.asarray(riods.value.indexes) + coords['band'] = np.asarray(riods.indexes) # Get coordinates if LooseVersion(rasterio.__version__) < '1.0': - transform = riods.value.affine + transform = riods.affine else: - transform = riods.value.transform + transform = riods.transform if transform.is_rectilinear: # 1d coordinates parse = True if parse_coordinates is None else parse_coordinates if parse: - nx, ny = riods.value.width, riods.value.height + nx, ny = riods.width, riods.height # xarray coordinates are pixel centered x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform _, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform @@ -234,57 +239,60 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, # 2d coordinates parse = False if (parse_coordinates is None) else parse_coordinates if parse: - warnings.warn("The file coordinates' transformation isn't " - "rectilinear: xarray won't parse the coordinates " - "in this case. Set `parse_coordinates=False` to " - "suppress this warning.", - RuntimeWarning, stacklevel=3) + warnings.warn( + "The file coordinates' transformation isn't " + "rectilinear: xarray won't parse the coordinates " + "in this case. Set `parse_coordinates=False` to " + "suppress this warning.", + RuntimeWarning, stacklevel=3) # Attributes attrs = dict() # Affine transformation matrix (always available) # This describes coefficients mapping pixel coordinates to CRS # For serialization store as tuple of 6 floats, the last row being - # always (0, 0, 1) per definition (see https://github.com/sgillies/affine) + # always (0, 0, 1) per definition (see + # https://github.com/sgillies/affine) attrs['transform'] = tuple(transform)[:6] - if hasattr(riods.value, 'crs') and riods.value.crs: + if hasattr(riods, 'crs') and riods.crs: # CRS is a dict-like object specific to rasterio # If CRS is not None, we convert it back to a PROJ4 string using # rasterio itself - attrs['crs'] = riods.value.crs.to_string() - if hasattr(riods.value, 'res'): + attrs['crs'] = riods.crs.to_string() + if hasattr(riods, 'res'): # (width, height) tuple of pixels in units of CRS - attrs['res'] = riods.value.res - if hasattr(riods.value, 'is_tiled'): + attrs['res'] = riods.res + if hasattr(riods, 'is_tiled'): # Is the TIF tiled? (bool) # We cast it to an int for netCDF compatibility - attrs['is_tiled'] = np.uint8(riods.value.is_tiled) - if hasattr(riods.value, 'nodatavals'): + attrs['is_tiled'] = np.uint8(riods.is_tiled) + if hasattr(riods, 'nodatavals'): # The nodata values for the raster bands - attrs['nodatavals'] = tuple([np.nan if nodataval is None else nodataval - for nodataval in riods.value.nodatavals]) + attrs['nodatavals'] = tuple( + np.nan if nodataval is None else nodataval + for nodataval in riods.nodatavals) # Parse extra metadata from tags, if supported parsers = {'ENVI': _parse_envi} - driver = riods.value.driver + driver = riods.driver if driver in parsers: - meta = parsers[driver](riods.value.tags(ns=driver)) + meta = parsers[driver](riods.tags(ns=driver)) for k, v in meta.items(): # Add values as coordinates if they match the band count, # as attributes otherwise if (isinstance(v, (list, np.ndarray)) and - len(v) == riods.value.count): + len(v) == riods.count): coords[k] = ('band', np.asarray(v)) else: attrs[k] = v - data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(riods)) + data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager)) # this lets you write arrays loaded with rasterio data = indexing.CopyOnWriteArray(data) - if cache and (chunks is None): + if cache and chunks is None: data = indexing.MemoryCachedArray(data) result = DataArray(data=data, dims=('band', 'y', 'x'), @@ -306,6 +314,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, lock=lock) # Make the file closeable - result._file_obj = riods + result._file_obj = manager return result diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index cd84431f6b7..b009342efb6 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -1,6 +1,5 @@ from __future__ import absolute_import, division, print_function -import functools import warnings from distutils.version import LooseVersion from io import BytesIO @@ -11,7 +10,9 @@ from ..core.indexing import NumpyIndexingAdapter from ..core.pycompat import OrderedDict, basestring, iteritems from ..core.utils import Frozen, FrozenOrderedDict -from .common import BackendArray, DataStorePickleMixin, WritableCFDataStore +from .common import BackendArray, WritableCFDataStore +from .locks import get_write_lock +from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import ( encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name) @@ -40,31 +41,26 @@ def __init__(self, variable_name, datastore): str(array.dtype.itemsize)) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name].data def __getitem__(self, key): - with self.datastore.ensure_open(autoclose=True): - data = NumpyIndexingAdapter(self.get_array())[key] - # Copy data if the source file is mmapped. - # This makes things consistent - # with the netCDF4 library by ensuring - # we can safely read arrays even - # after closing associated files. - copy = self.datastore.ds.use_mmap - return np.array(data, dtype=self.dtype, copy=copy) + data = NumpyIndexingAdapter(self.get_array())[key] + # Copy data if the source file is mmapped. This makes things consistent + # with the netCDF4 library by ensuring we can safely read arrays even + # after closing associated files. + copy = self.datastore.ds.use_mmap + return np.array(data, dtype=self.dtype, copy=copy) def __setitem__(self, key, value): - with self.datastore.ensure_open(autoclose=True): - data = self.datastore.ds.variables[self.variable_name] - try: - data[key] = value - except TypeError: - if key is Ellipsis: - # workaround for GH: scipy/scipy#6880 - data[:] = value - else: - raise + data = self.datastore.ds.variables[self.variable_name] + try: + data[key] = value + except TypeError: + if key is Ellipsis: + # workaround for GH: scipy/scipy#6880 + data[:] = value + else: + raise def _open_scipy_netcdf(filename, mode, mmap, version): @@ -106,7 +102,7 @@ def _open_scipy_netcdf(filename, mode, mmap, version): raise -class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): +class ScipyDataStore(WritableCFDataStore): """Store for reading and writing data via scipy.io.netcdf. This store has the advantage of being able to be initialized with a @@ -116,7 +112,7 @@ class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): """ def __init__(self, filename_or_obj, mode='r', format=None, group=None, - writer=None, mmap=None, autoclose=False, lock=None): + mmap=None, lock=None): import scipy import scipy.io @@ -140,34 +136,38 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, raise ValueError('invalid format for scipy.io.netcdf backend: %r' % format) - opener = functools.partial(_open_scipy_netcdf, - filename=filename_or_obj, - mode=mode, mmap=mmap, version=version) - self._ds = opener() - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode + if (lock is None and mode != 'r' and + isinstance(filename_or_obj, basestring)): + lock = get_write_lock(filename_or_obj) + + if isinstance(filename_or_obj, basestring): + manager = CachingFileManager( + _open_scipy_netcdf, filename_or_obj, mode=mode, lock=lock, + kwargs=dict(mmap=mmap, version=version)) + else: + scipy_dataset = _open_scipy_netcdf( + filename_or_obj, mode=mode, mmap=mmap, version=version) + manager = DummyFileManager(scipy_dataset) + + self._manager = manager - super(ScipyDataStore, self).__init__(writer, lock=lock) + @property + def ds(self): + return self._manager.acquire() def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - return Variable(var.dimensions, ScipyArrayWrapper(name, self), - _decode_attrs(var._attributes)) + return Variable(var.dimensions, ScipyArrayWrapper(name, self), + _decode_attrs(var._attributes)) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in iteritems(self.ds.variables)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in iteritems(self.ds.variables)) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(_decode_attrs(self.ds._attributes)) + return Frozen(_decode_attrs(self.ds._attributes)) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -176,22 +176,20 @@ def get_encoding(self): return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - if name in self.ds.dimensions: - raise ValueError('%s does not support modifying dimensions' - % type(self).__name__) - dim_length = length if not is_unlimited else None - self.ds.createDimension(name, dim_length) + if name in self.ds.dimensions: + raise ValueError('%s does not support modifying dimensions' + % type(self).__name__) + dim_length = length if not is_unlimited else None + self.ds.createDimension(name, dim_length) def _validate_attr_key(self, key): if not is_valid_nc3_name(key): raise ValueError("Not a valid attribute name") def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - self._validate_attr_key(key) - value = encode_nc3_attr_value(value) - setattr(self.ds, key, value) + self._validate_attr_key(key) + value = encode_nc3_attr_value(value) + setattr(self.ds, key, value) def encode_variable(self, variable): variable = encode_nc3_variable(variable) @@ -219,27 +217,8 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, data - def sync(self, compute=True): - if not compute: - raise NotImplementedError( - 'compute=False is not supported for the scipy backend yet') - with self.ensure_open(autoclose=True): - super(ScipyDataStore, self).sync(compute=compute) - self.ds.flush() + def sync(self): + self.ds.sync() def close(self): - self.ds.close() - self._isopen = False - - def __exit__(self, type, value, tb): - self.close() - - def __setstate__(self, state): - filename = state['_opener'].keywords['filename'] - if hasattr(filename, 'seek'): - # it's a file-like object - # seek to the start of the file so scipy can read it - filename.seek(0) - super(ScipyDataStore, self).__setstate__(state) - self._ds = None - self._isopen = False + self._manager.close() diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 47b90c8a617..5f19c826289 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -217,8 +217,7 @@ class ZarrStore(AbstractWritableDataStore): """ @classmethod - def open_group(cls, store, mode='r', synchronizer=None, group=None, - writer=None): + def open_group(cls, store, mode='r', synchronizer=None, group=None): import zarr min_zarr = '2.2' @@ -230,24 +229,14 @@ def open_group(cls, store, mode='r', synchronizer=None, group=None, "#installation" % min_zarr) zarr_group = zarr.open_group(store=store, mode=mode, synchronizer=synchronizer, path=group) - return cls(zarr_group, writer=writer) + return cls(zarr_group) - def __init__(self, zarr_group, writer=None): + def __init__(self, zarr_group): self.ds = zarr_group self._read_only = self.ds.read_only self._synchronizer = self.ds.synchronizer self._group = self.ds.path - if writer is None: - # by default, we should not need a lock for writing zarr because - # we do not (yet) allow overlapping chunks during write - zarr_writer = ArrayWriter(lock=False) - else: - zarr_writer = writer - - # do we need to define attributes for all of the opener keyword args? - super(ZarrStore, self).__init__(zarr_writer) - def open_store_variable(self, name, zarr_array): data = indexing.LazilyOuterIndexedArray(ZarrArrayWrapper(name, self)) dimensions, attributes = _get_zarr_dims_and_attrs(zarr_array, @@ -334,8 +323,8 @@ def store(self, variables, attributes, *args, **kwargs): AbstractWritableDataStore.store(self, variables, attributes, *args, **kwargs) - def sync(self, compute=True): - self.delayed_store = self.writer.sync(compute=compute) + def sync(self): + pass def open_zarr(store, group=None, synchronizer=None, auto_chunk=True, diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py new file mode 100644 index 00000000000..83e8c7a7e4b --- /dev/null +++ b/xarray/coding/cftime_offsets.py @@ -0,0 +1,735 @@ +"""Time offset classes for use with cftime.datetime objects""" +# The offset classes and mechanisms for generating time ranges defined in +# this module were copied/adapted from those defined in pandas. See in +# particular the objects and methods defined in pandas.tseries.offsets +# and pandas.core.indexes.datetimes. + +# For reference, here is a copy of the pandas copyright notice: + +# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team +# All rights reserved. + +# Copyright (c) 2008-2011 AQR Capital Management, LLC +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. + +# * Neither the name of the copyright holder nor the names of any +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import re +from datetime import timedelta +from functools import partial + +import numpy as np + +from ..core.pycompat import basestring +from .cftimeindex import CFTimeIndex, _parse_iso8601_with_reso +from .times import format_cftime_datetime + + +def get_date_type(calendar): + """Return the cftime date type for a given calendar name.""" + try: + import cftime + except ImportError: + raise ImportError( + 'cftime is required for dates with non-standard calendars') + else: + calendars = { + 'noleap': cftime.DatetimeNoLeap, + '360_day': cftime.Datetime360Day, + '365_day': cftime.DatetimeNoLeap, + '366_day': cftime.DatetimeAllLeap, + 'gregorian': cftime.DatetimeGregorian, + 'proleptic_gregorian': cftime.DatetimeProlepticGregorian, + 'julian': cftime.DatetimeJulian, + 'all_leap': cftime.DatetimeAllLeap, + 'standard': cftime.DatetimeProlepticGregorian + } + return calendars[calendar] + + +class BaseCFTimeOffset(object): + _freq = None + + def __init__(self, n=1): + if not isinstance(n, int): + raise TypeError( + "The provided multiple 'n' must be an integer. " + "Instead a value of type {!r} was provided.".format(type(n))) + self.n = n + + def rule_code(self): + return self._freq + + def __eq__(self, other): + return self.n == other.n and self.rule_code() == other.rule_code() + + def __ne__(self, other): + return not self == other + + def __add__(self, other): + return self.__apply__(other) + + def __sub__(self, other): + import cftime + + if isinstance(other, cftime.datetime): + raise TypeError('Cannot subtract a cftime.datetime ' + 'from a time offset.') + elif type(other) == type(self): + return type(self)(self.n - other.n) + else: + return NotImplemented + + def __mul__(self, other): + return type(self)(n=other * self.n) + + def __neg__(self): + return self * -1 + + def __rmul__(self, other): + return self.__mul__(other) + + def __radd__(self, other): + return self.__add__(other) + + def __rsub__(self, other): + if isinstance(other, BaseCFTimeOffset) and type(self) != type(other): + raise TypeError('Cannot subtract cftime offsets of differing ' + 'types') + return -self + other + + def __apply__(self): + return NotImplemented + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + test_date = (self + date) - self + return date == test_date + + def rollforward(self, date): + if self.onOffset(date): + return date + else: + return date + type(self)() + + def rollback(self, date): + if self.onOffset(date): + return date + else: + return date - type(self)() + + def __str__(self): + return '<{}: n={}>'.format(type(self).__name__, self.n) + + def __repr__(self): + return str(self) + + +def _days_in_month(date): + """The number of days in the month of the given date""" + if date.month == 12: + reference = type(date)(date.year + 1, 1, 1) + else: + reference = type(date)(date.year, date.month + 1, 1) + return (reference - timedelta(days=1)).day + + +def _adjust_n_months(other_day, n, reference_day): + """Adjust the number of times a monthly offset is applied based + on the day of a given date, and the reference day provided. + """ + if n > 0 and other_day < reference_day: + n = n - 1 + elif n <= 0 and other_day > reference_day: + n = n + 1 + return n + + +def _adjust_n_years(other, n, month, reference_day): + """Adjust the number of times an annual offset is applied based on + another date, and the reference day provided""" + if n > 0: + if other.month < month or (other.month == month and + other.day < reference_day): + n -= 1 + else: + if other.month > month or (other.month == month and + other.day > reference_day): + n += 1 + return n + + +def _shift_months(date, months, day_option='start'): + """Shift the date to a month start or end a given number of months away. + """ + delta_year = (date.month + months) // 12 + month = (date.month + months) % 12 + + if month == 0: + month = 12 + delta_year = delta_year - 1 + year = date.year + delta_year + + if day_option == 'start': + day = 1 + elif day_option == 'end': + reference = type(date)(year, month, 1) + day = _days_in_month(reference) + else: + raise ValueError(day_option) + return date.replace(year=year, month=month, day=day) + + +class MonthBegin(BaseCFTimeOffset): + _freq = 'MS' + + def __apply__(self, other): + n = _adjust_n_months(other.day, self.n, 1) + return _shift_months(other, n, 'start') + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == 1 + + +class MonthEnd(BaseCFTimeOffset): + _freq = 'M' + + def __apply__(self, other): + n = _adjust_n_months(other.day, self.n, _days_in_month(other)) + return _shift_months(other, n, 'end') + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == _days_in_month(date) + + +_MONTH_ABBREVIATIONS = { + 1: 'JAN', + 2: 'FEB', + 3: 'MAR', + 4: 'APR', + 5: 'MAY', + 6: 'JUN', + 7: 'JUL', + 8: 'AUG', + 9: 'SEP', + 10: 'OCT', + 11: 'NOV', + 12: 'DEC' +} + + +class YearOffset(BaseCFTimeOffset): + _freq = None + _day_option = None + _default_month = None + + def __init__(self, n=1, month=None): + BaseCFTimeOffset.__init__(self, n) + if month is None: + self.month = self._default_month + else: + self.month = month + if not isinstance(self.month, int): + raise TypeError("'self.month' must be an integer value between 1 " + "and 12. Instead, it was set to a value of " + "{!r}".format(self.month)) + elif not (1 <= self.month <= 12): + raise ValueError("'self.month' must be an integer value between 1 " + "and 12. Instead, it was set to a value of " + "{!r}".format(self.month)) + + def __apply__(self, other): + if self._day_option == 'start': + reference_day = 1 + elif self._day_option == 'end': + reference_day = _days_in_month(other) + else: + raise ValueError(self._day_option) + years = _adjust_n_years(other, self.n, self.month, reference_day) + months = years * 12 + (self.month - other.month) + return _shift_months(other, months, self._day_option) + + def __sub__(self, other): + import cftime + + if isinstance(other, cftime.datetime): + raise TypeError('Cannot subtract cftime.datetime from offset.') + elif type(other) == type(self) and other.month == self.month: + return type(self)(self.n - other.n, month=self.month) + else: + return NotImplemented + + def __mul__(self, other): + return type(self)(n=other * self.n, month=self.month) + + def rule_code(self): + return '{}-{}'.format(self._freq, _MONTH_ABBREVIATIONS[self.month]) + + def __str__(self): + return '<{}: n={}, month={}>'.format( + type(self).__name__, self.n, self.month) + + +class YearBegin(YearOffset): + _freq = 'AS' + _day_option = 'start' + _default_month = 1 + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == 1 and date.month == self.month + + def rollforward(self, date): + """Roll date forward to nearest start of year""" + if self.onOffset(date): + return date + else: + return date + YearBegin(month=self.month) + + def rollback(self, date): + """Roll date backward to nearest start of year""" + if self.onOffset(date): + return date + else: + return date - YearBegin(month=self.month) + + +class YearEnd(YearOffset): + _freq = 'A' + _day_option = 'end' + _default_month = 12 + + def onOffset(self, date): + """Check if the given date is in the set of possible dates created + using a length-one version of this offset class.""" + return date.day == _days_in_month(date) and date.month == self.month + + def rollforward(self, date): + """Roll date forward to nearest end of year""" + if self.onOffset(date): + return date + else: + return date + YearEnd(month=self.month) + + def rollback(self, date): + """Roll date backward to nearest end of year""" + if self.onOffset(date): + return date + else: + return date - YearEnd(month=self.month) + + +class Day(BaseCFTimeOffset): + _freq = 'D' + + def __apply__(self, other): + return other + timedelta(days=self.n) + + +class Hour(BaseCFTimeOffset): + _freq = 'H' + + def __apply__(self, other): + return other + timedelta(hours=self.n) + + +class Minute(BaseCFTimeOffset): + _freq = 'T' + + def __apply__(self, other): + return other + timedelta(minutes=self.n) + + +class Second(BaseCFTimeOffset): + _freq = 'S' + + def __apply__(self, other): + return other + timedelta(seconds=self.n) + + +_FREQUENCIES = { + 'A': YearEnd, + 'AS': YearBegin, + 'Y': YearEnd, + 'YS': YearBegin, + 'M': MonthEnd, + 'MS': MonthBegin, + 'D': Day, + 'H': Hour, + 'T': Minute, + 'min': Minute, + 'S': Second, + 'AS-JAN': partial(YearBegin, month=1), + 'AS-FEB': partial(YearBegin, month=2), + 'AS-MAR': partial(YearBegin, month=3), + 'AS-APR': partial(YearBegin, month=4), + 'AS-MAY': partial(YearBegin, month=5), + 'AS-JUN': partial(YearBegin, month=6), + 'AS-JUL': partial(YearBegin, month=7), + 'AS-AUG': partial(YearBegin, month=8), + 'AS-SEP': partial(YearBegin, month=9), + 'AS-OCT': partial(YearBegin, month=10), + 'AS-NOV': partial(YearBegin, month=11), + 'AS-DEC': partial(YearBegin, month=12), + 'A-JAN': partial(YearEnd, month=1), + 'A-FEB': partial(YearEnd, month=2), + 'A-MAR': partial(YearEnd, month=3), + 'A-APR': partial(YearEnd, month=4), + 'A-MAY': partial(YearEnd, month=5), + 'A-JUN': partial(YearEnd, month=6), + 'A-JUL': partial(YearEnd, month=7), + 'A-AUG': partial(YearEnd, month=8), + 'A-SEP': partial(YearEnd, month=9), + 'A-OCT': partial(YearEnd, month=10), + 'A-NOV': partial(YearEnd, month=11), + 'A-DEC': partial(YearEnd, month=12) +} + + +_FREQUENCY_CONDITION = '|'.join(_FREQUENCIES.keys()) +_PATTERN = '^((?P\d+)|())(?P({0}))$'.format( + _FREQUENCY_CONDITION) + + +def to_offset(freq): + """Convert a frequency string to the appropriate subclass of + BaseCFTimeOffset.""" + if isinstance(freq, BaseCFTimeOffset): + return freq + else: + try: + freq_data = re.match(_PATTERN, freq).groupdict() + except AttributeError: + raise ValueError('Invalid frequency string provided') + + freq = freq_data['freq'] + multiples = freq_data['multiple'] + if multiples is None: + multiples = 1 + else: + multiples = int(multiples) + + return _FREQUENCIES[freq](n=multiples) + + +def to_cftime_datetime(date_str_or_date, calendar=None): + import cftime + + if isinstance(date_str_or_date, basestring): + if calendar is None: + raise ValueError( + 'If converting a string to a cftime.datetime object, ' + 'a calendar type must be provided') + date, _ = _parse_iso8601_with_reso(get_date_type(calendar), + date_str_or_date) + return date + elif isinstance(date_str_or_date, cftime.datetime): + return date_str_or_date + else: + raise TypeError("date_str_or_date must be a string or a " + 'subclass of cftime.datetime. Instead got ' + '{!r}.'.format(date_str_or_date)) + + +def normalize_date(date): + """Round datetime down to midnight.""" + return date.replace(hour=0, minute=0, second=0, microsecond=0) + + +def _maybe_normalize_date(date, normalize): + """Round datetime down to midnight if normalize is True.""" + if normalize: + return normalize_date(date) + else: + return date + + +def _generate_linear_range(start, end, periods): + """Generate an equally-spaced sequence of cftime.datetime objects between + and including two dates (whose length equals the number of periods).""" + import cftime + + total_seconds = (end - start).total_seconds() + values = np.linspace(0., total_seconds, periods, endpoint=True) + units = 'seconds since {}'.format(format_cftime_datetime(start)) + calendar = start.calendar + return cftime.num2date(values, units=units, calendar=calendar, + only_use_cftime_datetimes=True) + + +def _generate_range(start, end, periods, offset): + """Generate a regular range of cftime.datetime objects with a + given time offset. + + Adapted from pandas.tseries.offsets.generate_range. + + Parameters + ---------- + start : cftime.datetime, or None + Start of range + end : cftime.datetime, or None + End of range + periods : int, or None + Number of elements in the sequence + offset : BaseCFTimeOffset + An offset class designed for working with cftime.datetime objects + + Returns + ------- + A generator object + """ + if start: + start = offset.rollforward(start) + + if end: + end = offset.rollback(end) + + if periods is None and end < start: + end = None + periods = 0 + + if end is None: + end = start + (periods - 1) * offset + + if start is None: + start = end - (periods - 1) * offset + + current = start + if offset.n >= 0: + while current <= end: + yield current + + next_date = current + offset + if next_date <= current: + raise ValueError('Offset {offset} did not increment date' + .format(offset=offset)) + current = next_date + else: + while current >= end: + yield current + + next_date = current + offset + if next_date >= current: + raise ValueError('Offset {offset} did not decrement date' + .format(offset=offset)) + current = next_date + + +def _count_not_none(*args): + """Compute the number of non-None arguments.""" + return sum([arg is not None for arg in args]) + + +def cftime_range(start=None, end=None, periods=None, freq='D', + tz=None, normalize=False, name=None, closed=None, + calendar='standard'): + """Return a fixed frequency CFTimeIndex. + + Parameters + ---------- + start : str or cftime.datetime, optional + Left bound for generating dates. + end : str or cftime.datetime, optional + Right bound for generating dates. + periods : integer, optional + Number of periods to generate. + freq : str, default 'D', BaseCFTimeOffset, or None + Frequency strings can have multiples, e.g. '5H'. + normalize : bool, default False + Normalize start/end dates to midnight before generating date range. + name : str, default None + Name of the resulting index + closed : {None, 'left', 'right'}, optional + Make the interval closed with respect to the given frequency to the + 'left', 'right', or both sides (None, the default). + calendar : str + Calendar type for the datetimes (default 'standard'). + + Returns + ------- + CFTimeIndex + + Notes + ----- + + This function is an analog of ``pandas.date_range`` for use in generating + sequences of ``cftime.datetime`` objects. It supports most of the + features of ``pandas.date_range`` (e.g. specifying how the index is + ``closed`` on either side, or whether or not to ``normalize`` the start and + end bounds); however, there are some notable exceptions: + + - You cannot specify a ``tz`` (time zone) argument. + - Start or end dates specified as partial-datetime strings must use the + `ISO-8601 format `_. + - It supports many, but not all, frequencies supported by + ``pandas.date_range``. For example it does not currently support any of + the business-related, semi-monthly, or sub-second frequencies. + - Compound sub-monthly frequencies are not supported, e.g. '1H1min', as + these can easily be written in terms of the finest common resolution, + e.g. '61min'. + + Valid simple frequency strings for use with ``cftime``-calendars include + any multiples of the following. + + +--------+-----------------------+ + | Alias | Description | + +========+=======================+ + | A, Y | Year-end frequency | + +--------+-----------------------+ + | AS, YS | Year-start frequency | + +--------+-----------------------+ + | M | Month-end frequency | + +--------+-----------------------+ + | MS | Month-start frequency | + +--------+-----------------------+ + | D | Day frequency | + +--------+-----------------------+ + | H | Hour frequency | + +--------+-----------------------+ + | T, min | Minute frequency | + +--------+-----------------------+ + | S | Second frequency | + +--------+-----------------------+ + + Any multiples of the following anchored offsets are also supported. + + +----------+-------------------------------------------------------------------+ + | Alias | Description | + +==========+===================================================================+ + | A(S)-JAN | Annual frequency, anchored at the end (or beginning) of January | + +----------+-------------------------------------------------------------------+ + | A(S)-FEB | Annual frequency, anchored at the end (or beginning) of February | + +----------+-------------------------------------------------------------------+ + | A(S)-MAR | Annual frequency, anchored at the end (or beginning) of March | + +----------+-------------------------------------------------------------------+ + | A(S)-APR | Annual frequency, anchored at the end (or beginning) of April | + +----------+-------------------------------------------------------------------+ + | A(S)-MAY | Annual frequency, anchored at the end (or beginning) of May | + +----------+-------------------------------------------------------------------+ + | A(S)-JUN | Annual frequency, anchored at the end (or beginning) of June | + +----------+-------------------------------------------------------------------+ + | A(S)-JUL | Annual frequency, anchored at the end (or beginning) of July | + +----------+-------------------------------------------------------------------+ + | A(S)-AUG | Annual frequency, anchored at the end (or beginning) of August | + +----------+-------------------------------------------------------------------+ + | A(S)-SEP | Annual frequency, anchored at the end (or beginning) of September | + +----------+-------------------------------------------------------------------+ + | A(S)-OCT | Annual frequency, anchored at the end (or beginning) of October | + +----------+-------------------------------------------------------------------+ + | A(S)-NOV | Annual frequency, anchored at the end (or beginning) of November | + +----------+-------------------------------------------------------------------+ + | A(S)-DEC | Annual frequency, anchored at the end (or beginning) of December | + +----------+-------------------------------------------------------------------+ + + Finally, the following calendar aliases are supported. + + +--------------------------------+---------------------------------------+ + | Alias | Date type | + +================================+=======================================+ + | standard, proleptic_gregorian | ``cftime.DatetimeProlepticGregorian`` | + +--------------------------------+---------------------------------------+ + | gregorian | ``cftime.DatetimeGregorian`` | + +--------------------------------+---------------------------------------+ + | noleap, 365_day | ``cftime.DatetimeNoLeap`` | + +--------------------------------+---------------------------------------+ + | all_leap, 366_day | ``cftime.DatetimeAllLeap`` | + +--------------------------------+---------------------------------------+ + | 360_day | ``cftime.Datetime360Day`` | + +--------------------------------+---------------------------------------+ + | julian | ``cftime.DatetimeJulian`` | + +--------------------------------+---------------------------------------+ + + Examples + -------- + + This function returns a ``CFTimeIndex``, populated with ``cftime.datetime`` + objects associated with the specified calendar type, e.g. + + >>> xr.cftime_range(start='2000', periods=6, freq='2MS', calendar='noleap') + CFTimeIndex([2000-01-01 00:00:00, 2000-03-01 00:00:00, 2000-05-01 00:00:00, + 2000-07-01 00:00:00, 2000-09-01 00:00:00, 2000-11-01 00:00:00], + dtype='object') + + As in the standard pandas function, three of the ``start``, ``end``, + ``periods``, or ``freq`` arguments must be specified at a given time, with + the other set to ``None``. See the `pandas documentation + `_ + for more examples of the behavior of ``date_range`` with each of the + parameters. + + See Also + -------- + pandas.date_range + """ # noqa: E501 + # Adapted from pandas.core.indexes.datetimes._generate_range. + if _count_not_none(start, end, periods, freq) != 3: + raise ValueError( + "Of the arguments 'start', 'end', 'periods', and 'freq', three " + "must be specified at a time.") + + if start is not None: + start = to_cftime_datetime(start, calendar) + start = _maybe_normalize_date(start, normalize) + if end is not None: + end = to_cftime_datetime(end, calendar) + end = _maybe_normalize_date(end, normalize) + + if freq is None: + dates = _generate_linear_range(start, end, periods) + else: + offset = to_offset(freq) + dates = np.array(list(_generate_range(start, end, periods, offset))) + + left_closed = False + right_closed = False + + if closed is None: + left_closed = True + right_closed = True + elif closed == 'left': + left_closed = True + elif closed == 'right': + right_closed = True + else: + raise ValueError("Closed must be either 'left', 'right' or None") + + if (not left_closed and len(dates) and + start is not None and dates[0] == start): + dates = dates[1:] + if (not right_closed and len(dates) and + end is not None and dates[-1] == end): + dates = dates[:-1] + + return CFTimeIndex(dates, name=name) diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index eb8cae2f398..dea896c199a 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -1,4 +1,46 @@ +"""DatetimeIndex analog for cftime.datetime objects""" +# The pandas.Index subclass defined here was copied and adapted for +# use with cftime.datetime objects based on the source code defining +# pandas.DatetimeIndex. + +# For reference, here is a copy of the pandas copyright notice: + +# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team +# All rights reserved. + +# Copyright (c) 2008-2011 AQR Capital Management, LLC +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. + +# * Neither the name of the copyright holder nor the names of any +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + from __future__ import absolute_import + import re from datetime import timedelta @@ -116,28 +158,43 @@ def f(self): def get_date_type(self): - return type(self._data[0]) + if self._data.size: + return type(self._data[0]) + else: + return None def assert_all_valid_date_type(data): import cftime - sample = data[0] - date_type = type(sample) - if not isinstance(sample, cftime.datetime): - raise TypeError( - 'CFTimeIndex requires cftime.datetime ' - 'objects. Got object of {}.'.format(date_type)) - if not all(isinstance(value, date_type) for value in data): - raise TypeError( - 'CFTimeIndex requires using datetime ' - 'objects of all the same type. Got\n{}.'.format(data)) + if data.size: + sample = data[0] + date_type = type(sample) + if not isinstance(sample, cftime.datetime): + raise TypeError( + 'CFTimeIndex requires cftime.datetime ' + 'objects. Got object of {}.'.format(date_type)) + if not all(isinstance(value, date_type) for value in data): + raise TypeError( + 'CFTimeIndex requires using datetime ' + 'objects of all the same type. Got\n{}.'.format(data)) class CFTimeIndex(pd.Index): """Custom Index for working with CF calendars and dates All elements of a CFTimeIndex must be cftime.datetime objects. + + Parameters + ---------- + data : array or CFTimeIndex + Sequence of cftime.datetime objects to use in index + name : str, default None + Name of the resulting index + + See Also + -------- + cftime_range """ year = _field_accessor('year', 'The year of the datetime') month = _field_accessor('month', 'The month of the datetime') @@ -149,10 +206,14 @@ class CFTimeIndex(pd.Index): 'The microseconds of the datetime') date_type = property(get_date_type) - def __new__(cls, data): + def __new__(cls, data, name=None): + if name is None and hasattr(data, 'name'): + name = data.name + result = object.__new__(cls) - assert_all_valid_date_type(data) - result._data = np.array(data) + result._data = np.array(data, dtype='O') + assert_all_valid_date_type(result._data) + result.name = name return result def _partial_date_slice(self, resolution, parsed): @@ -254,3 +315,80 @@ def __contains__(self, key): def contains(self, key): """Needed for .loc based partial-string indexing""" return self.__contains__(key) + + def shift(self, n, freq): + """Shift the CFTimeIndex a multiple of the given frequency. + + See the documentation for :py:func:`~xarray.cftime_range` for a + complete listing of valid frequency strings. + + Parameters + ---------- + n : int + Periods to shift by + freq : str or datetime.timedelta + A frequency string or datetime.timedelta object to shift by + + Returns + ------- + CFTimeIndex + + See also + -------- + pandas.DatetimeIndex.shift + + Examples + -------- + >>> index = xr.cftime_range('2000', periods=1, freq='M') + >>> index + CFTimeIndex([2000-01-31 00:00:00], dtype='object') + >>> index.shift(1, 'M') + CFTimeIndex([2000-02-29 00:00:00], dtype='object') + """ + from .cftime_offsets import to_offset + + if not isinstance(n, int): + raise TypeError("'n' must be an int, got {}.".format(n)) + if isinstance(freq, timedelta): + return self + n * freq + elif isinstance(freq, pycompat.basestring): + return self + n * to_offset(freq) + else: + raise TypeError( + "'freq' must be of type " + "str or datetime.timedelta, got {}.".format(freq)) + + def __add__(self, other): + return CFTimeIndex(np.array(self) + other) + + def __radd__(self, other): + return CFTimeIndex(other + np.array(self)) + + def __sub__(self, other): + return CFTimeIndex(np.array(self) - other) + + +def _parse_iso8601_without_reso(date_type, datetime_str): + date, _ = _parse_iso8601_with_reso(date_type, datetime_str) + return date + + +def _parse_array_of_cftime_strings(strings, date_type): + """Create a numpy array from an array of strings. + + For use in generating dates from strings for use with interp. Assumes the + array is either 0-dimensional or 1-dimensional. + + Parameters + ---------- + strings : array of strings + Strings to convert to dates + date_type : cftime.datetime type + Calendar type to use for dates + + Returns + ------- + np.array + """ + return np.array([_parse_iso8601_without_reso(date_type, s) + for s in strings.ravel()]).reshape(strings.shape) diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index 87b17d9175e..3502fd773d7 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -9,8 +9,8 @@ from ..core.pycompat import bytes_type, dask_array_type, unicode_type from ..core.variable import Variable from .variables import ( - VariableCoder, lazy_elemwise_func, pop_to, - safe_setitem, unpack_for_decoding, unpack_for_encoding) + VariableCoder, lazy_elemwise_func, pop_to, safe_setitem, + unpack_for_decoding, unpack_for_encoding) def create_vlen_dtype(element_type): diff --git a/xarray/coding/times.py b/xarray/coding/times.py index d946e2ed378..dff7e75bdcf 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -9,8 +9,8 @@ import numpy as np import pandas as pd -from ..core.common import contains_cftime_datetimes from ..core import indexing +from ..core.common import contains_cftime_datetimes from ..core.formatting import first_n_items, format_timestamp, last_item from ..core.options import OPTIONS from ..core.pycompat import PY3 @@ -183,8 +183,11 @@ def decode_cf_datetime(num_dates, units, calendar=None, # fixes: https://github.com/pydata/pandas/issues/14068 # these lines check if the the lowest or the highest value in dates # cause an OutOfBoundsDatetime (Overflow) error - pd.to_timedelta(flat_num_dates.min(), delta) + ref_date - pd.to_timedelta(flat_num_dates.max(), delta) + ref_date + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'invalid value encountered', + RuntimeWarning) + pd.to_timedelta(flat_num_dates.min(), delta) + ref_date + pd.to_timedelta(flat_num_dates.max(), delta) + ref_date # Cast input dates to integers of nanoseconds because `pd.to_datetime` # works much faster when dealing with integers diff --git a/xarray/conventions.py b/xarray/conventions.py index 67dcb8d6d4e..f60ee6b2c15 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -6,11 +6,11 @@ import numpy as np import pandas as pd -from .coding import times, strings, variables +from .coding import strings, times, variables from .coding.variables import SerializationWarning from .core import duck_array_ops, indexing from .core.pycompat import ( - OrderedDict, basestring, bytes_type, iteritems, dask_array_type, + OrderedDict, basestring, bytes_type, dask_array_type, iteritems, unicode_type) from .core.variable import IndexVariable, Variable, as_variable diff --git a/xarray/core/accessors.py b/xarray/core/accessors.py index 81af0532d93..72791ed73ec 100644 --- a/xarray/core/accessors.py +++ b/xarray/core/accessors.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd -from .common import is_np_datetime_like, _contains_datetime_like_objects +from .common import _contains_datetime_like_objects, is_np_datetime_like from .pycompat import dask_array_type diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index b0d2a49c29f..f82ddef25ba 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -174,11 +174,14 @@ def deep_align(objects, join='inner', copy=True, indexes=None, This function is not public API. """ + from .dataarray import DataArray + from .dataset import Dataset + if indexes is None: indexes = {} def is_alignable(obj): - return hasattr(obj, 'indexes') and hasattr(obj, 'reindex') + return isinstance(obj, (DataArray, Dataset)) positions = [] keys = [] diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 430f0e564d6..6853939c02d 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -8,8 +8,8 @@ from .alignment import align from .merge import merge from .pycompat import OrderedDict, basestring, iteritems -from .variable import concat as concat_vars from .variable import IndexVariable, Variable, as_variable +from .variable import concat as concat_vars def concat(objs, dim=None, data_vars='all', coords='different', @@ -125,16 +125,17 @@ def _calc_concat_dim_coord(dim): Infer the dimension name and 1d coordinate variable (if appropriate) for concatenating along the new dimension. """ + from .dataarray import DataArray + if isinstance(dim, basestring): coord = None - elif not hasattr(dim, 'dims'): - # dim is not a DataArray or IndexVariable + elif not isinstance(dim, (DataArray, Variable)): dim_name = getattr(dim, 'name', None) if dim_name is None: dim_name = 'concat_dim' coord = IndexVariable(dim_name, dim) dim = dim_name - elif not hasattr(dim, 'name'): + elif not isinstance(dim, DataArray): coord = as_variable(dim).to_index_variable() dim, = coord.dims else: diff --git a/xarray/core/common.py b/xarray/core/common.py index 3f934fcc769..c74b1fa080b 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -2,14 +2,18 @@ import warnings from distutils.version import LooseVersion +from textwrap import dedent import numpy as np import pandas as pd -from . import duck_array_ops, dtypes, formatting, ops +from . import dtypes, duck_array_ops, formatting, ops from .arithmetic import SupportsArithmetic from .pycompat import OrderedDict, basestring, dask_array_type, suppress -from .utils import Frozen, SortedKeysDict +from .utils import Frozen, ReprObject, SortedKeysDict, either_dict_or_kwargs + +# Used as a sentinel value to indicate a all dimensions +ALL_DIMS = ReprObject('') class ImplementsArrayReduce(object): @@ -27,20 +31,20 @@ def wrapped_func(self, dim=None, axis=None, keep_attrs=False, allow_lazy=True, **kwargs) return wrapped_func - _reduce_extra_args_docstring = \ - """dim : str or sequence of str, optional + _reduce_extra_args_docstring = dedent("""\ + dim : str or sequence of str, optional Dimension(s) over which to apply `{name}`. axis : int or sequence of int, optional Axis(es) over which to apply `{name}`. Only one of the 'dim' and 'axis' arguments can be supplied. If neither are supplied, then - `{name}` is calculated over axes.""" + `{name}` is calculated over axes.""") - _cum_extra_args_docstring = \ - """dim : str or sequence of str, optional + _cum_extra_args_docstring = dedent("""\ + dim : str or sequence of str, optional Dimension over which to apply `{name}`. axis : int or sequence of int, optional Axis over which to apply `{name}`. Only one of the 'dim' - and 'axis' arguments can be supplied.""" + and 'axis' arguments can be supplied.""") class ImplementsDatasetReduce(object): @@ -308,12 +312,12 @@ def assign_coords(self, **kwargs): assigned : same type as caller A new object with the new coordinates in addition to the existing data. - + Examples -------- - + Convert longitude coordinates from 0-359 to -180-179: - + >>> da = xr.DataArray(np.random.rand(4), ... coords=[np.array([358, 359, 0, 1])], ... dims='lon') @@ -445,11 +449,11 @@ def groupby(self, group, squeeze=True): grouped : GroupBy A `GroupBy` object patterned after `pandas.GroupBy` that can be iterated over in the form of `(unique_value, grouped_array)` pairs. - + Examples -------- Calculate daily anomalies for daily data: - + >>> da = xr.DataArray(np.linspace(0, 1826, num=1827), ... coords=[pd.date_range('1/1/2000', '31/12/2004', ... freq='D')], @@ -465,7 +469,7 @@ def groupby(self, group, squeeze=True): Coordinates: * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ... dayofyear (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 ... - + See Also -------- core.groupby.DataArrayGroupBy @@ -525,24 +529,24 @@ def groupby_bins(self, group, bins, right=True, labels=None, precision=3, 'precision': precision, 'include_lowest': include_lowest}) - def rolling(self, min_periods=None, center=False, **windows): + def rolling(self, dim=None, min_periods=None, center=False, **dim_kwargs): """ Rolling window object. Parameters ---------- + dim: dict, optional + Mapping from the dimension name to create the rolling iterator + along (e.g. `time`) to its moving window size. min_periods : int, default None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **windows : dim=window - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. + **dim_kwargs : optional + The keyword arguments form of ``dim``. + One of dim or dim_kwarg must be provided. Returns ------- @@ -581,15 +585,15 @@ def rolling(self, min_periods=None, center=False, **windows): core.rolling.DataArrayRolling core.rolling.DatasetRolling """ - - return self._rolling_cls(self, min_periods=min_periods, - center=center, **windows) + dim = either_dict_or_kwargs(dim, dim_kwargs, 'rolling') + return self._rolling_cls(self, dim, min_periods=min_periods, + center=center) def resample(self, freq=None, dim=None, how=None, skipna=None, closed=None, label=None, base=0, keep_attrs=False, **indexer): """Returns a Resample object for performing resampling operations. - Handles both downsampling and upsampling. If any intervals contain no + Handles both downsampling and upsampling. If any intervals contain no values from the original object, they will be given the value ``NaN``. Parameters @@ -616,11 +620,11 @@ def resample(self, freq=None, dim=None, how=None, skipna=None, ------- resampled : same type as caller This object resampled. - + Examples -------- Downsample monthly time-series data to seasonal data: - + >>> da = xr.DataArray(np.linspace(0, 11, num=12), ... coords=[pd.date_range('15/12/1999', ... periods=12, freq=pd.DateOffset(months=1))], @@ -637,18 +641,20 @@ def resample(self, freq=None, dim=None, how=None, skipna=None, * time (time) datetime64[ns] 1999-12-01 2000-03-01 2000-06-01 2000-09-01 Upsample monthly time-series data to daily data: - + >>> da.resample(time='1D').interpolate('linear') array([ 0. , 0.032258, 0.064516, ..., 10.935484, 10.967742, 11. ]) Coordinates: * time (time) datetime64[ns] 1999-12-15 1999-12-16 1999-12-17 ... - + References ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases """ + # TODO support non-string indexer after removing the old API. + from .dataarray import DataArray from .resample import RESAMPLE_DIM @@ -957,8 +963,8 @@ def contains_cftime_datetimes(var): sample = sample.item() return isinstance(sample, cftime_datetime) else: - return False - + return False + def _contains_datetime_like_objects(var): """Check if a variable contains datetime like objects (either diff --git a/xarray/core/computation.py b/xarray/core/computation.py index bdba72cb48a..7998cc4f72f 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -2,19 +2,19 @@ Functions for applying functions that act on arrays to xarray's labeled data. """ from __future__ import absolute_import, division, print_function -from distutils.version import LooseVersion + import functools import itertools import operator from collections import Counter +from distutils.version import LooseVersion import numpy as np -from . import duck_array_ops -from . import utils +from . import duck_array_ops, utils from .alignment import deep_align from .merge import expand_and_merge_variables -from .pycompat import OrderedDict, dask_array_type, basestring +from .pycompat import OrderedDict, basestring, dask_array_type from .utils import is_dict_like _DEFAULT_FROZEN_SET = frozenset() diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index c2417345f55..6b53dcffe6e 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -1,7 +1,10 @@ from __future__ import absolute_import, division, print_function -import numpy as np +from distutils.version import LooseVersion + import dask.array as da +import numpy as np +from dask import __version__ as dask_version try: from dask.array import isin @@ -30,3 +33,130 @@ def isin(element, test_elements, assume_unique=False, invert=False): if invert: result = ~result return result + + +if LooseVersion(dask_version) > LooseVersion('1.19.2'): + gradient = da.gradient + +else: # pragma: no cover + # Copied from dask v0.19.2 + # Used under the terms of Dask's license, see licenses/DASK_LICENSE. + import math + from numbers import Integral, Real + + try: + AxisError = np.AxisError + except AttributeError: + try: + np.array([0]).sum(axis=5) + except Exception as e: + AxisError = type(e) + + def validate_axis(axis, ndim): + """ Validate an input to axis= keywords """ + if isinstance(axis, (tuple, list)): + return tuple(validate_axis(ax, ndim) for ax in axis) + if not isinstance(axis, Integral): + raise TypeError("Axis value must be an integer, got %s" % axis) + if axis < -ndim or axis >= ndim: + raise AxisError("Axis %d is out of bounds for array of dimension " + "%d" % (axis, ndim)) + if axis < 0: + axis += ndim + return axis + + def _gradient_kernel(x, block_id, coord, axis, array_locs, grad_kwargs): + """ + x: nd-array + array of one block + coord: 1d-array or scalar + coordinate along which the gradient is computed. + axis: int + axis along which the gradient is computed + array_locs: + actual location along axis. None if coordinate is scalar + grad_kwargs: + keyword to be passed to np.gradient + """ + block_loc = block_id[axis] + if array_locs is not None: + coord = coord[array_locs[0][block_loc]:array_locs[1][block_loc]] + grad = np.gradient(x, coord, axis=axis, **grad_kwargs) + return grad + + def gradient(f, *varargs, **kwargs): + f = da.asarray(f) + + kwargs["edge_order"] = math.ceil(kwargs.get("edge_order", 1)) + if kwargs["edge_order"] > 2: + raise ValueError("edge_order must be less than or equal to 2.") + + drop_result_list = False + axis = kwargs.pop("axis", None) + if axis is None: + axis = tuple(range(f.ndim)) + elif isinstance(axis, Integral): + drop_result_list = True + axis = (axis,) + + axis = validate_axis(axis, f.ndim) + + if len(axis) != len(set(axis)): + raise ValueError("duplicate axes not allowed") + + axis = tuple(ax % f.ndim for ax in axis) + + if varargs == (): + varargs = (1,) + if len(varargs) == 1: + varargs = len(axis) * varargs + if len(varargs) != len(axis): + raise TypeError( + "Spacing must either be a single scalar, or a scalar / " + "1d-array per axis" + ) + + if issubclass(f.dtype.type, (np.bool8, Integral)): + f = f.astype(float) + elif issubclass(f.dtype.type, Real) and f.dtype.itemsize < 4: + f = f.astype(float) + + results = [] + for i, ax in enumerate(axis): + for c in f.chunks[ax]: + if np.min(c) < kwargs["edge_order"] + 1: + raise ValueError( + 'Chunk size must be larger than edge_order + 1. ' + 'Minimum chunk for aixs {} is {}. Rechunk to ' + 'proceed.'.format(np.min(c), ax)) + + if np.isscalar(varargs[i]): + array_locs = None + else: + if isinstance(varargs[i], da.Array): + raise NotImplementedError( + 'dask array coordinated is not supported.') + # coordinate position for each block taking overlap into + # account + chunk = np.array(f.chunks[ax]) + array_loc_stop = np.cumsum(chunk) + 1 + array_loc_start = array_loc_stop - chunk - 2 + array_loc_stop[-1] -= 1 + array_loc_start[0] = 0 + array_locs = (array_loc_start, array_loc_stop) + + results.append(f.map_overlap( + _gradient_kernel, + dtype=f.dtype, + depth={j: 1 if j == ax else 0 for j in range(f.ndim)}, + boundary="none", + coord=varargs[i], + axis=ax, + array_locs=array_locs, + grad_kwargs=kwargs, + )) + + if drop_result_list: + results = results[0] + + return results diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 423a65aa3c2..25c572edd54 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -1,10 +1,10 @@ from __future__ import absolute_import, division, print_function + from distutils.version import LooseVersion import numpy as np -from . import nputils -from . import dtypes +from . import dtypes, nputils try: import dask diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index f215bc47df8..f131b003a69 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -677,14 +677,77 @@ def persist(self, **kwargs): ds = self._to_temp_dataset().persist(**kwargs) return self._from_temp_dataset(ds) - def copy(self, deep=True): + def copy(self, deep=True, data=None): """Returns a copy of this array. - If `deep=True`, a deep copy is made of all variables in the underlying - dataset. Otherwise, a shallow copy is made, so each variable in the new + If `deep=True`, a deep copy is made of the data array. + Otherwise, a shallow copy is made, so each variable in the new array's dataset is also a variable in this array's dataset. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, optional + Whether the data array and its coordinates are loaded into memory + and copied onto the new object. Default is True. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + When `data` is used, `deep` is ignored for all data variables, + and only used for coords. + + Returns + ------- + object : DataArray + New object with dimensions, attributes, coordinates, name, + encoding, and optionally data copied from original. + + Examples + -------- + + Shallow versus deep copy + + >>> array = xr.DataArray([1, 2, 3], dims='x', + ... coords={'x': ['a', 'b', 'c']}) + >>> array.copy() + + array([1, 2, 3]) + Coordinates: + * x (x) >> array_0 = array.copy(deep=False) + >>> array_0[0] = 7 + >>> array_0 + + array([7, 2, 3]) + Coordinates: + * x (x) >> array + + array([7, 2, 3]) + Coordinates: + * x (x) >> array.copy(data=[0.1, 0.2, 0.3]) + + array([ 0.1, 0.2, 0.3]) + Coordinates: + * x (x) >> array + + array([1, 2, 3]) + Coordinates: + * x (x) >> arr = DataArray(np.arange(6).reshape(2, 3), + ... coords=[('x', ['a', 'b']), ('y', [0, 1, 2])]) + >>> arr + + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + * x (x) |S1 'a' 'b' + * y (y) int64 0 1 2 + >>> stacked = arr.stack(z=('x', 'y')) + >>> stacked.indexes['z'] + MultiIndex(levels=[[u'a', u'b'], [0, 1, 2]], + labels=[[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]], + names=[u'x', u'y']) + >>> roundtripped = stacked.unstack() + >>> arr.identical(roundtripped) + True + See also -------- DataArray.stack @@ -1856,7 +1955,7 @@ def _binary_op(f, reflexive=False, join=None, **ignored_kwargs): def func(self, other): if isinstance(other, (Dataset, groupby.GroupBy)): return NotImplemented - if hasattr(other, 'indexes'): + if isinstance(other, DataArray): align_type = (OPTIONS['arithmetic_join'] if join is None else join) self, other = align(self, other, join=align_type, copy=False) @@ -1974,11 +2073,14 @@ def diff(self, dim, n=1, label='upper'): Coordinates: * x (x) int64 3 4 + See Also + -------- + DataArray.differentiate """ ds = self._to_temp_dataset().diff(n=n, dim=dim, label=label) return self._from_temp_dataset(ds) - def shift(self, **shifts): + def shift(self, shifts=None, **shifts_kwargs): """Shift this array by an offset along one or more dimensions. Only the data is moved; coordinates stay in place. Values shifted from @@ -1987,10 +2089,13 @@ def shift(self, **shifts): Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} + shifts : Mapping with the form of {dim: offset} Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. + **shifts_kwargs: + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwarg must be provided. Returns ------- @@ -2012,17 +2117,23 @@ def shift(self, **shifts): Coordinates: * x (x) int64 0 1 2 """ - variable = self.variable.shift(**shifts) - return self._replace(variable) + ds = self._to_temp_dataset().shift(shifts=shifts, **shifts_kwargs) + return self._from_temp_dataset(ds) - def roll(self, **shifts): + def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): """Roll this array by an offset along one or more dimensions. - Unlike shift, roll rotates all variables, including coordinates. The - direction of rotation is consistent with :py:func:`numpy.roll`. + Unlike shift, roll may rotate all variables, including coordinates + if specified. The direction of rotation is consistent with + :py:func:`numpy.roll`. Parameters ---------- + roll_coords : bool + Indicates whether to roll the coordinates by the offset + The current default of roll_coords (None, equivalent to True) is + deprecated and will change to False in a future version. + Explicitly pass roll_coords to silence the warning. **shifts : keyword arguments of the form {dim: offset} Integer offset to rotate each of the given dimensions. Positive offsets roll to the right; negative offsets roll to the left. @@ -2046,7 +2157,8 @@ def roll(self, **shifts): Coordinates: * x (x) int64 2 0 1 """ - ds = self._to_temp_dataset().roll(**shifts) + ds = self._to_temp_dataset().roll( + shifts=shifts, roll_coords=roll_coords, **shifts_kwargs) return self._from_temp_dataset(ds) @property @@ -2243,6 +2355,61 @@ def rank(self, dim, pct=False, keep_attrs=False): ds = self._to_temp_dataset().rank(dim, pct=pct, keep_attrs=keep_attrs) return self._from_temp_dataset(ds) + def differentiate(self, coord, edge_order=1, datetime_unit=None): + """ Differentiate the array with the second order accurate central + differences. + + .. note:: + This feature is limited to simple cartesian geometry, i.e. coord + must be one dimensional. + + Parameters + ---------- + coord: str + The coordinate to be used to compute the gradient. + edge_order: 1 or 2. Default 1 + N-th order accurate differences at the boundaries. + datetime_unit: None or any of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', + 'us', 'ns', 'ps', 'fs', 'as'} + Unit to compute gradient. Only valid for datetime coordinate. + + Returns + ------- + differentiated: DataArray + + See also + -------- + numpy.gradient: corresponding numpy function + + Examples + -------- + + >>> da = xr.DataArray(np.arange(12).reshape(4, 3), dims=['x', 'y'], + ... coords={'x': [0, 0.1, 1.1, 1.2]}) + >>> da + + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Coordinates: + * x (x) float64 0.0 0.1 1.1 1.2 + Dimensions without coordinates: y + >>> + >>> da.differentiate('x') + + array([[30. , 30. , 30. ], + [27.545455, 27.545455, 27.545455], + [27.545455, 27.545455, 27.545455], + [30. , 30. , 30. ]]) + Coordinates: + * x (x) float64 0.0 0.1 1.1 1.2 + Dimensions without coordinates: y + """ + ds = self._to_temp_dataset().differentiate( + coord, edge_order, datetime_unit) + return self._from_temp_dataset(ds) + # priority most be higher than Variable to properly work with binary ufuncs ops.inject_all_ops_and_reduce_methods(DataArray, priority=60) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4b52178ad0e..c8586d1d408 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -13,12 +13,14 @@ import xarray as xr from . import ( - alignment, duck_array_ops, formatting, groupby, indexing, ops, resample, - rolling, utils) + alignment, computation, duck_array_ops, formatting, groupby, indexing, ops, + resample, rolling, utils) from .. import conventions +from ..coding.cftimeindex import _parse_array_of_cftime_strings from .alignment import align from .common import ( - DataWithCoords, ImplementsDatasetReduce, _contains_datetime_like_objects) + ALL_DIMS, DataWithCoords, ImplementsDatasetReduce, + _contains_datetime_like_objects) from .coordinates import ( DatasetCoordinates, Indexes, LevelCoordinatesSource, assert_coordinate_consistent, remap_label_indexers) @@ -30,8 +32,9 @@ from .pycompat import ( OrderedDict, basestring, dask_array_type, integer_types, iteritems, range) from .utils import ( - Frozen, SortedKeysDict, either_dict_or_kwargs, decode_numpy_dict_values, - ensure_us_time_resolution, hashable, maybe_wrap_array) + Frozen, SortedKeysDict, datetime_to_numeric, decode_numpy_dict_values, + either_dict_or_kwargs, ensure_us_time_resolution, hashable, + maybe_wrap_array) from .variable import IndexVariable, Variable, as_variable, broadcast_variables # list of attributes of pd.DatetimeIndex that are ndarrays of time info @@ -709,16 +712,120 @@ def _replace_indexes(self, indexes): obj = obj.rename(dim_names) return obj - def copy(self, deep=False): + def copy(self, deep=False, data=None): """Returns a copy of this dataset. If `deep=True`, a deep copy is made of each of the component variables. Otherwise, a shallow copy of each of the component variable is made, so that the underlying memory region of the new dataset is the same as in the original dataset. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, optional + Whether each component variable is loaded into memory and copied onto + the new object. Default is True. + data : dict-like, optional + Data to use in the new object. Each item in `data` must have same + shape as corresponding data variable in original. When `data` is + used, `deep` is ignored for the data variables and only used for + coords. + + Returns + ------- + object : Dataset + New object with dimensions, attributes, coordinates, name, encoding, + and optionally data copied from original. + + Examples + -------- + + Shallow copy versus deep copy + + >>> da = xr.DataArray(np.random.randn(2, 3)) + >>> ds = xr.Dataset({'foo': da, 'bar': ('x', [-1, 2])}, + coords={'x': ['one', 'two']}) + >>> ds.copy() + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds_0 = ds.copy(deep=False) + >>> ds_0['foo'][0, 0] = 7 + >>> ds_0 + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds.copy(data={'foo': np.arange(6).reshape(2, 3), 'bar': ['a', 'b']}) + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) >> ds + + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Coordinates: + * x (x) 1: + warnings.warn( + "Default reduction dimension will be changed to the " + "grouped dimension after xarray 0.12. To silence this " + "warning, pass dim=xarray.ALL_DIMS explicitly.", + FutureWarning, stacklevel=2) + def reduce_array(ar): return ar.reduce(func, dim, axis, keep_attrs=keep_attrs, **kwargs) return self.apply(reduce_array, shortcut=shortcut) + # TODO remove the following class method and DEFAULT_DIMS after the + # deprecation cycle + @classmethod + def _reduce_method(cls, func, include_skipna, numeric_only): + if include_skipna: + def wrapped_func(self, dim=DEFAULT_DIMS, axis=None, skipna=None, + keep_attrs=False, **kwargs): + return self.reduce(func, dim, axis, keep_attrs=keep_attrs, + skipna=skipna, allow_lazy=True, **kwargs) + else: + def wrapped_func(self, dim=DEFAULT_DIMS, axis=None, + keep_attrs=False, **kwargs): + return self.reduce(func, dim, axis, keep_attrs=keep_attrs, + allow_lazy=True, **kwargs) + return wrapped_func + + +DEFAULT_DIMS = utils.ReprObject('') ops.inject_reduce_methods(DataArrayGroupBy) ops.inject_binary_ops(DataArrayGroupBy) @@ -649,10 +679,40 @@ def reduce(self, func, dim=None, keep_attrs=False, **kwargs): Array with summarized data and the indicated dimension(s) removed. """ + if dim == DEFAULT_DIMS: + dim = ALL_DIMS + # TODO change this to dim = self._group_dim after + # the deprecation process. Do not forget to remove _reduce_method + warnings.warn( + "Default reduction dimension will be changed to the " + "grouped dimension after xarray 0.12. To silence this " + "warning, pass dim=xarray.ALL_DIMS explicitly.", + FutureWarning, stacklevel=2) + elif dim is None: + dim = self._group_dim + def reduce_dataset(ds): return ds.reduce(func, dim, keep_attrs, **kwargs) return self.apply(reduce_dataset) + # TODO remove the following class method and DEFAULT_DIMS after the + # deprecation cycle + @classmethod + def _reduce_method(cls, func, include_skipna, numeric_only): + if include_skipna: + def wrapped_func(self, dim=DEFAULT_DIMS, keep_attrs=False, + skipna=None, **kwargs): + return self.reduce(func, dim, keep_attrs, skipna=skipna, + numeric_only=numeric_only, allow_lazy=True, + **kwargs) + else: + def wrapped_func(self, dim=DEFAULT_DIMS, keep_attrs=False, + **kwargs): + return self.reduce(func, dim, keep_attrs, + numeric_only=numeric_only, allow_lazy=True, + **kwargs) + return wrapped_func + def assign(self, **kwargs): """Assign data variables by group. diff --git a/xarray/core/merge.py b/xarray/core/merge.py index f823717a8af..984dd2fa204 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -190,10 +190,13 @@ def expand_variable_dicts(list_of_variable_dicts): an input's values. The values of each ordered dictionary are all xarray.Variable objects. """ + from .dataarray import DataArray + from .dataset import Dataset + var_dicts = [] for variables in list_of_variable_dicts: - if hasattr(variables, 'variables'): # duck-type Dataset + if isinstance(variables, Dataset): sanitized_vars = variables.variables else: # append coords to var_dicts before appending sanitized_vars, @@ -201,7 +204,7 @@ def expand_variable_dicts(list_of_variable_dicts): sanitized_vars = OrderedDict() for name, var in variables.items(): - if hasattr(var, '_coords'): # duck-type DataArray + if isinstance(var, DataArray): # use private API for speed coords = var._coords.copy() # explicitly overwritten variables should take precedence @@ -232,17 +235,19 @@ def determine_coords(list_of_variable_dicts): All variable found in the input should appear in either the set of coordinate or non-coordinate names. """ + from .dataarray import DataArray + from .dataset import Dataset + coord_names = set() noncoord_names = set() for variables in list_of_variable_dicts: - if hasattr(variables, 'coords') and hasattr(variables, 'data_vars'): - # duck-type Dataset + if isinstance(variables, Dataset): coord_names.update(variables.coords) noncoord_names.update(variables.data_vars) else: for name, var in variables.items(): - if hasattr(var, '_coords'): # duck-type DataArray + if isinstance(var, DataArray): coords = set(var._coords) # use private API for speed # explicitly overwritten variables should take precedence coords.discard(name) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index bec9e2e1931..3f4e0fc3ac9 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, division, print_function +import warnings from collections import Iterable from functools import partial @@ -7,11 +8,12 @@ import pandas as pd from . import rolling +from .common import _contains_datetime_like_objects from .computation import apply_ufunc +from .duck_array_ops import dask_array_type from .pycompat import iteritems -from .utils import is_scalar, OrderedSet +from .utils import OrderedSet, datetime_to_numeric, is_scalar from .variable import Variable, broadcast_variables -from .duck_array_ops import dask_array_type class BaseInterpolator(object): @@ -57,7 +59,7 @@ def __init__(self, xi, yi, method='linear', fill_value=None, **kwargs): if self.cons_kwargs: raise ValueError( - 'recieved invalid kwargs: %r' % self.cons_kwargs.keys()) + 'received invalid kwargs: %r' % self.cons_kwargs.keys()) if fill_value is None: self._left = np.nan @@ -207,13 +209,16 @@ def interp_na(self, dim=None, use_coordinate=True, method='linear', limit=None, interp_class, kwargs = _get_interpolator(method, **kwargs) interpolator = partial(func_interpolate_na, interp_class, **kwargs) - arr = apply_ufunc(interpolator, index, self, - input_core_dims=[[dim], [dim]], - output_core_dims=[[dim]], - output_dtypes=[self.dtype], - dask='parallelized', - vectorize=True, - keep_attrs=True).transpose(*self.dims) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'overflow', RuntimeWarning) + warnings.filterwarnings('ignore', 'invalid value', RuntimeWarning) + arr = apply_ufunc(interpolator, index, self, + input_core_dims=[[dim], [dim]], + output_core_dims=[[dim]], + output_dtypes=[self.dtype], + dask='parallelized', + vectorize=True, + keep_attrs=True).transpose(*self.dims) if limit is not None: arr = arr.where(valids) @@ -402,15 +407,16 @@ def _floatize_x(x, new_x): x = list(x) new_x = list(new_x) for i in range(len(x)): - if x[i].dtype.kind in 'Mm': + if _contains_datetime_like_objects(x[i]): # Scipy casts coordinates to np.float64, which is not accurate # enough for datetime64 (uses 64bit integer). # We assume that the most of the bits are used to represent the # offset (min(x)) and the variation (x - min(x)) can be # represented by float. - xmin = np.min(x[i]) - x[i] = (x[i] - xmin).astype(np.float64) - new_x[i] = (new_x[i] - xmin).astype(np.float64) + xmin = x[i].min() + x[i] = datetime_to_numeric(x[i], offset=xmin, dtype=np.float64) + new_x[i] = datetime_to_numeric( + new_x[i], offset=xmin, dtype=np.float64) return x, new_x diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py new file mode 100644 index 00000000000..4d3f03c899e --- /dev/null +++ b/xarray/core/nanops.py @@ -0,0 +1,207 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np + +from . import dtypes, nputils +from .duck_array_ops import ( + _dask_or_eager_func, count, fillna, isnull, where_method) +from .pycompat import dask_array_type + +try: + import dask.array as dask_array +except ImportError: + dask_array = None + + +def _replace_nan(a, val): + """ + replace nan in a by val, and returns the replaced array and the nan + position + """ + mask = isnull(a) + return where_method(val, mask, a), mask + + +def _maybe_null_out(result, axis, mask, min_count=1): + """ + xarray version of pandas.core.nanops._maybe_null_out + """ + if hasattr(axis, '__len__'): # if tuple or list + raise ValueError('min_count is not available for reduction ' + 'with more than one dimensions.') + + if axis is not None and getattr(result, 'ndim', False): + null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0 + if null_mask.any(): + dtype, fill_value = dtypes.maybe_promote(result.dtype) + result = result.astype(dtype) + result[null_mask] = fill_value + + elif getattr(result, 'dtype', None) not in dtypes.NAT_TYPES: + null_mask = mask.size - mask.sum() + if null_mask < min_count: + result = np.nan + + return result + + +def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs): + """ In house nanargmin, nanargmax for object arrays. Always return integer + type + """ + valid_count = count(value, axis=axis) + value = fillna(value, fill_value) + data = _dask_or_eager_func(func)(value, axis=axis, **kwargs) + + # TODO This will evaluate dask arrays and might be costly. + if (valid_count == 0).any(): + raise ValueError('All-NaN slice encountered') + + return data + + +def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs): + """ In house nanmin and nanmax for object array """ + valid_count = count(value, axis=axis) + filled_value = fillna(value, fill_value) + data = getattr(np, func)(filled_value, axis=axis, **kwargs) + if not hasattr(data, 'dtype'): # scalar case + data = dtypes.fill_value(value.dtype) if valid_count == 0 else data + return np.array(data, dtype=value.dtype) + return where_method(data, valid_count != 0) + + +def nanmin(a, axis=None, out=None): + if a.dtype.kind == 'O': + return _nan_minmax_object( + 'min', dtypes.get_pos_infinity(a.dtype), a, axis) + + module = dask_array if isinstance(a, dask_array_type) else nputils + return module.nanmin(a, axis=axis) + + +def nanmax(a, axis=None, out=None): + if a.dtype.kind == 'O': + return _nan_minmax_object( + 'max', dtypes.get_neg_infinity(a.dtype), a, axis) + + module = dask_array if isinstance(a, dask_array_type) else nputils + return module.nanmax(a, axis=axis) + + +def nanargmin(a, axis=None): + fill_value = dtypes.get_pos_infinity(a.dtype) + if a.dtype.kind == 'O': + return _nan_argminmax_object('argmin', fill_value, a, axis=axis) + a, mask = _replace_nan(a, fill_value) + if isinstance(a, dask_array_type): + res = dask_array.argmin(a, axis=axis) + else: + res = np.argmin(a, axis=axis) + + if mask is not None: + mask = mask.all(axis=axis) + if mask.any(): + raise ValueError("All-NaN slice encountered") + return res + + +def nanargmax(a, axis=None): + fill_value = dtypes.get_neg_infinity(a.dtype) + if a.dtype.kind == 'O': + return _nan_argminmax_object('argmax', fill_value, a, axis=axis) + + a, mask = _replace_nan(a, fill_value) + if isinstance(a, dask_array_type): + res = dask_array.argmax(a, axis=axis) + else: + res = np.argmax(a, axis=axis) + + if mask is not None: + mask = mask.all(axis=axis) + if mask.any(): + raise ValueError("All-NaN slice encountered") + return res + + +def nansum(a, axis=None, dtype=None, out=None, min_count=None): + a, mask = _replace_nan(a, 0) + result = _dask_or_eager_func('sum')(a, axis=axis, dtype=dtype) + if min_count is not None: + return _maybe_null_out(result, axis, mask, min_count) + else: + return result + + +def _nanmean_ddof_object(ddof, value, axis=None, **kwargs): + """ In house nanmean. ddof argument will be used in _nanvar method """ + from .duck_array_ops import (count, fillna, _dask_or_eager_func, + where_method) + + valid_count = count(value, axis=axis) + value = fillna(value, 0) + # As dtype inference is impossible for object dtype, we assume float + # https://github.com/dask/dask/issues/3162 + dtype = kwargs.pop('dtype', None) + if dtype is None and value.dtype.kind == 'O': + dtype = value.dtype if value.dtype.kind in ['cf'] else float + + data = _dask_or_eager_func('sum')(value, axis=axis, dtype=dtype, **kwargs) + data = data / (valid_count - ddof) + return where_method(data, valid_count != 0) + + +def nanmean(a, axis=None, dtype=None, out=None): + if a.dtype.kind == 'O': + return _nanmean_ddof_object(0, a, axis=axis, dtype=dtype) + + if isinstance(a, dask_array_type): + return dask_array.nanmean(a, axis=axis, dtype=dtype) + + return np.nanmean(a, axis=axis, dtype=dtype) + + +def nanmedian(a, axis=None, out=None): + return _dask_or_eager_func('nanmedian', eager_module=nputils)(a, axis=axis) + + +def _nanvar_object(value, axis=None, **kwargs): + ddof = kwargs.pop('ddof', 0) + kwargs_mean = kwargs.copy() + kwargs_mean.pop('keepdims', None) + value_mean = _nanmean_ddof_object(ddof=0, value=value, axis=axis, + keepdims=True, **kwargs_mean) + squared = (value.astype(value_mean.dtype) - value_mean)**2 + return _nanmean_ddof_object(ddof, squared, axis=axis, **kwargs) + + +def nanvar(a, axis=None, dtype=None, out=None, ddof=0): + if a.dtype.kind == 'O': + return _nanvar_object(a, axis=axis, dtype=dtype, ddof=ddof) + + return _dask_or_eager_func('nanvar', eager_module=nputils)( + a, axis=axis, dtype=dtype, ddof=ddof) + + +def nanstd(a, axis=None, dtype=None, out=None, ddof=0): + return _dask_or_eager_func('nanstd', eager_module=nputils)( + a, axis=axis, dtype=dtype, ddof=ddof) + + +def nanprod(a, axis=None, dtype=None, out=None, min_count=None): + a, mask = _replace_nan(a, 1) + result = _dask_or_eager_func('nanprod')(a, axis=axis, dtype=dtype, out=out) + if min_count is not None: + return _maybe_null_out(result, axis, mask, min_count) + else: + return result + + +def nancumsum(a, axis=None, dtype=None, out=None): + return _dask_or_eager_func('nancumsum', eager_module=nputils)( + a, axis=axis, dtype=dtype) + + +def nancumprod(a, axis=None, dtype=None, out=None): + return _dask_or_eager_func('nancumprod', eager_module=nputils)( + a, axis=axis, dtype=dtype) diff --git a/xarray/core/npcompat.py b/xarray/core/npcompat.py index 6d4db063b98..efa68c8bad5 100644 --- a/xarray/core/npcompat.py +++ b/xarray/core/npcompat.py @@ -1,5 +1,7 @@ from __future__ import absolute_import, division, print_function +from distutils.version import LooseVersion + import numpy as np try: @@ -97,3 +99,187 @@ def isin(element, test_elements, assume_unique=False, invert=False): element = np.asarray(element) return np.in1d(element, test_elements, assume_unique=assume_unique, invert=invert).reshape(element.shape) + + +if LooseVersion(np.__version__) >= LooseVersion('1.13'): + gradient = np.gradient +else: + def normalize_axis_tuple(axes, N): + if isinstance(axes, int): + axes = (axes, ) + return tuple([N + a if a < 0 else a for a in axes]) + + def gradient(f, *varargs, **kwargs): + f = np.asanyarray(f) + N = f.ndim # number of dimensions + + axes = kwargs.pop('axis', None) + if axes is None: + axes = tuple(range(N)) + else: + axes = normalize_axis_tuple(axes, N) + + len_axes = len(axes) + n = len(varargs) + if n == 0: + # no spacing argument - use 1 in all axes + dx = [1.0] * len_axes + elif n == 1 and np.ndim(varargs[0]) == 0: + # single scalar for all axes + dx = varargs * len_axes + elif n == len_axes: + # scalar or 1d array for each axis + dx = list(varargs) + for i, distances in enumerate(dx): + if np.ndim(distances) == 0: + continue + elif np.ndim(distances) != 1: + raise ValueError("distances must be either scalars or 1d") + if len(distances) != f.shape[axes[i]]: + raise ValueError("when 1d, distances must match the " + "length of the corresponding dimension") + diffx = np.diff(distances) + # if distances are constant reduce to the scalar case + # since it brings a consistent speedup + if (diffx == diffx[0]).all(): + diffx = diffx[0] + dx[i] = diffx + else: + raise TypeError("invalid number of arguments") + + edge_order = kwargs.pop('edge_order', 1) + if kwargs: + raise TypeError('"{}" are not valid keyword arguments.'.format( + '", "'.join(kwargs.keys()))) + if edge_order > 2: + raise ValueError("'edge_order' greater than 2 not supported") + + # use central differences on interior and one-sided differences on the + # endpoints. This preserves second order-accuracy over the full domain. + + outvals = [] + + # create slice objects --- initially all are [:, :, ..., :] + slice1 = [slice(None)] * N + slice2 = [slice(None)] * N + slice3 = [slice(None)] * N + slice4 = [slice(None)] * N + + otype = f.dtype.char + if otype not in ['f', 'd', 'F', 'D', 'm', 'M']: + otype = 'd' + + # Difference of datetime64 elements results in timedelta64 + if otype == 'M': + # Need to use the full dtype name because it contains unit + # information + otype = f.dtype.name.replace('datetime', 'timedelta') + elif otype == 'm': + # Needs to keep the specific units, can't be a general unit + otype = f.dtype + + # Convert datetime64 data into ints. Make dummy variable `y` + # that is a view of ints if the data is datetime64, otherwise + # just set y equal to the array `f`. + if f.dtype.char in ["M", "m"]: + y = f.view('int64') + else: + y = f + + for i, axis in enumerate(axes): + if y.shape[axis] < edge_order + 1: + raise ValueError( + "Shape of array too small to calculate a numerical " + "gradient, at least (edge_order + 1) elements are " + "required.") + # result allocation + out = np.empty_like(y, dtype=otype) + + uniform_spacing = np.ndim(dx[i]) == 0 + + # Numerical differentiation: 2nd order interior + slice1[axis] = slice(1, -1) + slice2[axis] = slice(None, -2) + slice3[axis] = slice(1, -1) + slice4[axis] = slice(2, None) + + if uniform_spacing: + out[slice1] = (f[slice4] - f[slice2]) / (2. * dx[i]) + else: + dx1 = dx[i][0:-1] + dx2 = dx[i][1:] + a = -(dx2) / (dx1 * (dx1 + dx2)) + b = (dx2 - dx1) / (dx1 * dx2) + c = dx1 / (dx2 * (dx1 + dx2)) + # fix the shape for broadcasting + shape = np.ones(N, dtype=int) + shape[axis] = -1 + a.shape = b.shape = c.shape = shape + # 1D equivalent -- + # out[1:-1] = a * f[:-2] + b * f[1:-1] + c * f[2:] + out[slice1] = a * f[slice2] + b * f[slice3] + c * f[slice4] + + # Numerical differentiation: 1st order edges + if edge_order == 1: + slice1[axis] = 0 + slice2[axis] = 1 + slice3[axis] = 0 + dx_0 = dx[i] if uniform_spacing else dx[i][0] + # 1D equivalent -- out[0] = (y[1] - y[0]) / (x[1] - x[0]) + out[slice1] = (y[slice2] - y[slice3]) / dx_0 + + slice1[axis] = -1 + slice2[axis] = -1 + slice3[axis] = -2 + dx_n = dx[i] if uniform_spacing else dx[i][-1] + # 1D equivalent -- out[-1] = (y[-1] - y[-2]) / (x[-1] - x[-2]) + out[slice1] = (y[slice2] - y[slice3]) / dx_n + + # Numerical differentiation: 2nd order edges + else: + slice1[axis] = 0 + slice2[axis] = 0 + slice3[axis] = 1 + slice4[axis] = 2 + if uniform_spacing: + a = -1.5 / dx[i] + b = 2. / dx[i] + c = -0.5 / dx[i] + else: + dx1 = dx[i][0] + dx2 = dx[i][1] + a = -(2. * dx1 + dx2) / (dx1 * (dx1 + dx2)) + b = (dx1 + dx2) / (dx1 * dx2) + c = - dx1 / (dx2 * (dx1 + dx2)) + # 1D equivalent -- out[0] = a * y[0] + b * y[1] + c * y[2] + out[slice1] = a * y[slice2] + b * y[slice3] + c * y[slice4] + + slice1[axis] = -1 + slice2[axis] = -3 + slice3[axis] = -2 + slice4[axis] = -1 + if uniform_spacing: + a = 0.5 / dx[i] + b = -2. / dx[i] + c = 1.5 / dx[i] + else: + dx1 = dx[i][-2] + dx2 = dx[i][-1] + a = (dx2) / (dx1 * (dx1 + dx2)) + b = - (dx2 + dx1) / (dx1 * dx2) + c = (2. * dx2 + dx1) / (dx2 * (dx1 + dx2)) + # 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1] + out[slice1] = a * y[slice2] + b * y[slice3] + c * y[slice4] + + outvals.append(out) + + # reset the slice object in this dimension to ":" + slice1[axis] = slice(None) + slice2[axis] = slice(None) + slice3[axis] = slice(None) + slice4[axis] = slice(None) + + if len_axes == 1: + return outvals[0] + else: + return outvals diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 6df2d34bfe3..a8d596abd86 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -5,6 +5,14 @@ import numpy as np import pandas as pd +try: + import bottleneck as bn + _USE_BOTTLENECK = True +except ImportError: + # use numpy methods instead + bn = np + _USE_BOTTLENECK = False + def _validate_axis(data, axis): ndim = data.ndim @@ -195,3 +203,36 @@ def _rolling_window(a, window, axis=-1): rolling = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides, writeable=False) return np.swapaxes(rolling, -2, axis) + + +def _create_bottleneck_method(name, npmodule=np): + def f(values, axis=None, **kwds): + dtype = kwds.get('dtype', None) + bn_func = getattr(bn, name, None) + + if (_USE_BOTTLENECK and bn_func is not None and + not isinstance(axis, tuple) and + values.dtype.kind in 'uifc' and + values.dtype.isnative and + (dtype is None or np.dtype(dtype) == values.dtype)): + # bottleneck does not take care dtype, min_count + kwds.pop('dtype', None) + result = bn_func(values, axis=axis, **kwds) + else: + result = getattr(npmodule, name)(values, axis=axis, **kwds) + + return result + + f.__name__ = name + return f + + +nanmin = _create_bottleneck_method('nanmin') +nanmax = _create_bottleneck_method('nanmax') +nanmean = _create_bottleneck_method('nanmean') +nanmedian = _create_bottleneck_method('nanmedian') +nanvar = _create_bottleneck_method('nanvar') +nanstd = _create_bottleneck_method('nanstd') +nanprod = _create_bottleneck_method('nanprod') +nancumsum = _create_bottleneck_method('nancumsum') +nancumprod = _create_bottleneck_method('nancumprod') diff --git a/xarray/core/ops.py b/xarray/core/ops.py index d9e8ceb65d5..a0dd2212a8f 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -86,7 +86,7 @@ If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been - implemented (object, datetime64 or timedelta64). + implemented (object, datetime64 or timedelta64).{min_count_docs} keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be @@ -102,6 +102,12 @@ indicated dimension(s) removed. """ +_MINCOUNT_DOCSTRING = """ +min_count : int, default None + The required number of valid values to perform the operation. + If fewer than min_count non-NA values are present the result will + be NA. New in version 0.10.8: Added with the default being None.""" + _ROLLING_REDUCE_DOCSTRING_TEMPLATE = """\ Reduce this {da_or_ds}'s data windows by applying `{name}` along its dimension. @@ -236,11 +242,15 @@ def inject_reduce_methods(cls): [('count', duck_array_ops.count, False)]) for name, f, include_skipna in methods: numeric_only = getattr(f, 'numeric_only', False) + available_min_count = getattr(f, 'available_min_count', False) + min_count_docs = _MINCOUNT_DOCSTRING if available_min_count else '' + func = cls._reduce_method(f, include_skipna, numeric_only) func.__name__ = name func.__doc__ = _REDUCE_DOCSTRING_TEMPLATE.format( name=name, cls=cls.__name__, - extra_args=cls._reduce_extra_args_docstring.format(name=name)) + extra_args=cls._reduce_extra_args_docstring.format(name=name), + min_count_docs=min_count_docs) setattr(cls, name, func) diff --git a/xarray/core/options.py b/xarray/core/options.py index 48d4567fc99..04ea0be7172 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -1,9 +1,43 @@ from __future__ import absolute_import, division, print_function +DISPLAY_WIDTH = 'display_width' +ARITHMETIC_JOIN = 'arithmetic_join' +ENABLE_CFTIMEINDEX = 'enable_cftimeindex' +FILE_CACHE_MAXSIZE = 'file_cache_maxsize' +CMAP_SEQUENTIAL = 'cmap_sequential' +CMAP_DIVERGENT = 'cmap_divergent' + OPTIONS = { - 'display_width': 80, - 'arithmetic_join': 'inner', - 'enable_cftimeindex': False + DISPLAY_WIDTH: 80, + ARITHMETIC_JOIN: 'inner', + ENABLE_CFTIMEINDEX: False, + FILE_CACHE_MAXSIZE: 128, + CMAP_SEQUENTIAL: 'viridis', + CMAP_DIVERGENT: 'RdBu_r', +} + +_JOIN_OPTIONS = frozenset(['inner', 'outer', 'left', 'right', 'exact']) + + +def _positive_integer(value): + return isinstance(value, int) and value > 0 + + +_VALIDATORS = { + DISPLAY_WIDTH: _positive_integer, + ARITHMETIC_JOIN: _JOIN_OPTIONS.__contains__, + ENABLE_CFTIMEINDEX: lambda value: isinstance(value, bool), + FILE_CACHE_MAXSIZE: _positive_integer, +} + + +def _set_file_cache_maxsize(value): + from ..backends.file_manager import FILE_CACHE + FILE_CACHE.maxsize = value + + +_SETTERS = { + FILE_CACHE_MAXSIZE: _set_file_cache_maxsize, } @@ -19,8 +53,18 @@ class set_options(object): - ``enable_cftimeindex``: flag to enable using a ``CFTimeIndex`` for time indexes with non-standard calendars or dates outside the Timestamp-valid range. Default: ``False``. + - ``file_cache_maxsize``: maximum number of open files to hold in xarray's + global least-recently-usage cached. This should be smaller than your + system's per-process file descriptor limit, e.g., ``ulimit -n`` on Linux. + Default: 128. + - ``cmap_sequential``: colormap to use for nondivergent data plots. + Default: ``viridis``. If string, must be matplotlib built-in colormap. + Can also be a Colormap object (e.g. mpl.cm.magma) + - ``cmap_divergent``: colormap to use for divergent data plots. + Default: ``RdBu_r``. If string, must be matplotlib built-in colormap. + Can also be a Colormap object (e.g. mpl.cm.magma) - You can use ``set_options`` either as a context manager: +f You can use ``set_options`` either as a context manager: >>> ds = xr.Dataset({'x': np.arange(1000)}) >>> with xr.set_options(display_width=40): @@ -38,16 +82,26 @@ class set_options(object): """ def __init__(self, **kwargs): - invalid_options = {k for k in kwargs if k not in OPTIONS} - if invalid_options: - raise ValueError('argument names %r are not in the set of valid ' - 'options %r' % (invalid_options, set(OPTIONS))) self.old = OPTIONS.copy() - OPTIONS.update(kwargs) + for k, v in kwargs.items(): + if k not in OPTIONS: + raise ValueError( + 'argument name %r is not in the set of valid options %r' + % (k, set(OPTIONS))) + if k in _VALIDATORS and not _VALIDATORS[k](v): + raise ValueError( + 'option %r given an invalid value: %r' % (k, v)) + self._apply_update(kwargs) + + def _apply_update(self, options_dict): + for k, v in options_dict.items(): + if k in _SETTERS: + _SETTERS[k](v) + OPTIONS.update(options_dict) def __enter__(self): return def __exit__(self, type, value, traceback): OPTIONS.clear() - OPTIONS.update(self.old) + self._apply_update(self.old) diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 78c26f1e92f..b980bc279b0 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -28,6 +28,9 @@ def itervalues(d): import builtins from urllib.request import urlretrieve from inspect import getfullargspec as getargspec + + def move_to_end(ordered_dict, key): + ordered_dict.move_to_end(key) else: # pragma: no cover # Python 2 basestring = basestring # noqa @@ -50,6 +53,11 @@ def itervalues(d): from urllib import urlretrieve from inspect import getargspec + def move_to_end(ordered_dict, key): + value = ordered_dict[key] + del ordered_dict[key] + ordered_dict[key] = value + integer_types = native_int_types + (np.integer,) try: @@ -76,7 +84,6 @@ def itervalues(d): except ImportError as e: path_type = () - try: from contextlib import suppress except ImportError: diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 4933a09b257..bd84e04487e 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, division, print_function from . import ops -from .groupby import DataArrayGroupBy, DatasetGroupBy +from .groupby import DEFAULT_DIMS, DataArrayGroupBy, DatasetGroupBy from .pycompat import OrderedDict, dask_array_type RESAMPLE_DIM = '__resample_dim__' @@ -277,15 +277,14 @@ def reduce(self, func, dim=None, keep_attrs=False, **kwargs): """Reduce the items in this group by applying `func` along the pre-defined resampling dimension. - Note that `dim` is by default here and ignored if passed by the user; - this ensures compatibility with the existing reduce interface. - Parameters ---------- func : function Function which can be called in the form `func(x, axis=axis, **kwargs)` to return the result of collapsing an np.ndarray over an integer valued axis. + dim : str or sequence of str, optional + Dimension(s) over which to apply `func`. keep_attrs : bool, optional If True, the datasets's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new @@ -299,8 +298,11 @@ def reduce(self, func, dim=None, keep_attrs=False, **kwargs): Array with summarized data and the indicated dimension(s) removed. """ + if dim == DEFAULT_DIMS: + dim = None + return super(DatasetResample, self).reduce( - func, self._dim, keep_attrs, **kwargs) + func, dim, keep_attrs, **kwargs) def _interpolate(self, kind='linear'): """Apply scipy.interpolate.interp1d along resampling dimension.""" diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 24ed280b19e..883dbb34dff 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -44,7 +44,7 @@ class Rolling(object): _attributes = ['window', 'min_periods', 'center', 'dim'] - def __init__(self, obj, min_periods=None, center=False, **windows): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object. @@ -52,18 +52,18 @@ def __init__(self, obj, min_periods=None, center=False, **windows): ---------- obj : Dataset or DataArray Object to window. + windows : A mapping from a dimension name to window size + dim : str + Name of the dimension to create the rolling iterator + along (e.g., `time`). + window : int + Size of the moving window. min_periods : int, default None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **windows : dim=window - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. Returns ------- @@ -115,7 +115,7 @@ def __len__(self): class DataArrayRolling(Rolling): - def __init__(self, obj, min_periods=None, center=False, **windows): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object for DataArray. You should use DataArray.rolling() method to construct this object @@ -125,18 +125,18 @@ def __init__(self, obj, min_periods=None, center=False, **windows): ---------- obj : DataArray Object to window. + windows : A mapping from a dimension name to window size + dim : str + Name of the dimension to create the rolling iterator + along (e.g., `time`). + window : int + Size of the moving window. min_periods : int, default None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **windows : dim=window - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. Returns ------- @@ -149,8 +149,8 @@ def __init__(self, obj, min_periods=None, center=False, **windows): Dataset.rolling Dataset.groupby """ - super(DataArrayRolling, self).__init__(obj, min_periods=min_periods, - center=center, **windows) + super(DataArrayRolling, self).__init__( + obj, windows, min_periods=min_periods, center=center) self.window_labels = self.obj[self.dim] @@ -321,7 +321,7 @@ def wrapped_func(self, **kwargs): class DatasetRolling(Rolling): - def __init__(self, obj, min_periods=None, center=False, **windows): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object for Dataset. You should use Dataset.rolling() method to construct this object @@ -331,18 +331,18 @@ def __init__(self, obj, min_periods=None, center=False, **windows): ---------- obj : Dataset Object to window. + windows : A mapping from a dimension name to window size + dim : str + Name of the dimension to create the rolling iterator + along (e.g., `time`). + window : int + Size of the moving window. min_periods : int, default None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. center : boolean, default False Set the labels at the center of the window. - **windows : dim=window - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. Returns ------- @@ -355,8 +355,7 @@ def __init__(self, obj, min_periods=None, center=False, **windows): Dataset.groupby DataArray.groupby """ - super(DatasetRolling, self).__init__(obj, - min_periods, center, **windows) + super(DatasetRolling, self).__init__(obj, windows, min_periods, center) if self.dim not in self.obj.dims: raise KeyError(self.dim) # Keep each Rolling object as an OrderedDict @@ -364,8 +363,8 @@ def __init__(self, obj, min_periods=None, center=False, **windows): for key, da in self.obj.data_vars.items(): # keeps rollings only for the dataset depending on slf.dim if self.dim in da.dims: - self.rollings[key] = DataArrayRolling(da, min_periods, - center, **windows) + self.rollings[key] = DataArrayRolling( + da, windows, min_periods, center) def reduce(self, func, **kwargs): """Reduce the items in this group by applying `func` along some diff --git a/xarray/core/utils.py b/xarray/core/utils.py index c3bb747fac5..c39a07e1b5a 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -591,3 +591,29 @@ def __iter__(self): def __len__(self): num_hidden = sum([k in self._hidden_keys for k in self._data]) return len(self._data) - num_hidden + + +def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float): + """Convert an array containing datetime-like data to an array of floats. + + Parameters + ---------- + da : array + Input data + offset: Scalar with the same type of array or None + If None, subtract minimum values to reduce round off error + datetime_unit: None or any of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', + 'us', 'ns', 'ps', 'fs', 'as'} + dtype: target dtype + + Returns + ------- + array + """ + if offset is None: + offset = array.min() + array = array - offset + + if datetime_unit: + return (array / np.timedelta64(1, datetime_unit)).astype(dtype) + return array.astype(dtype) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index d9772407b82..c003d52aab2 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -64,22 +64,15 @@ def as_variable(obj, name=None): The newly created variable. """ + from .dataarray import DataArray + # TODO: consider extending this method to automatically handle Iris and - # pandas objects. - if hasattr(obj, 'variable'): + if isinstance(obj, DataArray): # extract the primary Variable from DataArrays obj = obj.variable if isinstance(obj, Variable): obj = obj.copy(deep=False) - elif hasattr(obj, 'dims') and (hasattr(obj, 'data') or - hasattr(obj, 'values')): - obj_data = getattr(obj, 'data', None) - if obj_data is None: - obj_data = getattr(obj, 'values') - obj = Variable(obj.dims, obj_data, - getattr(obj, 'attrs', None), - getattr(obj, 'encoding', None)) elif isinstance(obj, tuple): try: obj = Variable(*obj) @@ -728,24 +721,81 @@ def encoding(self, value): except ValueError: raise ValueError('encoding must be castable to a dictionary') - def copy(self, deep=True): + def copy(self, deep=True, data=None): """Returns a copy of this object. If `deep=True`, the data array is loaded into memory and copied onto the new object. Dimensions, attributes and encodings are always copied. - """ - data = self._data - if isinstance(data, indexing.MemoryCachedArray): - # don't share caching between copies - data = indexing.MemoryCachedArray(data.array) + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, optional + Whether the data array is loaded into memory and copied onto + the new object. Default is True. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + When `data` is used, `deep` is ignored. + + Returns + ------- + object : Variable + New object with dimensions, attributes, encodings, and optionally + data copied from original. - if deep: - if isinstance(data, dask_array_type): - data = data.copy() - elif not isinstance(data, PandasIndexAdapter): - # pandas.Index is immutable - data = np.array(data) + Examples + -------- + + Shallow copy versus deep copy + + >>> var = xr.Variable(data=[1, 2, 3], dims='x') + >>> var.copy() + + array([1, 2, 3]) + >>> var_0 = var.copy(deep=False) + >>> var_0[0] = 7 + >>> var_0 + + array([7, 2, 3]) + >>> var + + array([7, 2, 3]) + + Changing the data using the ``data`` argument maintains the + structure of the original object, but with the new data. Original + object is unaffected. + + >>> var.copy(data=[0.1, 0.2, 0.3]) + + array([ 0.1, 0.2, 0.3]) + >>> var + + array([7, 2, 3]) + + See Also + -------- + pandas.DataFrame.copy + """ + if data is None: + data = self._data + + if isinstance(data, indexing.MemoryCachedArray): + # don't share caching between copies + data = indexing.MemoryCachedArray(data.array) + + if deep: + if isinstance(data, dask_array_type): + data = data.copy() + elif not isinstance(data, PandasIndexAdapter): + # pandas.Index is immutable + data = np.array(data) + else: + data = as_compatible_data(data) + if self.shape != data.shape: + raise ValueError("Data shape {} must match shape of object {}" + .format(data.shape, self.shape)) # note: # dims is already an immutable tuple @@ -877,7 +927,7 @@ def squeeze(self, dim=None): numpy.squeeze """ dims = common.get_squeeze_dims(self, dim) - return self.isel(**{d: 0 for d in dims}) + return self.isel({d: 0 for d in dims}) def _shift_one_dim(self, dim, count): axis = self.get_axis_num(dim) @@ -919,36 +969,46 @@ def _shift_one_dim(self, dim, count): return type(self)(self.dims, data, self._attrs, fastpath=True) - def shift(self, **shifts): + def shift(self, shifts=None, **shifts_kwargs): """ Return a new Variable with shifted data. Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} + shifts : mapping of the form {dim: offset} Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. + **shifts_kwargs: + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwarg must be provided. Returns ------- shifted : Variable Variable with the same dimensions and attributes but shifted data. """ + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'shift') result = self for dim, count in shifts.items(): result = result._shift_one_dim(dim, count) return result - def pad_with_fill_value(self, fill_value=dtypes.NA, **pad_widths): + def pad_with_fill_value(self, pad_widths=None, fill_value=dtypes.NA, + **pad_widths_kwargs): """ Return a new Variable with paddings. Parameters ---------- - **pad_width: keyword arguments of the form {dim: (before, after)} + pad_width: Mapping of the form {dim: (before, after)} Number of values padded to the edges of each dimension. + **pad_widths_kwargs: + Keyword argument for pad_widths """ + pad_widths = either_dict_or_kwargs(pad_widths, pad_widths_kwargs, + 'pad') + if fill_value is dtypes.NA: # np.nan is passed dtype, fill_value = dtypes.maybe_promote(self.dtype) else: @@ -1009,22 +1069,27 @@ def _roll_one_dim(self, dim, count): return type(self)(self.dims, data, self._attrs, fastpath=True) - def roll(self, **shifts): + def roll(self, shifts=None, **shifts_kwargs): """ Return a new Variable with rolld data. Parameters ---------- - **shifts : keyword arguments of the form {dim: offset} + shifts : mapping of the form {dim: offset} Integer offset to roll along each of the given dimensions. Positive offsets roll to the right; negative offsets roll to the left. + **shifts_kwargs: + The keyword arguments form of ``shifts``. + One of shifts or shifts_kwarg must be provided. Returns ------- shifted : Variable Variable with the same dimensions and attributes but rolled data. """ + shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'roll') + result = self for dim, count in shifts.items(): result = result._roll_one_dim(dim, count) @@ -1142,7 +1207,7 @@ def _stack_once(self, dims, new_dim): return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) - def stack(self, **dimensions): + def stack(self, dimensions=None, **dimensions_kwargs): """ Stack any number of existing dimensions into a single new dimension. @@ -1151,9 +1216,12 @@ def stack(self, **dimensions): Parameters ---------- - **dimensions : keyword arguments of the form new_name=(dim1, dim2, ...) + dimensions : Mapping of form new_name=(dim1, dim2, ...) Names of new dimensions, and the existing dimensions that they replace. + **dimensions_kwargs: + The keyword arguments form of ``dimensions``. + One of dimensions or dimensions_kwargs must be provided. Returns ------- @@ -1164,6 +1232,8 @@ def stack(self, **dimensions): -------- Variable.unstack """ + dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, + 'stack') result = self for new_dim, dims in dimensions.items(): result = result._stack_once(dims, new_dim) @@ -1195,7 +1265,7 @@ def _unstack_once(self, dims, old_dim): return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) - def unstack(self, **dimensions): + def unstack(self, dimensions=None, **dimensions_kwargs): """ Unstack an existing dimension into multiple new dimensions. @@ -1204,9 +1274,12 @@ def unstack(self, **dimensions): Parameters ---------- - **dimensions : keyword arguments of the form old_dim={dim1: size1, ...} + dimensions : mapping of the form old_dim={dim1: size1, ...} Names of existing dimensions, and the new dimensions and sizes that they map to. + **dimensions_kwargs: + The keyword arguments form of ``dimensions``. + One of dimensions or dimensions_kwargs must be provided. Returns ------- @@ -1217,6 +1290,8 @@ def unstack(self, **dimensions): -------- Variable.stack """ + dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, + 'unstack') result = self for old_dim, dims in dimensions.items(): result = result._unstack_once(dims, old_dim) @@ -1258,6 +1333,8 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False, Array with summarized data and the indicated dimension(s) removed. """ + if dim is common.ALL_DIMS: + dim = None if dim is not None and axis is not None: raise ValueError("cannot supply both 'axis' and 'dim' arguments") @@ -1691,14 +1768,37 @@ def concat(cls, variables, dim='concat_dim', positions=None, return cls(first_var.dims, data, attrs) - def copy(self, deep=True): + def copy(self, deep=True, data=None): """Returns a copy of this object. - `deep` is ignored since data is stored in the form of pandas.Index, - which is already immutable. Dimensions, attributes and encodings are - always copied. + `deep` is ignored since data is stored in the form of + pandas.Index, which is already immutable. Dimensions, attributes + and encodings are always copied. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, optional + Deep is always ignored. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + + Returns + ------- + object : Variable + New object with dimensions, attributes, encodings, and optionally + data copied from original. """ - return type(self)(self.dims, self._data, self._attrs, + if data is None: + data = self._data + else: + data = as_compatible_data(data) + if self.shape != data.shape: + raise ValueError("Data shape {} must match shape of object {}" + .format(data.shape, self.shape)) + return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True) def equals(self, other, equiv=None): diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index a0d7c4dd5e2..32a954a3fcd 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -5,6 +5,7 @@ import warnings import numpy as np + from ..core.formatting import format_item from ..core.pycompat import getargspec from .utils import ( @@ -188,6 +189,7 @@ def __init__(self, data, col=None, row=None, col_wrap=None, self._y_var = None self._cmap_extend = None self._mappables = [] + self._finalized = False @property def _left_axes(self): @@ -308,13 +310,16 @@ def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs): def _finalize_grid(self, *axlabels): """Finalize the annotations and layout.""" - self.set_axis_labels(*axlabels) - self.set_titles() - self.fig.tight_layout() + if not self._finalized: + self.set_axis_labels(*axlabels) + self.set_titles() + self.fig.tight_layout() - for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): - if namedict is None: - ax.set_visible(False) + for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): + if namedict is None: + ax.set_visible(False) + + self._finalized = True def add_legend(self, **kwargs): figlegend = self.fig.legend( @@ -502,9 +507,12 @@ def map(self, func, *args, **kwargs): data = self.data.loc[namedict] plt.sca(ax) innerargs = [data[a].values for a in args] - # TODO: is it possible to verify that an artist is mappable? - mappable = func(*innerargs, **kwargs) - self._mappables.append(mappable) + maybe_mappable = func(*innerargs, **kwargs) + # TODO: better way to verify that an artist is mappable? + # https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522 + if (maybe_mappable and + hasattr(maybe_mappable, 'autoscale_None')): + self._mappables.append(maybe_mappable) self._finalize_grid(*args[:2]) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 3e7e5909a70..52be985153a 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -600,9 +600,11 @@ def step(self, *args, **kwargs): def _rescale_imshow_rgb(darray, vmin, vmax, robust): assert robust or vmin is not None or vmax is not None + # TODO: remove when min numpy version is bumped to 1.13 # There's a cyclic dependency via DataArray, so we can't import from # xarray.ufuncs in global scope. from xarray.ufuncs import maximum, minimum + # Calculate vmin and vmax automatically for `robust=True` if robust: if vmax is None: @@ -628,7 +630,10 @@ def _rescale_imshow_rgb(darray, vmin, vmax, robust): # After scaling, downcast to 32-bit float. This substantially reduces # memory usage after we hand `darray` off to matplotlib. darray = ((darray.astype('f8') - vmin) / (vmax - vmin)).astype('f4') - return minimum(maximum(darray, 0), 1) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'xarray.ufuncs', + PendingDeprecationWarning) + return minimum(maximum(darray, 0), 1) def _plot2d(plotfunc): @@ -678,6 +683,9 @@ def _plot2d(plotfunc): Adds colorbar to axis add_labels : Boolean, optional Use xarray metadata to label axes + norm : ``matplotlib.colors.Normalize`` instance, optional + If the ``norm`` has vmin or vmax specified, the corresponding kwarg + must be None. vmin, vmax : floats, optional Values to anchor the colormap, otherwise they are inferred from the data and other keyword arguments. When a diverging dataset is inferred, @@ -746,7 +754,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, levels=None, infer_intervals=None, colors=None, subplot_kws=None, cbar_ax=None, cbar_kwargs=None, xscale=None, yscale=None, xticks=None, yticks=None, - xlim=None, ylim=None, **kwargs): + xlim=None, ylim=None, norm=None, **kwargs): # All 2d plots in xarray share this function signature. # Method signature below should be consistent. @@ -847,6 +855,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, 'extend': extend, 'levels': levels, 'filled': plotfunc.__name__ != 'contour', + 'norm': norm, } cmap_params = _determine_cmap_params(**cmap_kwargs) @@ -857,13 +866,15 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, # pcolormesh kwargs['extend'] = cmap_params['extend'] kwargs['levels'] = cmap_params['levels'] + # if colors == a single color, matplotlib draws dashed negative + # contours. we lose this feature if we pass cmap and not colors + if isinstance(colors, basestring): + cmap_params['cmap'] = None + kwargs['colors'] = colors if 'pcolormesh' == plotfunc.__name__: kwargs['infer_intervals'] = infer_intervals - # This allows the user to pass in a custom norm coming via kwargs - kwargs.setdefault('norm', cmap_params['norm']) - if 'imshow' == plotfunc.__name__ and isinstance(aspect, basestring): # forbid usage of mpl strings raise ValueError("plt.imshow's `aspect` kwarg is not available " @@ -873,6 +884,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, primitive = plotfunc(xplt, yplt, zval, ax=ax, cmap=cmap_params['cmap'], vmin=cmap_params['vmin'], vmax=cmap_params['vmax'], + norm=cmap_params['norm'], **kwargs) # Label the plot with metadata @@ -890,12 +902,16 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, cbar_kwargs.setdefault('cax', cbar_ax) cbar = plt.colorbar(primitive, **cbar_kwargs) if add_labels and 'label' not in cbar_kwargs: - cbar.set_label(label_from_attrs(darray), rotation=90) + cbar.set_label(label_from_attrs(darray)) elif cbar_ax is not None or cbar_kwargs is not None: # inform the user about keywords which aren't used raise ValueError("cbar_ax and cbar_kwargs can't be used with " "add_colorbar=False.") + # origin kwarg overrides yincrease + if 'origin' in kwargs: + yincrease = None + _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) @@ -920,7 +936,7 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, figsize=None, size=None, levels=None, infer_intervals=None, subplot_kws=None, cbar_ax=None, cbar_kwargs=None, xscale=None, yscale=None, xticks=None, yticks=None, - xlim=None, ylim=None, **kwargs): + xlim=None, ylim=None, norm=None, **kwargs): """ The method should have the same signature as the function. @@ -982,10 +998,8 @@ def imshow(x, y, z, ax, **kwargs): left, right = x[0] - xstep, x[-1] + xstep bottom, top = y[-1] + ystep, y[0] - ystep - defaults = {'extent': [left, right, bottom, top], - 'origin': 'upper', - 'interpolation': 'nearest', - } + defaults = {'origin': 'upper', + 'interpolation': 'nearest'} if not hasattr(ax, 'projection'): # not for cartopy geoaxes @@ -994,6 +1008,11 @@ def imshow(x, y, z, ax, **kwargs): # Allow user to override these defaults defaults.update(kwargs) + if defaults['origin'] == 'upper': + defaults['extent'] = [left, right, bottom, top] + else: + defaults['extent'] = [left, right, top, bottom] + if z.ndim == 3: # matplotlib imshow uses black for missing data, but Xarray makes # missing data transparent. We therefore add an alpha channel if diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 4a09b66ca33..f39c989a514 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -1,10 +1,11 @@ from __future__ import absolute_import, division, print_function +import textwrap import warnings import numpy as np -import textwrap +from ..core.options import OPTIONS from ..core.pycompat import basestring from ..core.utils import is_scalar @@ -171,6 +172,10 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, # vlim might be computed below vlim = None + # save state; needed later + vmin_was_none = vmin is None + vmax_was_none = vmax is None + if vmin is None: if robust: vmin = np.percentile(calc_data, ROBUST_PERCENTILE) @@ -203,18 +208,42 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, vmin += center vmax += center + # now check norm and harmonize with vmin, vmax + if norm is not None: + if norm.vmin is None: + norm.vmin = vmin + else: + if not vmin_was_none and vmin != norm.vmin: + raise ValueError('Cannot supply vmin and a norm' + + ' with a different vmin.') + vmin = norm.vmin + + if norm.vmax is None: + norm.vmax = vmax + else: + if not vmax_was_none and vmax != norm.vmax: + raise ValueError('Cannot supply vmax and a norm' + + ' with a different vmax.') + vmax = norm.vmax + + # if BoundaryNorm, then set levels + if isinstance(norm, mpl.colors.BoundaryNorm): + levels = norm.boundaries + # Choose default colormaps if not provided if cmap is None: if divergent: - cmap = "RdBu_r" + cmap = OPTIONS['cmap_divergent'] else: - cmap = "viridis" + cmap = OPTIONS['cmap_sequential'] # Handle discrete levels - if levels is not None: + if levels is not None and norm is None: if is_scalar(levels): - if user_minmax or levels == 1: + if user_minmax: levels = np.linspace(vmin, vmax, levels) + elif levels == 1: + levels = np.asarray([(vmin + vmax) / 2]) else: # N in MaxNLocator refers to bins, not ticks ticker = mpl.ticker.MaxNLocator(levels - 1) @@ -224,8 +253,9 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, if extend is None: extend = _determine_extend(calc_data, vmin, vmax) - if levels is not None: - cmap, norm = _build_discrete_cmap(cmap, levels, extend, filled) + if levels is not None or isinstance(norm, mpl.colors.BoundaryNorm): + cmap, newnorm = _build_discrete_cmap(cmap, levels, extend, filled) + norm = newnorm if norm is None else norm return dict(vmin=vmin, vmax=vmax, cmap=cmap, extend=extend, levels=levels, norm=norm) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 33a8da6bbfb..285c1f03a26 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -9,11 +9,10 @@ import numpy as np from numpy.testing import assert_array_equal # noqa: F401 -from xarray.core.duck_array_ops import allclose_or_equiv +from xarray.core.duck_array_ops import allclose_or_equiv # noqa import pytest from xarray.core import utils -from xarray.core.pycompat import PY3 from xarray.core.indexing import ExplicitlyIndexed from xarray.testing import (assert_equal, assert_identical, # noqa: F401 assert_allclose) @@ -25,10 +24,6 @@ # old location, for pandas < 0.20 from pandas.util.testing import assert_frame_equal # noqa: F401 -try: - import unittest2 as unittest -except ImportError: - import unittest try: from unittest import mock @@ -117,39 +112,6 @@ def _importorskip(modname, minversion=None): "internet connection") -class TestCase(unittest.TestCase): - """ - These functions are all deprecated. Instead, use functions in xr.testing - """ - if PY3: - # Python 3 assertCountEqual is roughly equivalent to Python 2 - # assertItemsEqual - def assertItemsEqual(self, first, second, msg=None): - __tracebackhide__ = True # noqa: F841 - return self.assertCountEqual(first, second, msg) - - @contextmanager - def assertWarns(self, message): - __tracebackhide__ = True # noqa: F841 - with warnings.catch_warnings(record=True) as w: - warnings.filterwarnings('always', message) - yield - assert len(w) > 0 - assert any(message in str(wi.message) for wi in w) - - def assertVariableNotEqual(self, v1, v2): - __tracebackhide__ = True # noqa: F841 - assert not v1.equals(v2) - - def assertEqual(self, a1, a2): - __tracebackhide__ = True # noqa: F841 - assert a1 == a2 or (a1 != a1 and a2 != a2) - - def assertAllClose(self, a1, a2, rtol=1e-05, atol=1e-8): - __tracebackhide__ = True # noqa: F841 - assert allclose_or_equiv(a1, a2, rtol=rtol, atol=atol) - - @contextmanager def raises_regex(error, pattern): __tracebackhide__ = True # noqa: F841 diff --git a/xarray/tests/test_accessors.py b/xarray/tests/test_accessors.py index e1b3a95b942..38038fc8f65 100644 --- a/xarray/tests/test_accessors.py +++ b/xarray/tests/test_accessors.py @@ -7,12 +7,13 @@ import xarray as xr from . import ( - TestCase, assert_array_equal, assert_equal, raises_regex, requires_dask, - has_cftime, has_dask, has_cftime_or_netCDF4) + assert_array_equal, assert_equal, has_cftime, has_cftime_or_netCDF4, + has_dask, raises_regex, requires_dask) -class TestDatetimeAccessor(TestCase): - def setUp(self): +class TestDatetimeAccessor(object): + @pytest.fixture(autouse=True) + def setup(self): nt = 100 data = np.random.rand(10, 10, nt) lons = np.linspace(0, 11, 10) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index e6de50b9dd2..43811942d5f 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2,12 +2,12 @@ import contextlib import itertools +import math import os.path import pickle import shutil import sys import tempfile -import unittest import warnings from io import BytesIO @@ -19,22 +19,21 @@ from xarray import ( DataArray, Dataset, backends, open_dataarray, open_dataset, open_mfdataset, save_mfdataset) -from xarray.backends.common import (robust_getitem, - PickleByReconstructionWrapper) +from xarray.backends.common import robust_getitem from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding from xarray.backends.pydap_ import PydapDataStore from xarray.core import indexing from xarray.core.pycompat import ( - PY2, ExitStack, basestring, dask_array_type, iteritems) + ExitStack, basestring, dask_array_type, iteritems) +from xarray.core.options import set_options from xarray.tests import mock from . import ( - TestCase, assert_allclose, assert_array_equal, assert_equal, - assert_identical, has_dask, has_netCDF4, has_scipy, network, raises_regex, + assert_allclose, assert_array_equal, assert_equal, assert_identical, + has_dask, has_netCDF4, has_scipy, network, raises_regex, requires_cftime, requires_dask, requires_h5netcdf, requires_netCDF4, requires_pathlib, - requires_pydap, requires_pynio, requires_rasterio, requires_scipy, - requires_scipy_or_netCDF4, requires_zarr, requires_pseudonetcdf, - requires_cftime) + requires_pseudonetcdf, requires_pydap, requires_pynio, requires_rasterio, + requires_scipy, requires_scipy_or_netCDF4, requires_zarr) from .test_dataset import create_test_data try: @@ -106,7 +105,7 @@ def create_boolean_data(): return Dataset({'x': ('t', [True, False, False, True], attributes)}) -class TestCommon(TestCase): +class TestCommon(object): def test_robust_getitem(self): class UnreliableArrayFailure(Exception): @@ -126,11 +125,11 @@ def __getitem__(self, key): array = UnreliableArray([0]) with pytest.raises(UnreliableArrayFailure): array[0] - self.assertEqual(array[0], 0) + assert array[0] == 0 actual = robust_getitem(array, 0, catch=UnreliableArrayFailure, initial_delay=0) - self.assertEqual(actual, 0) + assert actual == 0 class NetCDF3Only(object): @@ -138,7 +137,6 @@ class NetCDF3Only(object): class DatasetIOTestCases(object): - autoclose = False engine = None file_format = None @@ -172,8 +170,7 @@ def save(self, dataset, path, **kwargs): @contextlib.contextmanager def open(self, path, **kwargs): - with open_dataset(path, engine=self.engine, autoclose=self.autoclose, - **kwargs) as ds: + with open_dataset(path, engine=self.engine, **kwargs) as ds: yield ds def test_zero_dimensional_variable(self): @@ -222,11 +219,11 @@ def assert_loads(vars=None): with self.roundtrip(expected) as actual: for k, v in actual.variables.items(): # IndexVariables are eagerly loaded into memory - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) yield actual for k, v in actual.variables.items(): if k in vars: - self.assertTrue(v._in_memory) + assert v._in_memory assert_identical(expected, actual) with pytest.raises(AssertionError): @@ -252,14 +249,14 @@ def test_dataset_compute(self): # Test Dataset.compute() for k, v in actual.variables.items(): # IndexVariables are eagerly cached - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) computed = actual.compute() for k, v in actual.variables.items(): - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) for v in computed.variables.values(): - self.assertTrue(v._in_memory) + assert v._in_memory assert_identical(expected, actual) assert_identical(expected, computed) @@ -343,12 +340,12 @@ def test_roundtrip_string_encoded_characters(self): expected['x'].encoding['dtype'] = 'S1' with self.roundtrip(expected) as actual: assert_identical(expected, actual) - self.assertEqual(actual['x'].encoding['_Encoding'], 'utf-8') + assert actual['x'].encoding['_Encoding'] == 'utf-8' expected['x'].encoding['_Encoding'] = 'ascii' with self.roundtrip(expected) as actual: assert_identical(expected, actual) - self.assertEqual(actual['x'].encoding['_Encoding'], 'ascii') + assert actual['x'].encoding['_Encoding'] == 'ascii' def test_roundtrip_numpy_datetime_data(self): times = pd.to_datetime(['2000-01-01', '2000-01-02', 'NaT']) @@ -434,10 +431,10 @@ def test_roundtrip_coordinates_with_space(self): def test_roundtrip_boolean_dtype(self): original = create_boolean_data() - self.assertEqual(original['x'].dtype, 'bool') + assert original['x'].dtype == 'bool' with self.roundtrip(original) as actual: assert_identical(original, actual) - self.assertEqual(actual['x'].dtype, 'bool') + assert actual['x'].dtype == 'bool' def test_orthogonal_indexing(self): in_memory = create_test_data() @@ -626,20 +623,20 @@ def test_unsigned_roundtrip_mask_and_scale(self): encoded = create_encoded_unsigned_masked_scaled_data() with self.roundtrip(decoded) as actual: for k in decoded.variables: - self.assertEqual(decoded.variables[k].dtype, - actual.variables[k].dtype) + assert (decoded.variables[k].dtype == + actual.variables[k].dtype) assert_allclose(decoded, actual, decode_bytes=False) with self.roundtrip(decoded, open_kwargs=dict(decode_cf=False)) as actual: for k in encoded.variables: - self.assertEqual(encoded.variables[k].dtype, - actual.variables[k].dtype) + assert (encoded.variables[k].dtype == + actual.variables[k].dtype) assert_allclose(encoded, actual, decode_bytes=False) with self.roundtrip(encoded, open_kwargs=dict(decode_cf=False)) as actual: for k in encoded.variables: - self.assertEqual(encoded.variables[k].dtype, - actual.variables[k].dtype) + assert (encoded.variables[k].dtype == + actual.variables[k].dtype) assert_allclose(encoded, actual, decode_bytes=False) # make sure roundtrip encoding didn't change the # original dataset. @@ -647,14 +644,14 @@ def test_unsigned_roundtrip_mask_and_scale(self): encoded, create_encoded_unsigned_masked_scaled_data()) with self.roundtrip(encoded) as actual: for k in decoded.variables: - self.assertEqual(decoded.variables[k].dtype, - actual.variables[k].dtype) + assert decoded.variables[k].dtype == \ + actual.variables[k].dtype assert_allclose(decoded, actual, decode_bytes=False) with self.roundtrip(encoded, open_kwargs=dict(decode_cf=False)) as actual: for k in encoded.variables: - self.assertEqual(encoded.variables[k].dtype, - actual.variables[k].dtype) + assert encoded.variables[k].dtype == \ + actual.variables[k].dtype assert_allclose(encoded, actual, decode_bytes=False) def test_roundtrip_mask_and_scale(self): @@ -692,12 +689,11 @@ def equals_latlon(obj): with create_tmp_file() as tmp_file: original.to_netcdf(tmp_file) with open_dataset(tmp_file, decode_coords=False) as ds: - self.assertTrue(equals_latlon(ds['temp'].attrs['coordinates'])) - self.assertTrue( - equals_latlon(ds['precip'].attrs['coordinates'])) - self.assertNotIn('coordinates', ds.attrs) - self.assertNotIn('coordinates', ds['lat'].attrs) - self.assertNotIn('coordinates', ds['lon'].attrs) + assert equals_latlon(ds['temp'].attrs['coordinates']) + assert equals_latlon(ds['precip'].attrs['coordinates']) + assert 'coordinates' not in ds.attrs + assert 'coordinates' not in ds['lat'].attrs + assert 'coordinates' not in ds['lon'].attrs modified = original.drop(['temp', 'precip']) with self.roundtrip(modified) as actual: @@ -705,9 +701,9 @@ def equals_latlon(obj): with create_tmp_file() as tmp_file: modified.to_netcdf(tmp_file) with open_dataset(tmp_file, decode_coords=False) as ds: - self.assertTrue(equals_latlon(ds.attrs['coordinates'])) - self.assertNotIn('coordinates', ds['lat'].attrs) - self.assertNotIn('coordinates', ds['lon'].attrs) + assert equals_latlon(ds.attrs['coordinates']) + assert 'coordinates' not in ds['lat'].attrs + assert 'coordinates' not in ds['lon'].attrs def test_roundtrip_endian(self): ds = Dataset({'x': np.arange(3, 10, dtype='>i2'), @@ -743,8 +739,8 @@ def test_encoding_kwarg(self): ds = Dataset({'x': ('y', np.arange(10.0))}) kwargs = dict(encoding={'x': {'dtype': 'f4'}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['dtype'], 'f4') - self.assertEqual(ds.x.encoding, {}) + assert actual.x.encoding['dtype'] == 'f4' + assert ds.x.encoding == {} kwargs = dict(encoding={'x': {'foo': 'bar'}}) with raises_regex(ValueError, 'unexpected encoding'): @@ -766,7 +762,7 @@ def test_encoding_kwarg_dates(self): units = 'days since 1900-01-01' kwargs = dict(encoding={'t': {'units': units}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.t.encoding['units'], units) + assert actual.t.encoding['units'] == units assert_identical(actual, ds) def test_encoding_kwarg_fixed_width_string(self): @@ -778,7 +774,7 @@ def test_encoding_kwarg_fixed_width_string(self): ds = Dataset({'x': strings}) kwargs = dict(encoding={'x': {'dtype': 'S1'}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual['x'].encoding['dtype'], 'S1') + assert actual['x'].encoding['dtype'] == 'S1' assert_identical(actual, ds) def test_default_fill_value(self): @@ -786,9 +782,8 @@ def test_default_fill_value(self): ds = Dataset({'x': ('y', np.arange(10.0))}) kwargs = dict(encoding={'x': {'dtype': 'f4'}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['_FillValue'], - np.nan) - self.assertEqual(ds.x.encoding, {}) + assert math.isnan(actual.x.encoding['_FillValue']) + assert ds.x.encoding == {} # Test default encoding for int: ds = Dataset({'x': ('y', np.arange(10.0))}) @@ -797,14 +792,14 @@ def test_default_fill_value(self): warnings.filterwarnings( 'ignore', '.*floating point data as an integer') with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertTrue('_FillValue' not in actual.x.encoding) - self.assertEqual(ds.x.encoding, {}) + assert '_FillValue' not in actual.x.encoding + assert ds.x.encoding == {} # Test default encoding for implicit int: ds = Dataset({'x': ('y', np.arange(10, dtype='int16'))}) with self.roundtrip(ds) as actual: - self.assertTrue('_FillValue' not in actual.x.encoding) - self.assertEqual(ds.x.encoding, {}) + assert '_FillValue' not in actual.x.encoding + assert ds.x.encoding == {} def test_explicitly_omit_fill_value(self): ds = Dataset({'x': ('y', [np.pi, -np.pi])}) @@ -817,7 +812,7 @@ def test_explicitly_omit_fill_value_via_encoding_kwarg(self): kwargs = dict(encoding={'x': {'_FillValue': None}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: assert '_FillValue' not in actual.x.encoding - self.assertEqual(ds.y.encoding, {}) + assert ds.y.encoding == {} def test_explicitly_omit_fill_value_in_coord(self): ds = Dataset({'x': ('y', [np.pi, -np.pi])}, coords={'y': [0.0, 1.0]}) @@ -830,14 +825,14 @@ def test_explicitly_omit_fill_value_in_coord_via_encoding_kwarg(self): kwargs = dict(encoding={'y': {'_FillValue': None}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: assert '_FillValue' not in actual.y.encoding - self.assertEqual(ds.y.encoding, {}) + assert ds.y.encoding == {} def test_encoding_same_dtype(self): ds = Dataset({'x': ('y', np.arange(10.0, dtype='f4'))}) kwargs = dict(encoding={'x': {'dtype': 'f4'}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['dtype'], 'f4') - self.assertEqual(ds.x.encoding, {}) + assert actual.x.encoding['dtype'] == 'f4' + assert ds.x.encoding == {} def test_append_write(self): # regression for GH1215 @@ -1015,7 +1010,7 @@ def test_default_to_char_arrays(self): data = Dataset({'x': np.array(['foo', 'zzzz'], dtype='S')}) with self.roundtrip(data) as actual: assert_identical(data, actual) - self.assertEqual(actual['x'].dtype, np.dtype('S4')) + assert actual['x'].dtype == np.dtype('S4') def test_open_encodings(self): # Create a netCDF file with explicit time units @@ -1040,15 +1035,15 @@ def test_open_encodings(self): actual_encoding = dict((k, v) for k, v in iteritems(actual['time'].encoding) if k in expected['time'].encoding) - self.assertDictEqual(actual_encoding, - expected['time'].encoding) + assert actual_encoding == \ + expected['time'].encoding def test_dump_encodings(self): # regression test for #709 ds = Dataset({'x': ('y', np.arange(10.0))}) kwargs = dict(encoding={'x': {'zlib': True}}) with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertTrue(actual.x.encoding['zlib']) + assert actual.x.encoding['zlib'] def test_dump_and_open_encodings(self): # Create a netCDF file with explicit time units @@ -1066,8 +1061,7 @@ def test_dump_and_open_encodings(self): with create_tmp_file() as tmp_file2: xarray_dataset.to_netcdf(tmp_file2) with nc4.Dataset(tmp_file2, 'r') as ds: - self.assertEqual( - ds.variables['time'].getncattr('units'), units) + assert ds.variables['time'].getncattr('units') == units assert_array_equal( ds.variables['time'], np.arange(10) + 4) @@ -1080,7 +1074,7 @@ def test_compression_encoding(self): 'original_shape': data.var2.shape}) with self.roundtrip(data) as actual: for k, v in iteritems(data['var2'].encoding): - self.assertEqual(v, actual['var2'].encoding[k]) + assert v == actual['var2'].encoding[k] # regression test for #156 expected = data.isel(dim1=0) @@ -1095,14 +1089,14 @@ def test_encoding_kwarg_compression(self): with self.roundtrip(ds, save_kwargs=kwargs) as actual: assert_equal(actual, ds) - self.assertEqual(actual.x.encoding['dtype'], 'f4') - self.assertEqual(actual.x.encoding['zlib'], True) - self.assertEqual(actual.x.encoding['complevel'], 9) - self.assertEqual(actual.x.encoding['fletcher32'], True) - self.assertEqual(actual.x.encoding['chunksizes'], (5,)) - self.assertEqual(actual.x.encoding['shuffle'], True) + assert actual.x.encoding['dtype'] == 'f4' + assert actual.x.encoding['zlib'] + assert actual.x.encoding['complevel'] == 9 + assert actual.x.encoding['fletcher32'] + assert actual.x.encoding['chunksizes'] == (5,) + assert actual.x.encoding['shuffle'] - self.assertEqual(ds.x.encoding, {}) + assert ds.x.encoding == {} def test_encoding_chunksizes_unlimited(self): # regression test for GH1225 @@ -1162,10 +1156,10 @@ def test_already_open_dataset(self): v[...] = 42 nc = nc4.Dataset(tmp_file, mode='r') - with backends.NetCDF4DataStore(nc, autoclose=False) as store: - with open_dataset(store) as ds: - expected = Dataset({'x': ((), 42)}) - assert_identical(expected, ds) + store = backends.NetCDF4DataStore(nc) + with open_dataset(store) as ds: + expected = Dataset({'x': ((), 42)}) + assert_identical(expected, ds) def test_read_variable_len_strings(self): with create_tmp_file() as tmp_file: @@ -1183,8 +1177,7 @@ def test_read_variable_len_strings(self): @requires_netCDF4 -class NetCDF4DataTest(BaseNetCDF4Test, TestCase): - autoclose = False +class NetCDF4DataTest(BaseNetCDF4Test): @contextlib.contextmanager def create_store(self): @@ -1201,7 +1194,7 @@ def test_variable_order(self): ds.coords['c'] = 4 with self.roundtrip(ds) as actual: - self.assertEqual(list(ds.variables), list(actual.variables)) + assert list(ds.variables) == list(actual.variables) def test_unsorted_index_raises(self): # should be fixed in netcdf4 v1.2.1 @@ -1220,7 +1213,7 @@ def test_unsorted_index_raises(self): try: ds2.randovar.values except IndexError as err: - self.assertIn('first by calling .load', str(err)) + assert 'first by calling .load' in str(err) def test_88_character_filename_segmentation_fault(self): # should be fixed in netcdf4 v1.3.1 @@ -1250,9 +1243,13 @@ def test_setncattr_string(self): totest.attrs['bar']) assert one_string == totest.attrs['baz'] - -class NetCDF4DataStoreAutocloseTrue(NetCDF4DataTest): - autoclose = True + def test_autoclose_future_warning(self): + data = create_test_data() + with create_tmp_file() as tmp_file: + self.save(data, tmp_file) + with pytest.warns(FutureWarning): + with self.open(tmp_file, autoclose=True) as actual: + assert_identical(data, actual) @requires_netCDF4 @@ -1293,10 +1290,6 @@ def test_write_inconsistent_chunks(self): assert actual['y'].encoding['chunksizes'] == (100, 50) -class NetCDF4ViaDaskDataTestAutocloseTrue(NetCDF4ViaDaskDataTest): - autoclose = True - - @requires_zarr class BaseZarrTest(CFEncodedDataTest): @@ -1335,17 +1328,17 @@ def test_auto_chunk(self): original, open_kwargs={'auto_chunk': False}) as actual: for k, v in actual.variables.items(): # only index variables should be in memory - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) # there should be no chunks - self.assertEqual(v.chunks, None) + assert v.chunks is None with self.roundtrip( original, open_kwargs={'auto_chunk': True}) as actual: for k, v in actual.variables.items(): # only index variables should be in memory - self.assertEqual(v._in_memory, k in actual.dims) + assert v._in_memory == (k in actual.dims) # chunk size should be the same as original - self.assertEqual(v.chunks, original[k].chunks) + assert v.chunks == original[k].chunks def test_write_uneven_dask_chunks(self): # regression for GH#2225 @@ -1365,7 +1358,7 @@ def test_chunk_encoding(self): data['var2'].encoding.update({'chunks': chunks}) with self.roundtrip(data) as actual: - self.assertEqual(chunks, actual['var2'].encoding['chunks']) + assert chunks == actual['var2'].encoding['chunks'] # expect an error with non-integer chunks data['var2'].encoding.update({'chunks': (5, 4.5)}) @@ -1382,7 +1375,7 @@ def test_chunk_encoding_with_dask(self): # zarr automatically gets chunk information from dask chunks ds_chunk4 = ds.chunk({'x': 4}) with self.roundtrip(ds_chunk4) as actual: - self.assertEqual((4,), actual['var1'].encoding['chunks']) + assert (4,) == actual['var1'].encoding['chunks'] # should fail if dask_chunks are irregular... ds_chunk_irreg = ds.chunk({'x': (5, 4, 3)}) @@ -1395,15 +1388,14 @@ def test_chunk_encoding_with_dask(self): # ... except if the last chunk is smaller than the first ds_chunk_irreg = ds.chunk({'x': (5, 5, 2)}) with self.roundtrip(ds_chunk_irreg) as actual: - self.assertEqual((5,), actual['var1'].encoding['chunks']) + assert (5,) == actual['var1'].encoding['chunks'] # - encoding specified - # specify compatible encodings for chunk_enc in 4, (4, ): ds_chunk4['var1'].encoding.update({'chunks': chunk_enc}) with self.roundtrip(ds_chunk4) as actual: - self.assertEqual((4,), actual['var1'].encoding['chunks']) - + assert (4,) == actual['var1'].encoding['chunks'] # TODO: remove this failure once syncronized overlapping writes are # supported by xarray @@ -1533,14 +1525,14 @@ def test_encoding_chunksizes(self): @requires_zarr -class ZarrDictStoreTest(BaseZarrTest, TestCase): +class ZarrDictStoreTest(BaseZarrTest): @contextlib.contextmanager def create_zarr_target(self): yield {} @requires_zarr -class ZarrDirectoryStoreTest(BaseZarrTest, TestCase): +class ZarrDirectoryStoreTest(BaseZarrTest): @contextlib.contextmanager def create_zarr_target(self): with create_tmp_file(suffix='.zarr') as tmp: @@ -1563,7 +1555,7 @@ def test_append_overwrite_values(self): @requires_scipy -class ScipyInMemoryDataTest(ScipyWriteTest, TestCase): +class ScipyInMemoryDataTest(ScipyWriteTest): engine = 'scipy' @contextlib.contextmanager @@ -1575,21 +1567,16 @@ def test_to_netcdf_explicit_engine(self): # regression test for GH1321 Dataset({'foo': 42}).to_netcdf(engine='scipy') - @pytest.mark.skipif(PY2, reason='cannot pickle BytesIO on Python 2') - def test_bytesio_pickle(self): + def test_bytes_pickle(self): data = Dataset({'foo': ('x', [1, 2, 3])}) - fobj = BytesIO(data.to_netcdf()) - with open_dataset(fobj, autoclose=self.autoclose) as ds: + fobj = data.to_netcdf() + with self.open(fobj) as ds: unpickled = pickle.loads(pickle.dumps(ds)) assert_identical(unpickled, data) -class ScipyInMemoryDataTestAutocloseTrue(ScipyInMemoryDataTest): - autoclose = True - - @requires_scipy -class ScipyFileObjectTest(ScipyWriteTest, TestCase): +class ScipyFileObjectTest(ScipyWriteTest): engine = 'scipy' @contextlib.contextmanager @@ -1617,7 +1604,7 @@ def test_pickle_dataarray(self): @requires_scipy -class ScipyFilePathTest(ScipyWriteTest, TestCase): +class ScipyFilePathTest(ScipyWriteTest): engine = 'scipy' @contextlib.contextmanager @@ -1641,7 +1628,7 @@ def test_netcdf3_endianness(self): # regression test for GH416 expected = open_example_dataset('bears.nc', engine='scipy') for var in expected.variables.values(): - self.assertTrue(var.dtype.isnative) + assert var.dtype.isnative @requires_netCDF4 def test_nc4_scipy(self): @@ -1653,12 +1640,8 @@ def test_nc4_scipy(self): open_dataset(tmp_file, engine='scipy') -class ScipyFilePathTestAutocloseTrue(ScipyFilePathTest): - autoclose = True - - @requires_netCDF4 -class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only, TestCase): +class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only): engine = 'netcdf4' file_format = 'NETCDF3_CLASSIC' @@ -1677,13 +1660,9 @@ def test_encoding_kwarg_vlen_string(self): pass -class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest): - autoclose = True - - @requires_netCDF4 class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only, - TestCase): + object): engine = 'netcdf4' file_format = 'NETCDF4_CLASSIC' @@ -1695,13 +1674,8 @@ def create_store(self): yield store -class NetCDF4ClassicViaNetCDF4DataTestAutocloseTrue( - NetCDF4ClassicViaNetCDF4DataTest): - autoclose = True - - @requires_scipy_or_netCDF4 -class GenericNetCDFDataTest(CFEncodedDataTest, NetCDF3Only, TestCase): +class GenericNetCDFDataTest(CFEncodedDataTest, NetCDF3Only): # verify that we can read and write netCDF3 files as long as we have scipy # or netCDF4-python installed file_format = 'netcdf3_64bit' @@ -1755,34 +1729,30 @@ def test_encoding_unlimited_dims(self): ds = Dataset({'x': ('y', np.arange(10.0))}) with self.roundtrip(ds, save_kwargs=dict(unlimited_dims=['y'])) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) # Regression test for https://github.com/pydata/xarray/issues/2134 with self.roundtrip(ds, save_kwargs=dict(unlimited_dims='y')) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) ds.encoding = {'unlimited_dims': ['y']} with self.roundtrip(ds) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) # Regression test for https://github.com/pydata/xarray/issues/2134 ds.encoding = {'unlimited_dims': 'y'} with self.roundtrip(ds) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) -class GenericNetCDFDataTestAutocloseTrue(GenericNetCDFDataTest): - autoclose = True - - @requires_h5netcdf @requires_netCDF4 -class H5NetCDFDataTest(BaseNetCDF4Test, TestCase): +class H5NetCDFDataTest(BaseNetCDF4Test): engine = 'h5netcdf' @contextlib.contextmanager @@ -1790,10 +1760,14 @@ def create_store(self): with create_tmp_file() as tmp_file: yield backends.H5NetCDFStore(tmp_file, 'w') + @pytest.mark.filterwarnings('ignore:complex dtypes are supported by h5py') def test_complex(self): expected = Dataset({'x': ('y', np.ones(5) + 1j * np.ones(5))}) - with self.roundtrip(expected) as actual: - assert_equal(expected, actual) + with pytest.warns(FutureWarning): + # TODO: make it possible to write invalid netCDF files from xarray + # without a warning + with self.roundtrip(expected) as actual: + assert_equal(expected, actual) @pytest.mark.xfail(reason='https://github.com/pydata/xarray/issues/535') def test_cross_engine_read_write_netcdf4(self): @@ -1822,11 +1796,11 @@ def test_encoding_unlimited_dims(self): ds = Dataset({'x': ('y', np.arange(10.0))}) with self.roundtrip(ds, save_kwargs=dict(unlimited_dims=['y'])) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) ds.encoding = {'unlimited_dims': ['y']} with self.roundtrip(ds) as actual: - self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert actual.encoding['unlimited_dims'] == set('y') assert_equal(ds, actual) def test_compression_encoding_h5py(self): @@ -1857,7 +1831,7 @@ def test_compression_encoding_h5py(self): compr_out.update(compr_common) with self.roundtrip(data) as actual: for k, v in compr_out.items(): - self.assertEqual(v, actual['var2'].encoding[k]) + assert v == actual['var2'].encoding[k] def test_compression_check_encoding_h5py(self): """When mismatched h5py and NetCDF4-Python encodings are expressed @@ -1898,20 +1872,14 @@ def test_dump_encodings_h5py(self): kwargs = {'encoding': {'x': { 'compression': 'gzip', 'compression_opts': 9}}} with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['zlib'], True) - self.assertEqual(actual.x.encoding['complevel'], 9) + assert actual.x.encoding['zlib'] + assert actual.x.encoding['complevel'] == 9 kwargs = {'encoding': {'x': { 'compression': 'lzf', 'compression_opts': None}}} with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertEqual(actual.x.encoding['compression'], 'lzf') - self.assertEqual(actual.x.encoding['compression_opts'], None) - - -# tests pending h5netcdf fix -@unittest.skip -class H5NetCDFDataTestAutocloseTrue(H5NetCDFDataTest): - autoclose = True + assert actual.x.encoding['compression'] == 'lzf' + assert actual.x.encoding['compression_opts'] is None @pytest.fixture(params=['scipy', 'netcdf4', 'h5netcdf', 'pynio']) @@ -1919,14 +1887,19 @@ def readengine(request): return request.param -@pytest.fixture(params=[1, 100]) +@pytest.fixture(params=[1, 20]) def nfiles(request): return request.param -@pytest.fixture(params=[True, False]) -def autoclose(request): - return request.param +@pytest.fixture(params=[5, None]) +def file_cache_maxsize(request): + maxsize = request.param + if maxsize is not None: + with set_options(file_cache_maxsize=maxsize): + yield maxsize + else: + yield maxsize @pytest.fixture(params=[True, False]) @@ -1949,8 +1922,8 @@ def skip_if_not_engine(engine): pytest.importorskip(engine) -def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, - chunks): +def test_open_mfdataset_manyfiles(readengine, nfiles, parallel, chunks, + file_cache_maxsize): # skip certain combinations skip_if_not_engine(readengine) @@ -1958,9 +1931,6 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, if not has_dask and parallel: pytest.skip('parallel requires dask') - if readengine == 'h5netcdf' and autoclose: - pytest.skip('h5netcdf does not support autoclose yet') - if ON_WINDOWS: pytest.skip('Skipping on Windows') @@ -1976,7 +1946,7 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, # check that calculation on opened datasets works properly actual = open_mfdataset(tmpfiles, engine=readengine, parallel=parallel, - autoclose=autoclose, chunks=chunks) + chunks=chunks) # check that using open_mfdataset returns dask arrays for variables assert isinstance(actual['foo'].data, dask_array_type) @@ -1985,7 +1955,7 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, @requires_scipy_or_netCDF4 -class OpenMFDatasetWithDataVarsAndCoordsKwTest(TestCase): +class OpenMFDatasetWithDataVarsAndCoordsKwTest(object): coord_name = 'lon' var_name = 'v1' @@ -2056,9 +2026,9 @@ def test_common_coord_when_datavars_all(self): var_shape = ds[self.var_name].shape - self.assertEqual(var_shape, coord_shape) - self.assertNotEqual(coord_shape1, coord_shape) - self.assertNotEqual(coord_shape2, coord_shape) + assert var_shape == coord_shape + assert coord_shape1 != coord_shape + assert coord_shape2 != coord_shape def test_common_coord_when_datavars_minimal(self): opt = 'minimal' @@ -2073,9 +2043,9 @@ def test_common_coord_when_datavars_minimal(self): var_shape = ds[self.var_name].shape - self.assertNotEqual(var_shape, coord_shape) - self.assertEqual(coord_shape1, coord_shape) - self.assertEqual(coord_shape2, coord_shape) + assert var_shape != coord_shape + assert coord_shape1 == coord_shape + assert coord_shape2 == coord_shape def test_invalid_data_vars_value_should_fail(self): @@ -2093,7 +2063,7 @@ def test_invalid_data_vars_value_should_fail(self): @requires_dask @requires_scipy @requires_netCDF4 -class DaskTest(TestCase, DatasetIOTestCases): +class DaskTest(DatasetIOTestCases): @contextlib.contextmanager def create_store(self): yield Dataset() @@ -2133,10 +2103,10 @@ def test_roundtrip_cftime_datetime_data_enable_cftimeindex(self): with xr.set_options(enable_cftimeindex=True): with self.roundtrip(expected) as actual: abs_diff = abs(actual.t.values - expected_decoded_t) - self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) + assert (abs_diff <= np.timedelta64(1, 's')).all() abs_diff = abs(actual.t0.values - expected_decoded_t0) - self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) + assert (abs_diff <= np.timedelta64(1, 's')).all() def test_roundtrip_cftime_datetime_data_disable_cftimeindex(self): # Override method in DatasetIOTestCases - remove not applicable @@ -2153,10 +2123,10 @@ def test_roundtrip_cftime_datetime_data_disable_cftimeindex(self): with xr.set_options(enable_cftimeindex=False): with self.roundtrip(expected) as actual: abs_diff = abs(actual.t.values - expected_decoded_t) - self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) + assert (abs_diff <= np.timedelta64(1, 's')).all() abs_diff = abs(actual.t0.values - expected_decoded_t0) - self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) + assert (abs_diff <= np.timedelta64(1, 's')).all() def test_write_store(self): # Override method in DatasetIOTestCases - not applicable to dask @@ -2175,22 +2145,20 @@ def test_open_mfdataset(self): with create_tmp_file() as tmp2: original.isel(x=slice(5)).to_netcdf(tmp1) original.isel(x=slice(5, 10)).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: - self.assertIsInstance(actual.foo.variable.data, da.Array) - self.assertEqual(actual.foo.variable.data.chunks, - ((5, 5),)) + with open_mfdataset([tmp1, tmp2]) as actual: + assert isinstance(actual.foo.variable.data, da.Array) + assert actual.foo.variable.data.chunks == \ + ((5, 5),) assert_identical(original, actual) - with open_mfdataset([tmp1, tmp2], chunks={'x': 3}, - autoclose=self.autoclose) as actual: - self.assertEqual(actual.foo.variable.data.chunks, - ((3, 2, 3, 2),)) + with open_mfdataset([tmp1, tmp2], chunks={'x': 3}) as actual: + assert actual.foo.variable.data.chunks == \ + ((3, 2, 3, 2),) with raises_regex(IOError, 'no files to open'): - open_mfdataset('foo-bar-baz-*.nc', autoclose=self.autoclose) + open_mfdataset('foo-bar-baz-*.nc') with raises_regex(ValueError, 'wild-card'): - open_mfdataset('http://some/remote/uri', autoclose=self.autoclose) + open_mfdataset('http://some/remote/uri') @requires_pathlib def test_open_mfdataset_pathlib(self): @@ -2201,8 +2169,7 @@ def test_open_mfdataset_pathlib(self): tmp2 = Path(tmp2) original.isel(x=slice(5)).to_netcdf(tmp1) original.isel(x=slice(5, 10)).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(original, actual) def test_attrs_mfdataset(self): @@ -2218,7 +2185,7 @@ def test_attrs_mfdataset(self): with open_mfdataset([tmp1, tmp2]) as actual: # presumes that attributes inherited from # first dataset loaded - self.assertEqual(actual.test1, ds1.test1) + assert actual.test1 == ds1.test1 # attributes from ds2 are not retained, e.g., with raises_regex(AttributeError, 'no attribute'): @@ -2233,8 +2200,7 @@ def preprocess(ds): return ds.assign_coords(z=0) expected = preprocess(original) - with open_mfdataset(tmp, preprocess=preprocess, - autoclose=self.autoclose) as actual: + with open_mfdataset(tmp, preprocess=preprocess) as actual: assert_identical(expected, actual) def test_save_mfdataset_roundtrip(self): @@ -2244,8 +2210,7 @@ def test_save_mfdataset_roundtrip(self): with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: save_mfdataset(datasets, [tmp1, tmp2]) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) def test_save_mfdataset_invalid(self): @@ -2271,15 +2236,14 @@ def test_save_mfdataset_pathlib_roundtrip(self): tmp1 = Path(tmp1) tmp2 = Path(tmp2) save_mfdataset(datasets, [tmp1, tmp2]) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) def test_open_and_do_math(self): original = Dataset({'foo': ('x', np.random.randn(10))}) with create_tmp_file() as tmp: original.to_netcdf(tmp) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: actual = 1.0 * ds assert_allclose(original, actual, decode_bytes=False) @@ -2289,8 +2253,7 @@ def test_open_mfdataset_concat_dim_none(self): data = Dataset({'x': 0}) data.to_netcdf(tmp1) Dataset({'x': np.nan}).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], concat_dim=None, - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2], concat_dim=None) as actual: assert_identical(data, actual) def test_open_dataset(self): @@ -2298,13 +2261,13 @@ def test_open_dataset(self): with create_tmp_file() as tmp: original.to_netcdf(tmp) with open_dataset(tmp, chunks={'x': 5}) as actual: - self.assertIsInstance(actual.foo.variable.data, da.Array) - self.assertEqual(actual.foo.variable.data.chunks, ((5, 5),)) + assert isinstance(actual.foo.variable.data, da.Array) + assert actual.foo.variable.data.chunks == ((5, 5),) assert_identical(original, actual) with open_dataset(tmp, chunks=5) as actual: assert_identical(original, actual) with open_dataset(tmp) as actual: - self.assertIsInstance(actual.foo.variable.data, np.ndarray) + assert isinstance(actual.foo.variable.data, np.ndarray) assert_identical(original, actual) def test_open_single_dataset(self): @@ -2317,8 +2280,7 @@ def test_open_single_dataset(self): {'baz': [100]}) with create_tmp_file() as tmp: original.to_netcdf(tmp) - with open_mfdataset([tmp], concat_dim=dim, - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp], concat_dim=dim) as actual: assert_identical(expected, actual) def test_dask_roundtrip(self): @@ -2337,16 +2299,16 @@ def test_deterministic_names(self): with create_tmp_file() as tmp: data = create_test_data() data.to_netcdf(tmp) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: original_names = dict((k, v.data.name) for k, v in ds.data_vars.items()) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: repeat_names = dict((k, v.data.name) for k, v in ds.data_vars.items()) for var_name, dask_name in original_names.items(): - self.assertIn(var_name, dask_name) - self.assertEqual(dask_name[:13], 'open_dataset-') - self.assertEqual(original_names, repeat_names) + assert var_name in dask_name + assert dask_name[:13] == 'open_dataset-' + assert original_names == repeat_names def test_dataarray_compute(self): # Test DataArray.compute() on dask backend. @@ -2354,48 +2316,29 @@ def test_dataarray_compute(self): # however dask is the only tested backend which supports DataArrays actual = DataArray([1, 2]).chunk() computed = actual.compute() - self.assertFalse(actual._in_memory) - self.assertTrue(computed._in_memory) + assert not actual._in_memory + assert computed._in_memory assert_allclose(actual, computed, decode_bytes=False) - def test_to_netcdf_compute_false_roundtrip(self): - from dask.delayed import Delayed - - original = create_test_data().chunk() - - with create_tmp_file() as tmp_file: - # dataset, path, **kwargs): - delayed_obj = self.save(original, tmp_file, compute=False) - assert isinstance(delayed_obj, Delayed) - delayed_obj.compute() - - with self.open(tmp_file) as actual: - assert_identical(original, actual) - def test_save_mfdataset_compute_false_roundtrip(self): from dask.delayed import Delayed original = Dataset({'foo': ('x', np.random.randn(10))}).chunk() datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))] - with create_tmp_file() as tmp1: - with create_tmp_file() as tmp2: + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp1: + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp2: delayed_obj = save_mfdataset(datasets, [tmp1, tmp2], engine=self.engine, compute=False) assert isinstance(delayed_obj, Delayed) delayed_obj.compute() - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) -class DaskTestAutocloseTrue(DaskTest): - autoclose = True - - @requires_scipy_or_netCDF4 @requires_pydap -class PydapTest(TestCase): +class PydapTest(object): def convert_to_pydap_dataset(self, original): from pydap.model import GridType, BaseType, DatasetType ds = DatasetType('bears', **original.attrs) @@ -2427,8 +2370,8 @@ def test_cmp_local_file(self): assert_equal(actual, expected) # global attributes should be global attributes on the dataset - self.assertNotIn('NC_GLOBAL', actual.attrs) - self.assertIn('history', actual.attrs) + assert 'NC_GLOBAL' not in actual.attrs + assert 'history' in actual.attrs # we don't check attributes exactly with assertDatasetIdentical() # because the test DAP server seems to insert some extra @@ -2436,8 +2379,7 @@ def test_cmp_local_file(self): assert actual.attrs.keys() == expected.attrs.keys() with self.create_datasets() as (actual, expected): - assert_equal( - actual.isel(l=2), expected.isel(l=2)) # noqa: E741 + assert_equal(actual.isel(l=2), expected.isel(l=2)) # noqa with self.create_datasets() as (actual, expected): assert_equal(actual.isel(i=0, j=-1), @@ -2497,15 +2439,14 @@ def test_session(self): @requires_scipy @requires_pynio -class PyNioTest(ScipyWriteTest, TestCase): +class PyNioTest(ScipyWriteTest): def test_write_store(self): # pynio is read-only for now pass @contextlib.contextmanager def open(self, path, **kwargs): - with open_dataset(path, engine='pynio', autoclose=self.autoclose, - **kwargs) as ds: + with open_dataset(path, engine='pynio', **kwargs) as ds: yield ds def save(self, dataset, path, **kwargs): @@ -2523,18 +2464,12 @@ def test_weakrefs(self): assert_identical(actual, expected) -class PyNioTestAutocloseTrue(PyNioTest): - autoclose = True - - @requires_pseudonetcdf -class PseudoNetCDFFormatTest(TestCase): - autoclose = True +@pytest.mark.filterwarnings('ignore:IOAPI_ISPH is assumed to be 6370000') +class PseudoNetCDFFormatTest(object): def open(self, path, **kwargs): - return open_dataset(path, engine='pseudonetcdf', - autoclose=self.autoclose, - **kwargs) + return open_dataset(path, engine='pseudonetcdf', **kwargs) @contextlib.contextmanager def roundtrip(self, data, save_kwargs={}, open_kwargs={}, @@ -2551,7 +2486,6 @@ def test_ict_format(self): """ ictfile = open_example_dataset('example.ict', engine='pseudonetcdf', - autoclose=False, backend_kwargs={'format': 'ffi1001'}) stdattr = { 'fill_value': -9999.0, @@ -2649,7 +2583,6 @@ def test_ict_format_write(self): fmtkw = {'format': 'ffi1001'} expected = open_example_dataset('example.ict', engine='pseudonetcdf', - autoclose=False, backend_kwargs=fmtkw) with self.roundtrip(expected, save_kwargs=fmtkw, open_kwargs={'backend_kwargs': fmtkw}) as actual: @@ -2659,14 +2592,10 @@ def test_uamiv_format_read(self): """ Open a CAMx file and test data variables """ - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning, - message=('IOAPI_ISPH is assumed to be ' + - '6370000.; consistent with WRF')) - camxfile = open_example_dataset('example.uamiv', - engine='pseudonetcdf', - autoclose=True, - backend_kwargs={'format': 'uamiv'}) + + camxfile = open_example_dataset('example.uamiv', + engine='pseudonetcdf', + backend_kwargs={'format': 'uamiv'}) data = np.arange(20, dtype='f').reshape(1, 1, 4, 5) expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data, dict(units='ppm', long_name='O3'.ljust(16), @@ -2688,17 +2617,13 @@ def test_uamiv_format_mfread(self): """ Open a CAMx file and test data variables """ - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning, - message=('IOAPI_ISPH is assumed to be ' + - '6370000.; consistent with WRF')) - camxfile = open_example_mfdataset( - ['example.uamiv', - 'example.uamiv'], - engine='pseudonetcdf', - autoclose=True, - concat_dim='TSTEP', - backend_kwargs={'format': 'uamiv'}) + + camxfile = open_example_mfdataset( + ['example.uamiv', + 'example.uamiv'], + engine='pseudonetcdf', + concat_dim='TSTEP', + backend_kwargs={'format': 'uamiv'}) data1 = np.arange(20, dtype='f').reshape(1, 1, 4, 5) data = np.concatenate([data1] * 2, axis=0) @@ -2710,30 +2635,28 @@ def test_uamiv_format_mfread(self): data1 = np.array(['2002-06-03'], 'datetime64[ns]') data = np.concatenate([data1] * 2, axis=0) - expected = xr.Variable(('TSTEP',), data, - dict(bounds='time_bounds', - long_name=('synthesized time coordinate ' + - 'from SDATE, STIME, STEP ' + - 'global attributes'))) + attrs = dict(bounds='time_bounds', + long_name=('synthesized time coordinate ' + + 'from SDATE, STIME, STEP ' + + 'global attributes')) + expected = xr.Variable(('TSTEP',), data, attrs) actual = camxfile.variables['time'] assert_allclose(expected, actual) camxfile.close() def test_uamiv_format_write(self): fmtkw = {'format': 'uamiv'} - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning, - message=('IOAPI_ISPH is assumed to be ' + - '6370000.; consistent with WRF')) - expected = open_example_dataset('example.uamiv', - engine='pseudonetcdf', - autoclose=False, - backend_kwargs=fmtkw) + + expected = open_example_dataset('example.uamiv', + engine='pseudonetcdf', + backend_kwargs=fmtkw) with self.roundtrip(expected, save_kwargs=fmtkw, open_kwargs={'backend_kwargs': fmtkw}) as actual: assert_identical(expected, actual) + expected.close() + def save(self, dataset, path, **save_kwargs): import PseudoNetCDF as pnc pncf = pnc.PseudoNetCDFFile() @@ -2798,7 +2721,7 @@ def create_tmp_geotiff(nx=4, ny=3, nz=3, @requires_rasterio -class TestRasterio(TestCase): +class TestRasterio(object): @requires_scipy_or_netCDF4 def test_serialization(self): @@ -2843,7 +2766,8 @@ def test_non_rectilinear(self): assert len(rioda.attrs['transform']) == 6 # See if a warning is raised if we force it - with self.assertWarns("transformation isn't rectilinear"): + with pytest.warns(Warning, + match="transformation isn't rectilinear"): with xr.open_rasterio(tmp_file, parse_coordinates=True) as rioda: assert 'x' not in rioda.coords @@ -2934,6 +2858,10 @@ def test_indexing(self): assert_allclose(expected.isel(**ind), actual.isel(**ind)) assert not actual.variable._in_memory + ind = {'band': 0, 'x': np.array([0, 0]), 'y': np.array([1, 1, 1])} + assert_allclose(expected.isel(**ind), actual.isel(**ind)) + assert not actual.variable._in_memory + # minus-stepped slice ind = {'band': np.array([2, 1, 0]), 'x': slice(-1, None, -1), 'y': 0} @@ -3030,7 +2958,7 @@ def test_chunks(self): with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual: import dask.array as da - self.assertIsInstance(actual.data, da.Array) + assert isinstance(actual.data, da.Array) assert 'open_rasterio' in actual.data.name # do some arithmetic @@ -3111,7 +3039,7 @@ def test_no_mftime(self): with mock.patch('os.path.getmtime', side_effect=OSError): with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual: import dask.array as da - self.assertIsInstance(actual.data, da.Array) + assert isinstance(actual.data, da.Array) assert_allclose(actual, expected) @network @@ -3124,10 +3052,10 @@ def test_http_url(self): # make sure chunking works with xr.open_rasterio(url, chunks=(1, 256, 256)) as actual: import dask.array as da - self.assertIsInstance(actual.data, da.Array) + assert isinstance(actual.data, da.Array) -class TestEncodingInvalid(TestCase): +class TestEncodingInvalid(object): def test_extract_nc4_variable_encoding(self): var = xr.Variable(('x',), [1, 2, 3], {}, {'foo': 'bar'}) @@ -3136,12 +3064,12 @@ def test_extract_nc4_variable_encoding(self): var = xr.Variable(('x',), [1, 2, 3], {}, {'chunking': (2, 1)}) encoding = _extract_nc4_variable_encoding(var) - self.assertEqual({}, encoding) + assert {} == encoding # regression test var = xr.Variable(('x',), [1, 2, 3], {}, {'shuffle': True}) encoding = _extract_nc4_variable_encoding(var, raise_on_invalid=True) - self.assertEqual({'shuffle': True}, encoding) + assert {'shuffle': True} == encoding def test_extract_h5nc_encoding(self): # not supported with h5netcdf (yet) @@ -3156,7 +3084,7 @@ class MiscObject: @requires_netCDF4 -class TestValidateAttrs(TestCase): +class TestValidateAttrs(object): def test_validating_attrs(self): def new_dataset(): return Dataset({'data': ('y', np.arange(10.0))}, @@ -3256,7 +3184,7 @@ def new_dataset_and_coord_attrs(): @requires_scipy_or_netCDF4 -class TestDataArrayToNetCDF(TestCase): +class TestDataArrayToNetCDF(object): def test_dataarray_to_netcdf_no_name(self): original_da = DataArray(np.arange(12).reshape((3, 4))) @@ -3317,32 +3245,6 @@ def test_dataarray_to_netcdf_no_name_pathlib(self): assert_identical(original_da, loaded_da) -def test_pickle_reconstructor(): - - lines = ['foo bar spam eggs'] - - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp: - with open(tmp, 'w') as f: - f.writelines(lines) - - obj = PickleByReconstructionWrapper(open, tmp) - - assert obj.value.readlines() == lines - - p_obj = pickle.dumps(obj) - obj.value.close() # for windows - obj2 = pickle.loads(p_obj) - - assert obj2.value.readlines() == lines - - # roundtrip again to make sure we can fully restore the state - p_obj2 = pickle.dumps(obj2) - obj2.value.close() # for windows - obj3 = pickle.loads(p_obj2) - - assert obj3.value.readlines() == lines - - @requires_scipy_or_netCDF4 def test_no_warning_from_dask_effective_get(): with create_tmp_file() as tmpfile: diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py new file mode 100644 index 00000000000..591c981cd45 --- /dev/null +++ b/xarray/tests/test_backends_file_manager.py @@ -0,0 +1,114 @@ +import pickle +import threading +try: + from unittest import mock +except ImportError: + import mock # noqa: F401 + +import pytest + +from xarray.backends.file_manager import CachingFileManager +from xarray.backends.lru_cache import LRUCache + + +@pytest.fixture(params=[1, 2, 3, None]) +def file_cache(request): + maxsize = request.param + if maxsize is None: + yield {} + else: + yield LRUCache(maxsize) + + +def test_file_manager_mock_write(file_cache): + mock_file = mock.Mock() + opener = mock.Mock(spec=open, return_value=mock_file) + lock = mock.MagicMock(spec=threading.Lock()) + + manager = CachingFileManager( + opener, 'filename', lock=lock, cache=file_cache) + f = manager.acquire() + f.write('contents') + manager.close() + + assert not file_cache + opener.assert_called_once_with('filename') + mock_file.write.assert_called_once_with('contents') + mock_file.close.assert_called_once_with() + lock.__enter__.assert_has_calls([mock.call(), mock.call()]) + + +def test_file_manager_write_consecutive(tmpdir, file_cache): + path1 = str(tmpdir.join('testing1.txt')) + path2 = str(tmpdir.join('testing2.txt')) + manager1 = CachingFileManager(open, path1, mode='w', cache=file_cache) + manager2 = CachingFileManager(open, path2, mode='w', cache=file_cache) + f1a = manager1.acquire() + f1a.write('foo') + f1a.flush() + f2 = manager2.acquire() + f2.write('bar') + f2.flush() + f1b = manager1.acquire() + f1b.write('baz') + assert (getattr(file_cache, 'maxsize', float('inf')) > 1) == (f1a is f1b) + manager1.close() + manager2.close() + + with open(path1, 'r') as f: + assert f.read() == 'foobaz' + with open(path2, 'r') as f: + assert f.read() == 'bar' + + +def test_file_manager_write_concurrent(tmpdir, file_cache): + path = str(tmpdir.join('testing.txt')) + manager = CachingFileManager(open, path, mode='w', cache=file_cache) + f1 = manager.acquire() + f2 = manager.acquire() + f3 = manager.acquire() + assert f1 is f2 + assert f2 is f3 + f1.write('foo') + f1.flush() + f2.write('bar') + f2.flush() + f3.write('baz') + f3.flush() + manager.close() + + with open(path, 'r') as f: + assert f.read() == 'foobarbaz' + + +def test_file_manager_write_pickle(tmpdir, file_cache): + path = str(tmpdir.join('testing.txt')) + manager = CachingFileManager(open, path, mode='w', cache=file_cache) + f = manager.acquire() + f.write('foo') + f.flush() + manager2 = pickle.loads(pickle.dumps(manager)) + f2 = manager2.acquire() + f2.write('bar') + manager2.close() + manager.close() + + with open(path, 'r') as f: + assert f.read() == 'foobar' + + +def test_file_manager_read(tmpdir, file_cache): + path = str(tmpdir.join('testing.txt')) + + with open(path, 'w') as f: + f.write('foobar') + + manager = CachingFileManager(open, path, cache=file_cache) + f = manager.acquire() + assert f.read() == 'foobar' + manager.close() + + +def test_file_manager_invalid_kwargs(): + with pytest.raises(TypeError): + CachingFileManager(open, 'dummy', mode='w', invalid=True) diff --git a/xarray/tests/test_backends_locks.py b/xarray/tests/test_backends_locks.py new file mode 100644 index 00000000000..5f83321802e --- /dev/null +++ b/xarray/tests/test_backends_locks.py @@ -0,0 +1,13 @@ +import threading + +from xarray.backends import locks + + +def test_threaded_lock(): + lock1 = locks._get_threaded_lock('foo') + assert isinstance(lock1, type(threading.Lock())) + lock2 = locks._get_threaded_lock('foo') + assert lock1 is lock2 + + lock3 = locks._get_threaded_lock('bar') + assert lock1 is not lock3 diff --git a/xarray/tests/test_backends_lru_cache.py b/xarray/tests/test_backends_lru_cache.py new file mode 100644 index 00000000000..03eb6dcf208 --- /dev/null +++ b/xarray/tests/test_backends_lru_cache.py @@ -0,0 +1,91 @@ +try: + from unittest import mock +except ImportError: + import mock # noqa: F401 + +import pytest + +from xarray.backends.lru_cache import LRUCache + + +def test_simple(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + + assert cache['x'] == 1 + assert cache['y'] == 2 + assert len(cache) == 2 + assert dict(cache) == {'x': 1, 'y': 2} + assert list(cache.keys()) == ['x', 'y'] + assert list(cache.items()) == [('x', 1), ('y', 2)] + + cache['z'] = 3 + assert len(cache) == 2 + assert list(cache.items()) == [('y', 2), ('z', 3)] + + +def test_trivial(): + cache = LRUCache(maxsize=0) + cache['x'] = 1 + assert len(cache) == 0 + + +def test_invalid(): + with pytest.raises(TypeError): + LRUCache(maxsize=None) + with pytest.raises(ValueError): + LRUCache(maxsize=-1) + + +def test_update_priority(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + assert list(cache) == ['x', 'y'] + assert 'x' in cache # contains + assert list(cache) == ['y', 'x'] + assert cache['y'] == 2 # getitem + assert list(cache) == ['x', 'y'] + cache['x'] = 3 # setitem + assert list(cache.items()) == [('y', 2), ('x', 3)] + + +def test_del(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + del cache['x'] + assert dict(cache) == {'y': 2} + + +def test_on_evict(): + on_evict = mock.Mock() + cache = LRUCache(maxsize=1, on_evict=on_evict) + cache['x'] = 1 + cache['y'] = 2 + on_evict.assert_called_once_with('x', 1) + + +def test_on_evict_trivial(): + on_evict = mock.Mock() + cache = LRUCache(maxsize=0, on_evict=on_evict) + cache['x'] = 1 + on_evict.assert_called_once_with('x', 1) + + +def test_resize(): + cache = LRUCache(maxsize=2) + assert cache.maxsize == 2 + cache['w'] = 0 + cache['x'] = 1 + cache['y'] = 2 + assert list(cache.items()) == [('x', 1), ('y', 2)] + cache.maxsize = 10 + cache['z'] = 3 + assert list(cache.items()) == [('x', 1), ('y', 2), ('z', 3)] + cache.maxsize = 1 + assert list(cache.items()) == [('z', 3)] + + with pytest.raises(ValueError): + cache.maxsize = -1 diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py new file mode 100644 index 00000000000..7acd764cab3 --- /dev/null +++ b/xarray/tests/test_cftime_offsets.py @@ -0,0 +1,799 @@ +from itertools import product + +import numpy as np +import pytest + +from xarray import CFTimeIndex +from xarray.coding.cftime_offsets import ( + _MONTH_ABBREVIATIONS, BaseCFTimeOffset, Day, Hour, Minute, MonthBegin, + MonthEnd, Second, YearBegin, YearEnd, _days_in_month, cftime_range, + get_date_type, to_cftime_datetime, to_offset) + +cftime = pytest.importorskip('cftime') + + +_CFTIME_CALENDARS = ['365_day', '360_day', 'julian', 'all_leap', + '366_day', 'gregorian', 'proleptic_gregorian', 'standard'] + + +def _id_func(param): + """Called on each parameter passed to pytest.mark.parametrize""" + return str(param) + + +@pytest.fixture(params=_CFTIME_CALENDARS) +def calendar(request): + return request.param + + +@pytest.mark.parametrize( + ('offset', 'expected_n'), + [(BaseCFTimeOffset(), 1), + (YearBegin(), 1), + (YearEnd(), 1), + (BaseCFTimeOffset(n=2), 2), + (YearBegin(n=2), 2), + (YearEnd(n=2), 2)], + ids=_id_func +) +def test_cftime_offset_constructor_valid_n(offset, expected_n): + assert offset.n == expected_n + + +@pytest.mark.parametrize( + ('offset', 'invalid_n'), + [(BaseCFTimeOffset, 1.5), + (YearBegin, 1.5), + (YearEnd, 1.5)], + ids=_id_func +) +def test_cftime_offset_constructor_invalid_n(offset, invalid_n): + with pytest.raises(TypeError): + offset(n=invalid_n) + + +@pytest.mark.parametrize( + ('offset', 'expected_month'), + [(YearBegin(), 1), + (YearEnd(), 12), + (YearBegin(month=5), 5), + (YearEnd(month=5), 5)], + ids=_id_func +) +def test_year_offset_constructor_valid_month(offset, expected_month): + assert offset.month == expected_month + + +@pytest.mark.parametrize( + ('offset', 'invalid_month', 'exception'), + [(YearBegin, 0, ValueError), + (YearEnd, 0, ValueError), + (YearBegin, 13, ValueError,), + (YearEnd, 13, ValueError), + (YearBegin, 1.5, TypeError), + (YearEnd, 1.5, TypeError)], + ids=_id_func +) +def test_year_offset_constructor_invalid_month( + offset, invalid_month, exception): + with pytest.raises(exception): + offset(month=invalid_month) + + +@pytest.mark.parametrize( + ('offset', 'expected'), + [(BaseCFTimeOffset(), None), + (MonthBegin(), 'MS'), + (YearBegin(), 'AS-JAN')], + ids=_id_func +) +def test_rule_code(offset, expected): + assert offset.rule_code() == expected + + +@pytest.mark.parametrize( + ('offset', 'expected'), + [(BaseCFTimeOffset(), ''), + (YearBegin(), '')], + ids=_id_func +) +def test_str_and_repr(offset, expected): + assert str(offset) == expected + assert repr(offset) == expected + + +@pytest.mark.parametrize( + 'offset', + [BaseCFTimeOffset(), MonthBegin(), YearBegin()], + ids=_id_func +) +def test_to_offset_offset_input(offset): + assert to_offset(offset) == offset + + +@pytest.mark.parametrize( + ('freq', 'expected'), + [('M', MonthEnd()), + ('2M', MonthEnd(n=2)), + ('MS', MonthBegin()), + ('2MS', MonthBegin(n=2)), + ('D', Day()), + ('2D', Day(n=2)), + ('H', Hour()), + ('2H', Hour(n=2)), + ('T', Minute()), + ('2T', Minute(n=2)), + ('min', Minute()), + ('2min', Minute(n=2)), + ('S', Second()), + ('2S', Second(n=2))], + ids=_id_func +) +def test_to_offset_sub_annual(freq, expected): + assert to_offset(freq) == expected + + +_ANNUAL_OFFSET_TYPES = { + 'A': YearEnd, + 'AS': YearBegin +} + + +@pytest.mark.parametrize(('month_int', 'month_label'), + list(_MONTH_ABBREVIATIONS.items()) + [('', '')]) +@pytest.mark.parametrize('multiple', [None, 2]) +@pytest.mark.parametrize('offset_str', ['AS', 'A']) +def test_to_offset_annual(month_label, month_int, multiple, offset_str): + freq = offset_str + offset_type = _ANNUAL_OFFSET_TYPES[offset_str] + if month_label: + freq = '-'.join([freq, month_label]) + if multiple: + freq = '{}'.format(multiple) + freq + result = to_offset(freq) + + if multiple and month_int: + expected = offset_type(n=multiple, month=month_int) + elif multiple: + expected = offset_type(n=multiple) + elif month_int: + expected = offset_type(month=month_int) + else: + expected = offset_type() + assert result == expected + + +@pytest.mark.parametrize('freq', ['Z', '7min2', 'AM', 'M-', 'AS-', '1H1min']) +def test_invalid_to_offset_str(freq): + with pytest.raises(ValueError): + to_offset(freq) + + +@pytest.mark.parametrize( + ('argument', 'expected_date_args'), + [('2000-01-01', (2000, 1, 1)), + ((2000, 1, 1), (2000, 1, 1))], + ids=_id_func +) +def test_to_cftime_datetime(calendar, argument, expected_date_args): + date_type = get_date_type(calendar) + expected = date_type(*expected_date_args) + if isinstance(argument, tuple): + argument = date_type(*argument) + result = to_cftime_datetime(argument, calendar=calendar) + assert result == expected + + +def test_to_cftime_datetime_error_no_calendar(): + with pytest.raises(ValueError): + to_cftime_datetime('2000') + + +def test_to_cftime_datetime_error_type_error(): + with pytest.raises(TypeError): + to_cftime_datetime(1) + + +_EQ_TESTS_A = [ + BaseCFTimeOffset(), YearBegin(), YearEnd(), YearBegin(month=2), + YearEnd(month=2), MonthBegin(), MonthEnd(), Day(), Hour(), Minute(), + Second() +] +_EQ_TESTS_B = [ + BaseCFTimeOffset(n=2), YearBegin(n=2), YearEnd(n=2), + YearBegin(n=2, month=2), YearEnd(n=2, month=2), MonthBegin(n=2), + MonthEnd(n=2), Day(n=2), Hour(n=2), Minute(n=2), Second(n=2) +] + + +@pytest.mark.parametrize( + ('a', 'b'), product(_EQ_TESTS_A, _EQ_TESTS_B), ids=_id_func +) +def test_neq(a, b): + assert a != b + + +_EQ_TESTS_B_COPY = [ + BaseCFTimeOffset(n=2), YearBegin(n=2), YearEnd(n=2), + YearBegin(n=2, month=2), YearEnd(n=2, month=2), MonthBegin(n=2), + MonthEnd(n=2), Day(n=2), Hour(n=2), Minute(n=2), Second(n=2) +] + + +@pytest.mark.parametrize( + ('a', 'b'), zip(_EQ_TESTS_B, _EQ_TESTS_B_COPY), ids=_id_func +) +def test_eq(a, b): + assert a == b + + +_MUL_TESTS = [ + (BaseCFTimeOffset(), BaseCFTimeOffset(n=3)), + (YearEnd(), YearEnd(n=3)), + (YearBegin(), YearBegin(n=3)), + (MonthEnd(), MonthEnd(n=3)), + (MonthBegin(), MonthBegin(n=3)), + (Day(), Day(n=3)), + (Hour(), Hour(n=3)), + (Minute(), Minute(n=3)), + (Second(), Second(n=3)) +] + + +@pytest.mark.parametrize(('offset', 'expected'), _MUL_TESTS, ids=_id_func) +def test_mul(offset, expected): + assert offset * 3 == expected + + +@pytest.mark.parametrize(('offset', 'expected'), _MUL_TESTS, ids=_id_func) +def test_rmul(offset, expected): + assert 3 * offset == expected + + +@pytest.mark.parametrize( + ('offset', 'expected'), + [(BaseCFTimeOffset(), BaseCFTimeOffset(n=-1)), + (YearEnd(), YearEnd(n=-1)), + (YearBegin(), YearBegin(n=-1)), + (MonthEnd(), MonthEnd(n=-1)), + (MonthBegin(), MonthBegin(n=-1)), + (Day(), Day(n=-1)), + (Hour(), Hour(n=-1)), + (Minute(), Minute(n=-1)), + (Second(), Second(n=-1))], + ids=_id_func) +def test_neg(offset, expected): + assert -offset == expected + + +_ADD_TESTS = [ + (Day(n=2), (1, 1, 3)), + (Hour(n=2), (1, 1, 1, 2)), + (Minute(n=2), (1, 1, 1, 0, 2)), + (Second(n=2), (1, 1, 1, 0, 0, 2)) +] + + +@pytest.mark.parametrize( + ('offset', 'expected_date_args'), + _ADD_TESTS, + ids=_id_func +) +def test_add_sub_monthly(offset, expected_date_args, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 1) + expected = date_type(*expected_date_args) + result = offset + initial + assert result == expected + + +@pytest.mark.parametrize( + ('offset', 'expected_date_args'), + _ADD_TESTS, + ids=_id_func +) +def test_radd_sub_monthly(offset, expected_date_args, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 1) + expected = date_type(*expected_date_args) + result = initial + offset + assert result == expected + + +@pytest.mark.parametrize( + ('offset', 'expected_date_args'), + [(Day(n=2), (1, 1, 1)), + (Hour(n=2), (1, 1, 2, 22)), + (Minute(n=2), (1, 1, 2, 23, 58)), + (Second(n=2), (1, 1, 2, 23, 59, 58))], + ids=_id_func +) +def test_rsub_sub_monthly(offset, expected_date_args, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 3) + expected = date_type(*expected_date_args) + result = initial - offset + assert result == expected + + +@pytest.mark.parametrize('offset', _EQ_TESTS_A, ids=_id_func) +def test_sub_error(offset, calendar): + date_type = get_date_type(calendar) + initial = date_type(1, 1, 1) + with pytest.raises(TypeError): + offset - initial + + +@pytest.mark.parametrize( + ('a', 'b'), + zip(_EQ_TESTS_A, _EQ_TESTS_B), + ids=_id_func +) +def test_minus_offset(a, b): + result = b - a + expected = a + assert result == expected + + +@pytest.mark.parametrize( + ('a', 'b'), + list(zip(np.roll(_EQ_TESTS_A, 1), _EQ_TESTS_B)) + + [(YearEnd(month=1), YearEnd(month=2))], + ids=_id_func +) +def test_minus_offset_error(a, b): + with pytest.raises(TypeError): + b - a + + +def test_days_in_month_non_december(calendar): + date_type = get_date_type(calendar) + reference = date_type(1, 4, 1) + assert _days_in_month(reference) == 30 + + +def test_days_in_month_december(calendar): + if calendar == '360_day': + expected = 30 + else: + expected = 31 + date_type = get_date_type(calendar) + reference = date_type(1, 12, 5) + assert _days_in_month(reference) == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_date_args'), + [((1, 1, 1), MonthBegin(), (1, 2, 1)), + ((1, 1, 1), MonthBegin(n=2), (1, 3, 1)), + ((1, 1, 7), MonthBegin(), (1, 2, 1)), + ((1, 1, 7), MonthBegin(n=2), (1, 3, 1)), + ((1, 3, 1), MonthBegin(n=-1), (1, 2, 1)), + ((1, 3, 1), MonthBegin(n=-2), (1, 1, 1)), + ((1, 3, 3), MonthBegin(n=-1), (1, 3, 1)), + ((1, 3, 3), MonthBegin(n=-2), (1, 2, 1)), + ((1, 2, 1), MonthBegin(n=14), (2, 4, 1)), + ((2, 4, 1), MonthBegin(n=-14), (1, 2, 1)), + ((1, 1, 1, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)), + ((1, 1, 3, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)), + ((1, 1, 3, 5, 5, 5, 5), MonthBegin(n=-1), (1, 1, 1, 5, 5, 5, 5))], + ids=_id_func +) +def test_add_month_begin( + calendar, initial_date_args, offset, expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 1, 1), MonthEnd(), (1, 1), ()), + ((1, 1, 1), MonthEnd(n=2), (1, 2), ()), + ((1, 3, 1), MonthEnd(n=-1), (1, 2), ()), + ((1, 3, 1), MonthEnd(n=-2), (1, 1), ()), + ((1, 2, 1), MonthEnd(n=14), (2, 3), ()), + ((2, 4, 1), MonthEnd(n=-14), (1, 2), ()), + ((1, 1, 1, 5, 5, 5, 5), MonthEnd(), (1, 1), (5, 5, 5, 5)), + ((1, 2, 1, 5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_month_end( + calendar, initial_date_args, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_year_month', 'initial_sub_day', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 1), (), MonthEnd(), (1, 2), ()), + ((1, 1), (), MonthEnd(n=2), (1, 3), ()), + ((1, 3), (), MonthEnd(n=-1), (1, 2), ()), + ((1, 3), (), MonthEnd(n=-2), (1, 1), ()), + ((1, 2), (), MonthEnd(n=14), (2, 4), ()), + ((2, 4), (), MonthEnd(n=-14), (1, 2), ()), + ((1, 1), (5, 5, 5, 5), MonthEnd(), (1, 2), (5, 5, 5, 5)), + ((1, 2), (5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_month_end_onOffset( + calendar, initial_year_month, initial_sub_day, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + reference_args = initial_year_month + (1,) + reference = date_type(*reference_args) + initial_date_args = (initial_year_month + (_days_in_month(reference),) + + initial_sub_day) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_date_args'), + [((1, 1, 1), YearBegin(), (2, 1, 1)), + ((1, 1, 1), YearBegin(n=2), (3, 1, 1)), + ((1, 1, 1), YearBegin(month=2), (1, 2, 1)), + ((1, 1, 7), YearBegin(n=2), (3, 1, 1)), + ((2, 2, 1), YearBegin(n=-1), (2, 1, 1)), + ((1, 1, 2), YearBegin(n=-1), (1, 1, 1)), + ((1, 1, 1, 5, 5, 5, 5), YearBegin(), (2, 1, 1, 5, 5, 5, 5)), + ((2, 1, 1, 5, 5, 5, 5), YearBegin(n=-1), (1, 1, 1, 5, 5, 5, 5))], + ids=_id_func +) +def test_add_year_begin(calendar, initial_date_args, offset, + expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_date_args', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 1, 1), YearEnd(), (1, 12), ()), + ((1, 1, 1), YearEnd(n=2), (2, 12), ()), + ((1, 1, 1), YearEnd(month=1), (1, 1), ()), + ((2, 3, 1), YearEnd(n=-1), (1, 12), ()), + ((1, 3, 1), YearEnd(n=-1, month=2), (1, 2), ()), + ((1, 1, 1, 5, 5, 5, 5), YearEnd(), (1, 12), (5, 5, 5, 5)), + ((1, 1, 1, 5, 5, 5, 5), YearEnd(n=2), (2, 12), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_year_end( + calendar, initial_date_args, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +@pytest.mark.parametrize( + ('initial_year_month', 'initial_sub_day', 'offset', 'expected_year_month', + 'expected_sub_day'), + [((1, 12), (), YearEnd(), (2, 12), ()), + ((1, 12), (), YearEnd(n=2), (3, 12), ()), + ((2, 12), (), YearEnd(n=-1), (1, 12), ()), + ((3, 12), (), YearEnd(n=-2), (1, 12), ()), + ((1, 1), (), YearEnd(month=2), (1, 2), ()), + ((1, 12), (5, 5, 5, 5), YearEnd(), (2, 12), (5, 5, 5, 5)), + ((2, 12), (5, 5, 5, 5), YearEnd(n=-1), (1, 12), (5, 5, 5, 5))], + ids=_id_func +) +def test_add_year_end_onOffset( + calendar, initial_year_month, initial_sub_day, offset, expected_year_month, + expected_sub_day +): + date_type = get_date_type(calendar) + reference_args = initial_year_month + (1,) + reference = date_type(*reference_args) + initial_date_args = (initial_year_month + (_days_in_month(reference),) + + initial_sub_day) + initial = date_type(*initial_date_args) + result = initial + offset + reference_args = expected_year_month + (1,) + reference = date_type(*reference_args) + + # Here the days at the end of each month varies based on the calendar used + expected_date_args = (expected_year_month + + (_days_in_month(reference),) + expected_sub_day) + expected = date_type(*expected_date_args) + assert result == expected + + +# Note for all sub-monthly offsets, pandas always returns True for onOffset +@pytest.mark.parametrize( + ('date_args', 'offset', 'expected'), + [((1, 1, 1), MonthBegin(), True), + ((1, 1, 1, 1), MonthBegin(), True), + ((1, 1, 5), MonthBegin(), False), + ((1, 1, 5), MonthEnd(), False), + ((1, 1, 1), YearBegin(), True), + ((1, 1, 1, 1), YearBegin(), True), + ((1, 1, 5), YearBegin(), False), + ((1, 12, 1), YearEnd(), False), + ((1, 1, 1), Day(), True), + ((1, 1, 1, 1), Day(), True), + ((1, 1, 1), Hour(), True), + ((1, 1, 1), Minute(), True), + ((1, 1, 1), Second(), True)], + ids=_id_func +) +def test_onOffset(calendar, date_args, offset, expected): + date_type = get_date_type(calendar) + date = date_type(*date_args) + result = offset.onOffset(date) + assert result == expected + + +@pytest.mark.parametrize( + ('year_month_args', 'sub_day_args', 'offset'), + [((1, 1), (), MonthEnd()), + ((1, 1), (1,), MonthEnd()), + ((1, 12), (), YearEnd()), + ((1, 1), (), YearEnd(month=1))], + ids=_id_func +) +def test_onOffset_month_or_year_end( + calendar, year_month_args, sub_day_args, offset): + date_type = get_date_type(calendar) + reference_args = year_month_args + (1,) + reference = date_type(*reference_args) + date_args = year_month_args + (_days_in_month(reference),) + sub_day_args + date = date_type(*date_args) + result = offset.onOffset(date) + assert result + + +@pytest.mark.parametrize( + ('offset', 'initial_date_args', 'partial_expected_date_args'), + [(YearBegin(), (1, 3, 1), (2, 1)), + (YearBegin(), (1, 1, 1), (1, 1)), + (YearBegin(n=2), (1, 3, 1), (2, 1)), + (YearBegin(n=2, month=2), (1, 3, 1), (2, 2)), + (YearEnd(), (1, 3, 1), (1, 12)), + (YearEnd(n=2), (1, 3, 1), (1, 12)), + (YearEnd(n=2, month=2), (1, 3, 1), (2, 2)), + (YearEnd(n=2, month=4), (1, 4, 30), (1, 4)), + (MonthBegin(), (1, 3, 2), (1, 4)), + (MonthBegin(), (1, 3, 1), (1, 3)), + (MonthBegin(n=2), (1, 3, 2), (1, 4)), + (MonthEnd(), (1, 3, 2), (1, 3)), + (MonthEnd(), (1, 4, 30), (1, 4)), + (MonthEnd(n=2), (1, 3, 2), (1, 3)), + (Day(), (1, 3, 2, 1), (1, 3, 2, 1)), + (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), + (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), + (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1))], + ids=_id_func +) +def test_rollforward(calendar, offset, initial_date_args, + partial_expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + if isinstance(offset, (MonthBegin, YearBegin)): + expected_date_args = partial_expected_date_args + (1,) + elif isinstance(offset, (MonthEnd, YearEnd)): + reference_args = partial_expected_date_args + (1,) + reference = date_type(*reference_args) + expected_date_args = (partial_expected_date_args + + (_days_in_month(reference),)) + else: + expected_date_args = partial_expected_date_args + expected = date_type(*expected_date_args) + result = offset.rollforward(initial) + assert result == expected + + +@pytest.mark.parametrize( + ('offset', 'initial_date_args', 'partial_expected_date_args'), + [(YearBegin(), (1, 3, 1), (1, 1)), + (YearBegin(n=2), (1, 3, 1), (1, 1)), + (YearBegin(n=2, month=2), (1, 3, 1), (1, 2)), + (YearBegin(), (1, 1, 1), (1, 1)), + (YearBegin(n=2, month=2), (1, 2, 1), (1, 2)), + (YearEnd(), (2, 3, 1), (1, 12)), + (YearEnd(n=2), (2, 3, 1), (1, 12)), + (YearEnd(n=2, month=2), (2, 3, 1), (2, 2)), + (YearEnd(month=4), (1, 4, 30), (1, 4)), + (MonthBegin(), (1, 3, 2), (1, 3)), + (MonthBegin(n=2), (1, 3, 2), (1, 3)), + (MonthBegin(), (1, 3, 1), (1, 3)), + (MonthEnd(), (1, 3, 2), (1, 2)), + (MonthEnd(n=2), (1, 3, 2), (1, 2)), + (MonthEnd(), (1, 4, 30), (1, 4)), + (Day(), (1, 3, 2, 1), (1, 3, 2, 1)), + (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), + (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), + (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1))], + ids=_id_func +) +def test_rollback(calendar, offset, initial_date_args, + partial_expected_date_args): + date_type = get_date_type(calendar) + initial = date_type(*initial_date_args) + if isinstance(offset, (MonthBegin, YearBegin)): + expected_date_args = partial_expected_date_args + (1,) + elif isinstance(offset, (MonthEnd, YearEnd)): + reference_args = partial_expected_date_args + (1,) + reference = date_type(*reference_args) + expected_date_args = (partial_expected_date_args + + (_days_in_month(reference),)) + else: + expected_date_args = partial_expected_date_args + expected = date_type(*expected_date_args) + result = offset.rollback(initial) + assert result == expected + + +_CFTIME_RANGE_TESTS = [ + ('0001-01-01', '0001-01-04', None, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-01', '0001-01-04', None, 'D', 'left', False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3)]), + ('0001-01-01', '0001-01-04', None, 'D', 'right', False, + [(1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-01T01:00:00', '0001-01-04', None, 'D', None, False, + [(1, 1, 1, 1), (1, 1, 2, 1), (1, 1, 3, 1)]), + ('0001-01-01T01:00:00', '0001-01-04', None, 'D', None, True, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-01', None, 4, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + (None, '0001-01-04', 4, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ((1, 1, 1), '0001-01-04', None, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ((1, 1, 1), (1, 1, 4), None, 'D', None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]), + ('0001-01-30', '0011-02-01', None, '3AS-JUN', None, False, + [(1, 6, 1), (4, 6, 1), (7, 6, 1), (10, 6, 1)]), + ('0001-01-04', '0001-01-01', None, 'D', None, False, + []), + ('0010', None, 4, YearBegin(n=-2), None, False, + [(10, 1, 1), (8, 1, 1), (6, 1, 1), (4, 1, 1)]), + ('0001-01-01', '0001-01-04', 4, None, None, False, + [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]) +] + + +@pytest.mark.parametrize( + ('start', 'end', 'periods', 'freq', 'closed', 'normalize', + 'expected_date_args'), + _CFTIME_RANGE_TESTS, ids=_id_func +) +def test_cftime_range( + start, end, periods, freq, closed, normalize, calendar, + expected_date_args): + date_type = get_date_type(calendar) + expected_dates = [date_type(*args) for args in expected_date_args] + + if isinstance(start, tuple): + start = date_type(*start) + if isinstance(end, tuple): + end = date_type(*end) + + result = cftime_range( + start=start, end=end, periods=periods, freq=freq, closed=closed, + normalize=normalize, calendar=calendar) + resulting_dates = result.values + + assert isinstance(result, CFTimeIndex) + + if freq is not None: + np.testing.assert_equal(resulting_dates, expected_dates) + else: + # If we create a linear range of dates using cftime.num2date + # we will not get exact round number dates. This is because + # datetime arithmetic in cftime is accurate approximately to + # 1 millisecond (see https://unidata.github.io/cftime/api.html). + deltas = resulting_dates - expected_dates + deltas = np.array([delta.total_seconds() for delta in deltas]) + assert np.max(np.abs(deltas)) < 0.001 + + +def test_cftime_range_name(): + result = cftime_range(start='2000', periods=4, name='foo') + assert result.name == 'foo' + + result = cftime_range(start='2000', periods=4) + assert result.name is None + + +@pytest.mark.parametrize( + ('start', 'end', 'periods', 'freq', 'closed'), + [(None, None, 5, 'A', None), + ('2000', None, None, 'A', None), + (None, '2000', None, 'A', None), + ('2000', '2001', None, None, None), + (None, None, None, None, None), + ('2000', '2001', None, 'A', 'up'), + ('2000', '2001', 5, 'A', None)] +) +def test_invalid_cftime_range_inputs(start, end, periods, freq, closed): + with pytest.raises(ValueError): + cftime_range(start, end, periods, freq, closed=closed) + + +_CALENDAR_SPECIFIC_MONTH_END_TESTS = [ + ('2M', 'noleap', + [(2, 28), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', 'all_leap', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', '360_day', + [(2, 30), (4, 30), (6, 30), (8, 30), (10, 30), (12, 30)]), + ('2M', 'standard', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', 'gregorian', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ('2M', 'julian', + [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]) +] + + +@pytest.mark.parametrize( + ('freq', 'calendar', 'expected_month_day'), + _CALENDAR_SPECIFIC_MONTH_END_TESTS, ids=_id_func +) +def test_calendar_specific_month_end(freq, calendar, expected_month_day): + year = 2000 # Use a leap-year to highlight calendar differences + result = cftime_range( + start='2000-02', end='2001', freq=freq, calendar=calendar).values + date_type = get_date_type(calendar) + expected = [date_type(year, *args) for args in expected_month_day] + np.testing.assert_equal(result, expected) + + +@pytest.mark.parametrize( + ('calendar', 'start', 'end', 'expected_number_of_days'), + [('noleap', '2000', '2001', 365), + ('all_leap', '2000', '2001', 366), + ('360_day', '2000', '2001', 360), + ('standard', '2000', '2001', 366), + ('gregorian', '2000', '2001', 366), + ('julian', '2000', '2001', 366), + ('noleap', '2001', '2002', 365), + ('all_leap', '2001', '2002', 366), + ('360_day', '2001', '2002', 360), + ('standard', '2001', '2002', 365), + ('gregorian', '2001', '2002', 365), + ('julian', '2001', '2002', 365)] +) +def test_calendar_year_length( + calendar, start, end, expected_number_of_days): + result = cftime_range(start, end, freq='D', closed='left', + calendar=calendar) + assert len(result) == expected_number_of_days diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index 6f102b60b9d..d1726ab3313 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -1,17 +1,18 @@ from __future__ import absolute_import -import pytest +from datetime import timedelta +import numpy as np import pandas as pd -import xarray as xr +import pytest -from datetime import timedelta +import xarray as xr from xarray.coding.cftimeindex import ( - parse_iso8601, CFTimeIndex, assert_all_valid_date_type, - _parsed_string_to_bounds, _parse_iso8601_with_reso) + CFTimeIndex, _parse_array_of_cftime_strings, _parse_iso8601_with_reso, + _parsed_string_to_bounds, assert_all_valid_date_type, parse_iso8601) from xarray.tests import assert_array_equal, assert_identical -from . import has_cftime, has_cftime_or_netCDF4 +from . import has_cftime, has_cftime_or_netCDF4, requires_cftime from .test_coding_times import _all_cftime_date_types @@ -121,22 +122,42 @@ def dec_days(date_type): return 31 +@pytest.fixture +def index_with_name(date_type): + dates = [date_type(1, 1, 1), date_type(1, 2, 1), + date_type(2, 1, 1), date_type(2, 2, 1)] + return CFTimeIndex(dates, name='foo') + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize( + ('name', 'expected_name'), + [('bar', 'bar'), + (None, 'foo')]) +def test_constructor_with_name(index_with_name, name, expected_name): + result = CFTimeIndex(index_with_name, name=name).name + assert result == expected_name + + @pytest.mark.skipif(not has_cftime, reason='cftime not installed') def test_assert_all_valid_date_type(date_type, index): import cftime if date_type is cftime.DatetimeNoLeap: - mixed_date_types = [date_type(1, 1, 1), - cftime.DatetimeAllLeap(1, 2, 1)] + mixed_date_types = np.array( + [date_type(1, 1, 1), + cftime.DatetimeAllLeap(1, 2, 1)]) else: - mixed_date_types = [date_type(1, 1, 1), - cftime.DatetimeNoLeap(1, 2, 1)] + mixed_date_types = np.array( + [date_type(1, 1, 1), + cftime.DatetimeNoLeap(1, 2, 1)]) with pytest.raises(TypeError): assert_all_valid_date_type(mixed_date_types) with pytest.raises(TypeError): - assert_all_valid_date_type([1, date_type(1, 1, 1)]) + assert_all_valid_date_type(np.array([1, date_type(1, 1, 1)])) - assert_all_valid_date_type([date_type(1, 1, 1), date_type(1, 2, 1)]) + assert_all_valid_date_type( + np.array([date_type(1, 1, 1), date_type(1, 2, 1)])) @pytest.mark.skipif(not has_cftime, reason='cftime not installed') @@ -589,3 +610,95 @@ def test_concat_cftimeindex(date_type, enable_cftimeindex): else: assert isinstance(da.indexes['time'], pd.Index) assert not isinstance(da.indexes['time'], CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_empty_cftimeindex(): + index = CFTimeIndex([]) + assert index.date_type is None + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_add(index): + date_type = index.date_type + expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2), + date_type(2, 1, 2), date_type(2, 2, 2)] + expected = CFTimeIndex(expected_dates) + result = index + timedelta(days=1) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_radd(index): + date_type = index.date_type + expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2), + date_type(2, 1, 2), date_type(2, 2, 2)] + expected = CFTimeIndex(expected_dates) + result = timedelta(days=1) + index + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_sub(index): + date_type = index.date_type + expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2), + date_type(2, 1, 2), date_type(2, 2, 2)] + expected = CFTimeIndex(expected_dates) + result = index + timedelta(days=2) + result = result - timedelta(days=1) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_rsub(index): + with pytest.raises(TypeError): + timedelta(days=1) - index + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('freq', ['D', timedelta(days=1)]) +def test_cftimeindex_shift(index, freq): + date_type = index.date_type + expected_dates = [date_type(1, 1, 3), date_type(1, 2, 3), + date_type(2, 1, 3), date_type(2, 2, 3)] + expected = CFTimeIndex(expected_dates) + result = index.shift(2, freq) + assert result.equals(expected) + assert isinstance(result, CFTimeIndex) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_shift_invalid_n(): + index = xr.cftime_range('2000', periods=3) + with pytest.raises(TypeError): + index.shift('a', 'D') + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_cftimeindex_shift_invalid_freq(): + index = xr.cftime_range('2000', periods=3) + with pytest.raises(TypeError): + index.shift(1, 1) + + +@requires_cftime +def test_parse_array_of_cftime_strings(): + from cftime import DatetimeNoLeap + + strings = np.array([['2000-01-01', '2000-01-02'], + ['2000-01-03', '2000-01-04']]) + expected = np.array( + [[DatetimeNoLeap(2000, 1, 1), DatetimeNoLeap(2000, 1, 2)], + [DatetimeNoLeap(2000, 1, 3), DatetimeNoLeap(2000, 1, 4)]]) + + result = _parse_array_of_cftime_strings(strings, DatetimeNoLeap) + np.testing.assert_array_equal(result, expected) + + # Test scalar array case + strings = np.array('2000-01-01') + expected = np.array(DatetimeNoLeap(2000, 1, 1)) + result = _parse_array_of_cftime_strings(strings, DatetimeNoLeap) + np.testing.assert_array_equal(result, expected) diff --git a/xarray/tests/test_coding_strings.py b/xarray/tests/test_coding_strings.py index 53d028e164b..ca138ca8362 100644 --- a/xarray/tests/test_coding_strings.py +++ b/xarray/tests/test_coding_strings.py @@ -5,13 +5,13 @@ import pytest from xarray import Variable -from xarray.core.pycompat import bytes_type, unicode_type, suppress from xarray.coding import strings from xarray.core import indexing +from xarray.core.pycompat import bytes_type, suppress, unicode_type -from . import (IndexerMaker, assert_array_equal, assert_identical, - raises_regex, requires_dask) - +from . import ( + IndexerMaker, assert_array_equal, assert_identical, raises_regex, + requires_dask) with suppress(ImportError): import dask.array as da diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index e763af4984c..10a1a956b27 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -1,20 +1,20 @@ from __future__ import absolute_import, division, print_function -from itertools import product import warnings +from itertools import product import numpy as np import pandas as pd import pytest -from xarray import Variable, coding, set_options, DataArray, decode_cf +from xarray import DataArray, Variable, coding, decode_cf, set_options from xarray.coding.times import _import_cftime from xarray.coding.variables import SerializationWarning from xarray.core.common import contains_cftime_datetimes -from . import (assert_array_equal, has_cftime_or_netCDF4, - requires_cftime_or_netCDF4, has_cftime, has_dask) - +from . import ( + assert_array_equal, has_cftime, has_cftime_or_netCDF4, has_dask, + requires_cftime_or_netCDF4) _NON_STANDARD_CALENDARS_SET = {'noleap', '365_day', '360_day', 'julian', 'all_leap', '366_day'} @@ -538,7 +538,8 @@ def test_cf_datetime_nan(num_dates, units, expected_list): with warnings.catch_warnings(): warnings.filterwarnings('ignore', 'All-NaN') actual = coding.times.decode_cf_datetime(num_dates, units) - expected = np.array(expected_list, dtype='datetime64[ns]') + # use pandas because numpy will deprecate timezone-aware conversions + expected = pd.to_datetime(expected_list) assert_array_equal(expected, actual) diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 482a280b355..2004b1e660f 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -10,12 +10,12 @@ from xarray.core.pycompat import OrderedDict, iteritems from . import ( - InaccessibleArray, TestCase, assert_array_equal, assert_equal, - assert_identical, raises_regex, requires_dask) + InaccessibleArray, assert_array_equal, assert_equal, assert_identical, + raises_regex, requires_dask) from .test_dataset import create_test_data -class TestConcatDataset(TestCase): +class TestConcatDataset(object): def test_concat(self): # TODO: simplify and split this test case @@ -235,7 +235,7 @@ def test_concat_multiindex(self): assert isinstance(actual.x.to_index(), pd.MultiIndex) -class TestConcatDataArray(TestCase): +class TestConcatDataArray(object): def test_concat(self): ds = Dataset({'foo': (['x', 'y'], np.random.random((2, 3))), 'bar': (['x', 'y'], np.random.random((2, 3)))}, @@ -295,7 +295,7 @@ def test_concat_lazy(self): assert combined.dims == ('z', 'x', 'y') -class TestAutoCombine(TestCase): +class TestAutoCombine(object): @requires_dask # only for toolz def test_auto_combine(self): diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index ca8e4e59737..1003c531018 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -15,7 +15,7 @@ join_dict_keys, ordered_set_intersection, ordered_set_union, result_name, unified_dim_sizes) -from . import raises_regex, requires_dask, has_dask +from . import has_dask, raises_regex, requires_dask def assert_identical(a, b): diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 5ed482ed2bd..a067d01a308 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -8,20 +8,20 @@ import pandas as pd import pytest -from xarray import (Dataset, Variable, SerializationWarning, coding, - conventions, open_dataset) +from xarray import ( + Dataset, SerializationWarning, Variable, coding, conventions, open_dataset) from xarray.backends.common import WritableCFDataStore from xarray.backends.memory import InMemoryDataStore from xarray.conventions import decode_cf from xarray.testing import assert_identical from . import ( - TestCase, assert_array_equal, raises_regex, requires_netCDF4, - requires_cftime_or_netCDF4, unittest, requires_dask) + assert_array_equal, raises_regex, requires_cftime_or_netCDF4, + requires_dask, requires_netCDF4) from .test_backends import CFEncodedDataTest -class TestBoolTypeArray(TestCase): +class TestBoolTypeArray(object): def test_booltype_array(self): x = np.array([1, 0, 1, 1, 0], dtype='i1') bx = conventions.BoolTypeArray(x) @@ -30,7 +30,7 @@ def test_booltype_array(self): dtype=np.bool)) -class TestNativeEndiannessArray(TestCase): +class TestNativeEndiannessArray(object): def test(self): x = np.arange(5, dtype='>i8') expected = np.arange(5, dtype='int64') @@ -69,7 +69,7 @@ def test_decode_cf_with_conflicting_fill_missing_value(): @requires_cftime_or_netCDF4 -class TestEncodeCFVariable(TestCase): +class TestEncodeCFVariable(object): def test_incompatible_attributes(self): invalid_vars = [ Variable(['t'], pd.date_range('2000-01-01', periods=3), @@ -134,7 +134,7 @@ def test_string_object_warning(self): @requires_cftime_or_netCDF4 -class TestDecodeCF(TestCase): +class TestDecodeCF(object): def test_dataset(self): original = Dataset({ 't': ('t', [0, 1, 2], {'units': 'days since 2000-01-01'}), @@ -255,7 +255,7 @@ def encode_variable(self, var): @requires_netCDF4 -class TestCFEncodedDataStore(CFEncodedDataTest, TestCase): +class TestCFEncodedDataStore(CFEncodedDataTest): @contextlib.contextmanager def create_store(self): yield CFEncodedInMemoryStore() @@ -267,9 +267,10 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}, data.dump_to_store(store, **save_kwargs) yield open_dataset(store, **open_kwargs) + @pytest.mark.skip('cannot roundtrip coordinates yet for ' + 'CFEncodedInMemoryStore') def test_roundtrip_coordinates(self): - raise unittest.SkipTest('cannot roundtrip coordinates yet for ' - 'CFEncodedInMemoryStore') + pass def test_invalid_dataarray_names_raise(self): # only relevant for on-disk file formats diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index f6c47cce8d8..e56f751bef9 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import pickle +from distutils.version import LooseVersion from textwrap import dedent import numpy as np @@ -14,18 +15,22 @@ from xarray.tests import mock from . import ( - TestCase, assert_allclose, assert_array_equal, assert_equal, - assert_frame_equal, assert_identical, raises_regex) + assert_allclose, assert_array_equal, assert_equal, assert_frame_equal, + assert_identical, raises_regex) dask = pytest.importorskip('dask') da = pytest.importorskip('dask.array') dd = pytest.importorskip('dask.dataframe') -class DaskTestCase(TestCase): +class DaskTestCase(object): def assertLazyAnd(self, expected, actual, test): - with dask.set_options(get=dask.get): + + with (dask.config.set(get=dask.get) + if LooseVersion(dask.__version__) >= LooseVersion('0.18.0') + else dask.set_options(get=dask.get)): test(actual, expected) + if isinstance(actual, Dataset): for k, v in actual.variables.items(): if k in actual.dims: @@ -52,6 +57,7 @@ def assertLazyAndIdentical(self, expected, actual): def assertLazyAndAllClose(self, expected, actual): self.assertLazyAnd(expected, actual, assert_allclose) + @pytest.fixture(autouse=True) def setUp(self): self.values = np.random.RandomState(0).randn(4, 6) self.data = da.from_array(self.values, chunks=(2, 2)) @@ -196,11 +202,13 @@ def test_missing_methods(self): except NotImplementedError as err: assert 'dask' in str(err) + @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning') def test_univariate_ufunc(self): u = self.eager_var v = self.lazy_var self.assertLazyAndAllClose(np.sin(u), xu.sin(v)) + @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning') def test_bivariate_ufunc(self): u = self.eager_var v = self.lazy_var @@ -242,6 +250,7 @@ def assertLazyAndAllClose(self, expected, actual): def assertLazyAndEqual(self, expected, actual): self.assertLazyAnd(expected, actual, assert_equal) + @pytest.fixture(autouse=True) def setUp(self): self.values = np.random.randn(4, 6) self.data = da.from_array(self.values, chunks=(2, 2)) @@ -378,8 +387,8 @@ def test_groupby(self): u = self.eager_array v = self.lazy_array - expected = u.groupby('x').mean() - actual = v.groupby('x').mean() + expected = u.groupby('x').mean(xr.ALL_DIMS) + actual = v.groupby('x').mean(xr.ALL_DIMS) self.assertLazyAndAllClose(expected, actual) def test_groupby_first(self): @@ -421,6 +430,7 @@ def duplicate_and_merge(array): actual = duplicate_and_merge(self.lazy_array) self.assertLazyAndEqual(expected, actual) + @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning') def test_ufuncs(self): u = self.eager_array v = self.lazy_array @@ -573,7 +583,7 @@ def test_from_dask_variable(self): self.assertLazyAndIdentical(self.lazy_array, a) -class TestToDaskDataFrame(TestCase): +class TestToDaskDataFrame(object): def test_to_dask_dataframe(self): # Test conversion of Datasets to dask DataFrames @@ -821,7 +831,9 @@ def test_basic_compute(): dask.multiprocessing.get, dask.local.get_sync, None]: - with dask.set_options(get=get): + with (dask.config.set(get=get) + if LooseVersion(dask.__version__) >= LooseVersion('0.18.0') + else dask.set_options(get=get)): ds.compute() ds.foo.compute() ds.foo.variable.compute() diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 2950e97cc75..d15a0bb6081 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1,9 +1,9 @@ from __future__ import absolute_import, division, print_function import pickle +import warnings from copy import deepcopy from textwrap import dedent -import warnings import numpy as np import pandas as pd @@ -12,19 +12,20 @@ import xarray as xr from xarray import ( DataArray, Dataset, IndexVariable, Variable, align, broadcast, set_options) -from xarray.convert import from_cdms2 from xarray.coding.times import CFDatetimeCoder, _import_cftime -from xarray.core.common import full_like +from xarray.convert import from_cdms2 +from xarray.core.common import ALL_DIMS, full_like from xarray.core.pycompat import OrderedDict, iteritems from xarray.tests import ( - ReturnItem, TestCase, assert_allclose, assert_array_equal, assert_equal, + ReturnItem, assert_allclose, assert_array_equal, assert_equal, assert_identical, raises_regex, requires_bottleneck, requires_cftime, requires_dask, requires_iris, requires_np113, requires_scipy, - source_ndarray, unittest) + source_ndarray) -class TestDataArray(TestCase): - def setUp(self): +class TestDataArray(object): + @pytest.fixture(autouse=True) + def setup(self): self.attrs = {'attr1': 'value1', 'attr2': 2929} self.x = np.random.random((10, 20)) self.v = Variable(['x', 'y'], self.x) @@ -440,7 +441,7 @@ def test_getitem(self): assert_identical(self.ds['x'], x) assert_identical(self.ds['y'], y) - I = ReturnItem() # noqa: E741 # allow ambiguous name + I = ReturnItem() # noqa for i in [I[:], I[...], I[x.values], I[x.variable], I[x], I[x, y], I[x.values > -1], I[x.variable > -1], I[x > -1], I[x > -1, y > -1]]: @@ -672,6 +673,7 @@ def test_isel_types(self): assert_identical(da.isel(x=np.array([0], dtype="int64")), da.isel(x=np.array([0]))) + @pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_isel_fancy(self): shape = (10, 7, 6) np_array = np.random.random(shape) @@ -845,6 +847,7 @@ def test_isel_drop(self): selected = data.isel(x=0, drop=False) assert_identical(expected, selected) + @pytest.mark.filterwarnings("ignore:Dataset.isel_points") def test_isel_points(self): shape = (10, 5, 6) np_array = np.random.random(shape) @@ -999,7 +1002,7 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False, assert da.dims[0] == renamed_dim da = da.rename({renamed_dim: 'x'}) assert_identical(da.variable, expected_da.variable) - self.assertVariableNotEqual(da['x'], expected_da['x']) + assert not da['x'].equals(expected_da['x']) test_sel(('a', 1, -1), 0) test_sel(('b', 2, -2), -1) @@ -1237,6 +1240,7 @@ def test_reindex_like_no_index(self): ValueError, 'different size for unlabeled'): foo.reindex_like(bar) + @pytest.mark.filterwarnings('ignore:Indexer has dimensions') def test_reindex_regressions(self): # regression test for #279 expected = DataArray(np.random.randn(5), coords=[("time", range(5))]) @@ -1286,7 +1290,7 @@ def test_swap_dims(self): def test_expand_dims_error(self): array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'], - coords={'x': np.linspace(0.0, 1.0, 3.0)}, + coords={'x': np.linspace(0.0, 1.0, 3)}, attrs={'key': 'entry'}) with raises_regex(ValueError, 'dim should be str or'): @@ -1660,9 +1664,23 @@ def test_dataset_math(self): def test_stack_unstack(self): orig = DataArray([[0, 1], [2, 3]], dims=['x', 'y'], attrs={'foo': 2}) + assert_identical(orig, orig.unstack()) + actual = orig.stack(z=['x', 'y']).unstack('z').drop(['x', 'y']) assert_identical(orig, actual) + dims = ['a', 'b', 'c', 'd', 'e'] + orig = xr.DataArray(np.random.rand(1, 2, 3, 2, 1), dims=dims) + stacked = orig.stack(ab=['a', 'b'], cd=['c', 'd']) + + unstacked = stacked.unstack(['ab', 'cd']) + roundtripped = unstacked.drop(['a', 'b', 'c', 'd']).transpose(*dims) + assert_identical(orig, roundtripped) + + unstacked = stacked.unstack() + roundtripped = unstacked.drop(['a', 'b', 'c', 'd']).transpose(*dims) + assert_identical(orig, roundtripped) + def test_stack_unstack_decreasing_coordinate(self): # regression test for GH980 orig = DataArray(np.random.rand(3, 4), dims=('y', 'x'), @@ -1983,15 +2001,15 @@ def test_groupby_sum(self): self.x[:, 10:].sum(), self.x[:, 9:10].sum()]).T), 'abc': Variable(['abc'], np.array(['a', 'b', 'c']))})['foo'] - assert_allclose(expected_sum_all, grouped.reduce(np.sum)) - assert_allclose(expected_sum_all, grouped.sum()) + assert_allclose(expected_sum_all, grouped.reduce(np.sum, dim=ALL_DIMS)) + assert_allclose(expected_sum_all, grouped.sum(ALL_DIMS)) expected = DataArray([array['y'].values[idx].sum() for idx in [slice(9), slice(10, None), slice(9, 10)]], [['a', 'b', 'c']], ['abc']) actual = array['y'].groupby('abc').apply(np.sum) assert_allclose(expected, actual) - actual = array['y'].groupby('abc').sum() + actual = array['y'].groupby('abc').sum(ALL_DIMS) assert_allclose(expected, actual) expected_sum_axis1 = Dataset( @@ -2002,6 +2020,29 @@ def test_groupby_sum(self): assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, 'y')) assert_allclose(expected_sum_axis1, grouped.sum('y')) + def test_groupby_warning(self): + array = self.make_groupby_example_array() + grouped = array.groupby('y') + with pytest.warns(FutureWarning): + grouped.sum() + + # Currently disabled due to https://github.com/pydata/xarray/issues/2468 + # @pytest.mark.skipif(LooseVersion(xr.__version__) < LooseVersion('0.12'), + # reason="not to forget the behavior change") + @pytest.mark.skip + def test_groupby_sum_default(self): + array = self.make_groupby_example_array() + grouped = array.groupby('abc') + + expected_sum_all = Dataset( + {'foo': Variable(['x', 'abc'], + np.array([self.x[:, :9].sum(axis=-1), + self.x[:, 10:].sum(axis=-1), + self.x[:, 9:10].sum(axis=-1)]).T), + 'abc': Variable(['abc'], np.array(['a', 'b', 'c']))})['foo'] + + assert_allclose(expected_sum_all, grouped.sum()) + def test_groupby_count(self): array = DataArray( [0, 0, np.nan, np.nan, 0, 0], @@ -2011,7 +2052,7 @@ def test_groupby_count(self): expected = DataArray([1, 1, 2], coords=[('cat', ['a', 'b', 'c'])]) assert_identical(actual, expected) - @unittest.skip('needs to be fixed for shortcut=False, keep_attrs=False') + @pytest.mark.skip('needs to be fixed for shortcut=False, keep_attrs=False') def test_groupby_reduce_attrs(self): array = self.make_groupby_example_array() array.attrs['foo'] = 'bar' @@ -2082,9 +2123,9 @@ def test_groupby_math(self): assert_identical(expected, actual) grouped = array.groupby('abc') - expected_agg = (grouped.mean() - np.arange(3)).rename(None) + expected_agg = (grouped.mean(ALL_DIMS) - np.arange(3)).rename(None) actual = grouped - DataArray(range(3), [('abc', ['a', 'b', 'c'])]) - actual_agg = actual.groupby('abc').mean() + actual_agg = actual.groupby('abc').mean(ALL_DIMS) assert_allclose(expected_agg, actual_agg) with raises_regex(TypeError, 'only support binary ops'): @@ -2158,7 +2199,7 @@ def test_groupby_multidim(self): ('lon', DataArray([5, 28, 23], coords=[('lon', [30., 40., 50.])])), ('lat', DataArray([16, 40], coords=[('lat', [10., 20.])]))]: - actual_sum = array.groupby(dim).sum() + actual_sum = array.groupby(dim).sum(ALL_DIMS) assert_identical(expected_sum, actual_sum) def test_groupby_multidim_apply(self): @@ -2787,7 +2828,7 @@ def test_to_and_from_series(self): def test_series_categorical_index(self): # regression test for GH700 if not hasattr(pd, 'CategoricalIndex'): - raise unittest.SkipTest('requires pandas with CategoricalIndex') + pytest.skip('requires pandas with CategoricalIndex') s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list('aabbc'))) arr = DataArray(s) @@ -2915,7 +2956,6 @@ def test_to_masked_array(self): ma = da.to_masked_array() assert len(ma.mask) == N - @pytest.mark.xfail # GH:2332 TODO fix this in upstream? def test_to_and_from_cdms2_classic(self): """Classic with 1D axes""" pytest.importorskip('cdms2') @@ -2928,9 +2968,9 @@ def test_to_and_from_cdms2_classic(self): expected_coords = [IndexVariable('distance', [-2, 2]), IndexVariable('time', [0, 1, 2])] actual = original.to_cdms2() - assert_array_equal(actual, original) + assert_array_equal(actual.asma(), original) assert actual.id == original.name - self.assertItemsEqual(actual.getAxisIds(), original.dims) + assert tuple(actual.getAxisIds()) == original.dims for axis, coord in zip(actual.getAxisList(), expected_coords): assert axis.id == coord.name assert_array_equal(axis, coord.values) @@ -2944,13 +2984,12 @@ def test_to_and_from_cdms2_classic(self): assert_identical(original, roundtripped) back = from_cdms2(actual) - self.assertItemsEqual(original.dims, back.dims) - self.assertItemsEqual(original.coords.keys(), back.coords.keys()) + assert original.dims == back.dims + assert original.coords.keys() == back.coords.keys() for coord_name in original.coords.keys(): assert_array_equal(original.coords[coord_name], back.coords[coord_name]) - @pytest.mark.xfail # GH:2332 TODO fix this in upstream? def test_to_and_from_cdms2_sgrid(self): """Curvilinear (structured) grid @@ -2967,17 +3006,18 @@ def test_to_and_from_cdms2_sgrid(self): coords=OrderedDict(x=x, y=y, lon=lon, lat=lat), name='sst') actual = original.to_cdms2() - self.assertItemsEqual(actual.getAxisIds(), original.dims) - assert_array_equal(original.coords['lon'], actual.getLongitude()) - assert_array_equal(original.coords['lat'], actual.getLatitude()) + assert tuple(actual.getAxisIds()) == original.dims + assert_array_equal(original.coords['lon'], + actual.getLongitude().asma()) + assert_array_equal(original.coords['lat'], + actual.getLatitude().asma()) back = from_cdms2(actual) - self.assertItemsEqual(original.dims, back.dims) - self.assertItemsEqual(original.coords.keys(), back.coords.keys()) + assert original.dims == back.dims + assert set(original.coords.keys()) == set(back.coords.keys()) assert_array_equal(original.coords['lat'], back.coords['lat']) assert_array_equal(original.coords['lon'], back.coords['lon']) - @pytest.mark.xfail # GH:2332 TODO fix this in upstream? def test_to_and_from_cdms2_ugrid(self): """Unstructured grid""" pytest.importorskip('cdms2') @@ -2988,13 +3028,15 @@ def test_to_and_from_cdms2_ugrid(self): original = DataArray(np.arange(5), dims=['cell'], coords={'lon': lon, 'lat': lat, 'cell': cell}) actual = original.to_cdms2() - self.assertItemsEqual(actual.getAxisIds(), original.dims) - assert_array_equal(original.coords['lon'], actual.getLongitude()) - assert_array_equal(original.coords['lat'], actual.getLatitude()) + assert tuple(actual.getAxisIds()) == original.dims + assert_array_equal(original.coords['lon'], + actual.getLongitude().getValue()) + assert_array_equal(original.coords['lat'], + actual.getLatitude().getValue()) back = from_cdms2(actual) - self.assertItemsEqual(original.dims, back.dims) - self.assertItemsEqual(original.coords.keys(), back.coords.keys()) + assert set(original.dims) == set(back.dims) + assert set(original.coords.keys()) == set(back.coords.keys()) assert_array_equal(original.coords['lat'], back.coords['lat']) assert_array_equal(original.coords['lon'], back.coords['lon']) @@ -3087,24 +3129,51 @@ def test_coordinate_diff(self): actual = lon.diff('lon') assert_equal(expected, actual) - def test_shift(self): + @pytest.mark.parametrize('offset', [-5, -2, -1, 0, 1, 2, 5]) + def test_shift(self, offset): arr = DataArray([1, 2, 3], dims='x') actual = arr.shift(x=1) expected = DataArray([np.nan, 1, 2], dims='x') assert_identical(expected, actual) arr = DataArray([1, 2, 3], [('x', ['a', 'b', 'c'])]) - for offset in [-5, -2, -1, 0, 1, 2, 5]: - expected = DataArray(arr.to_pandas().shift(offset)) - actual = arr.shift(x=offset) - assert_identical(expected, actual) + expected = DataArray(arr.to_pandas().shift(offset)) + actual = arr.shift(x=offset) + assert_identical(expected, actual) - def test_roll(self): + def test_roll_coords(self): arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') - actual = arr.roll(x=1) + actual = arr.roll(x=1, roll_coords=True) expected = DataArray([3, 1, 2], coords=[('x', [2, 0, 1])]) assert_identical(expected, actual) + def test_roll_no_coords(self): + arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') + actual = arr.roll(x=1, roll_coords=False) + expected = DataArray([3, 1, 2], coords=[('x', [0, 1, 2])]) + assert_identical(expected, actual) + + def test_roll_coords_none(self): + arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x') + + with pytest.warns(FutureWarning): + actual = arr.roll(x=1, roll_coords=None) + + expected = DataArray([3, 1, 2], coords=[('x', [2, 0, 1])]) + assert_identical(expected, actual) + + def test_copy_with_data(self): + orig = DataArray(np.random.random(size=(2, 2)), + dims=('x', 'y'), + attrs={'attr1': 'value1'}, + coords={'x': [4, 3]}, + name='helloworld') + new_data = np.arange(4).reshape(2, 2) + actual = orig.copy(data=new_data) + expected = orig.copy() + expected.data = new_data + assert_identical(expected, actual) + def test_real_and_imag(self): array = DataArray(1 + 2j) assert_identical(array.real, DataArray(1)) @@ -3329,7 +3398,9 @@ def test_isin(da): def test_rolling_iter(da): rolling_obj = da.rolling(time=7) - rolling_obj_mean = rolling_obj.mean() + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Mean of empty slice') + rolling_obj_mean = rolling_obj.mean() assert len(rolling_obj.window_labels) == len(da['time']) assert_identical(rolling_obj.window_labels, da['time']) @@ -3337,8 +3408,10 @@ def test_rolling_iter(da): for i, (label, window_da) in enumerate(rolling_obj): assert label == da['time'].isel(time=i) - actual = rolling_obj_mean.isel(time=i) - expected = window_da.mean('time') + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Mean of empty slice') + actual = rolling_obj_mean.isel(time=i) + expected = window_da.mean('time') # TODO add assert_allclose_with_nan, which compares nan position # as well as the closeness of the values. diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 08d71d462d8..89704653e92 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function +import sys +import warnings from copy import copy, deepcopy from io import StringIO from textwrap import dedent -import warnings -import sys import numpy as np import pandas as pd @@ -13,17 +13,18 @@ import xarray as xr from xarray import ( - DataArray, Dataset, IndexVariable, MergeError, Variable, align, backends, - broadcast, open_dataset, set_options) -from xarray.core import indexing, utils + ALL_DIMS, DataArray, Dataset, IndexVariable, MergeError, Variable, align, + backends, broadcast, open_dataset, set_options) +from xarray.core import indexing, npcompat, utils from xarray.core.common import full_like from xarray.core.pycompat import ( OrderedDict, integer_types, iteritems, unicode_type) from . import ( - InaccessibleArray, TestCase, UnexpectedDataAccess, assert_allclose, - assert_array_equal, assert_equal, assert_identical, has_dask, raises_regex, - requires_bottleneck, requires_dask, requires_scipy, source_ndarray) + InaccessibleArray, UnexpectedDataAccess, assert_allclose, + assert_array_equal, assert_equal, assert_identical, has_cftime, has_dask, + raises_regex, requires_bottleneck, requires_dask, requires_scipy, + source_ndarray) try: import cPickle as pickle @@ -63,8 +64,8 @@ def create_test_multiindex(): class InaccessibleVariableDataStore(backends.InMemoryDataStore): - def __init__(self, writer=None): - super(InaccessibleVariableDataStore, self).__init__(writer) + def __init__(self): + super(InaccessibleVariableDataStore, self).__init__() self._indexvars = set() def store(self, variables, *args, **kwargs): @@ -85,7 +86,7 @@ def lazy_inaccessible(k, v): k, v in iteritems(self._variables)) -class TestDataset(TestCase): +class TestDataset(object): def test_repr(self): data = create_test_data(seed=123) data.attrs['foo'] = 'bar' @@ -398,7 +399,7 @@ def test_constructor_with_coords(self): ds = Dataset({}, {'a': ('x', [1])}) assert not ds.data_vars - self.assertItemsEqual(ds.coords.keys(), ['a']) + assert list(ds.coords.keys()) == ['a'] mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2]], names=('level_1', 'level_2')) @@ -420,9 +421,9 @@ def test_properties(self): assert type(ds.dims.mapping.mapping) is dict # noqa with pytest.warns(FutureWarning): - self.assertItemsEqual(ds, list(ds.variables)) + assert list(ds) == list(ds.variables) with pytest.warns(FutureWarning): - self.assertItemsEqual(ds.keys(), list(ds.variables)) + assert list(ds.keys()) == list(ds.variables) assert 'aasldfjalskdfj' not in ds.variables assert 'dim1' in repr(ds.variables) with pytest.warns(FutureWarning): @@ -430,18 +431,18 @@ def test_properties(self): with pytest.warns(FutureWarning): assert bool(ds) - self.assertItemsEqual(ds.data_vars, ['var1', 'var2', 'var3']) - self.assertItemsEqual(ds.data_vars.keys(), ['var1', 'var2', 'var3']) + assert list(ds.data_vars) == ['var1', 'var2', 'var3'] + assert list(ds.data_vars.keys()) == ['var1', 'var2', 'var3'] assert 'var1' in ds.data_vars assert 'dim1' not in ds.data_vars assert 'numbers' not in ds.data_vars assert len(ds.data_vars) == 3 - self.assertItemsEqual(ds.indexes, ['dim2', 'dim3', 'time']) + assert set(ds.indexes) == {'dim2', 'dim3', 'time'} assert len(ds.indexes) == 3 assert 'dim2' in repr(ds.indexes) - self.assertItemsEqual(ds.coords, ['time', 'dim2', 'dim3', 'numbers']) + assert list(ds.coords) == ['time', 'dim2', 'dim3', 'numbers'] assert 'dim2' in ds.coords assert 'numbers' in ds.coords assert 'var1' not in ds.coords @@ -534,7 +535,7 @@ def test_coords_properties(self): assert 4 == len(data.coords) - self.assertItemsEqual(['x', 'y', 'a', 'b'], list(data.coords)) + assert ['x', 'y', 'a', 'b'] == list(data.coords) assert_identical(data.coords['x'].variable, data['x'].variable) assert_identical(data.coords['y'].variable, data['y'].variable) @@ -830,7 +831,7 @@ def test_isel(self): ret = data.isel(**slicers) # Verify that only the specified dimension was altered - self.assertItemsEqual(data.dims, ret.dims) + assert list(data.dims) == list(ret.dims) for d in data.dims: if d in slicers: assert ret.dims[d] == \ @@ -856,21 +857,21 @@ def test_isel(self): ret = data.isel(dim1=0) assert {'time': 20, 'dim2': 9, 'dim3': 10} == ret.dims - self.assertItemsEqual(data.data_vars, ret.data_vars) - self.assertItemsEqual(data.coords, ret.coords) - self.assertItemsEqual(data.indexes, ret.indexes) + assert set(data.data_vars) == set(ret.data_vars) + assert set(data.coords) == set(ret.coords) + assert set(data.indexes) == set(ret.indexes) ret = data.isel(time=slice(2), dim1=0, dim2=slice(5)) assert {'time': 2, 'dim2': 5, 'dim3': 10} == ret.dims - self.assertItemsEqual(data.data_vars, ret.data_vars) - self.assertItemsEqual(data.coords, ret.coords) - self.assertItemsEqual(data.indexes, ret.indexes) + assert set(data.data_vars) == set(ret.data_vars) + assert set(data.coords) == set(ret.coords) + assert set(data.indexes) == set(ret.indexes) ret = data.isel(time=0, dim1=0, dim2=slice(5)) - self.assertItemsEqual({'dim2': 5, 'dim3': 10}, ret.dims) - self.assertItemsEqual(data.data_vars, ret.data_vars) - self.assertItemsEqual(data.coords, ret.coords) - self.assertItemsEqual(data.indexes, list(ret.indexes) + ['time']) + assert {'dim2': 5, 'dim3': 10} == ret.dims + assert set(data.data_vars) == set(ret.data_vars) + assert set(data.coords) == set(ret.coords) + assert set(data.indexes) == set(list(ret.indexes) + ['time']) def test_isel_fancy(self): # isel with fancy indexing. @@ -1240,6 +1241,7 @@ def test_isel_drop(self): selected = data.isel(x=0, drop=False) assert_identical(expected, selected) + @pytest.mark.filterwarnings("ignore:Dataset.isel_points") def test_isel_points(self): data = create_test_data() @@ -1317,6 +1319,8 @@ def test_isel_points(self): dim2=stations['dim2s'], dim=np.array([4, 5, 6])) + @pytest.mark.filterwarnings("ignore:Dataset.sel_points") + @pytest.mark.filterwarnings("ignore:Dataset.isel_points") def test_sel_points(self): data = create_test_data() @@ -1347,6 +1351,7 @@ def test_sel_points(self): with pytest.raises(KeyError): data.sel_points(x=[2.5], y=[2.0], method='pad', tolerance=1e-3) + @pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_sel_fancy(self): data = create_test_data() @@ -1477,7 +1482,7 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False, ds = ds.rename({renamed_dim: 'x'}) assert_identical(ds['var'].variable, expected_ds['var'].variable) - self.assertVariableNotEqual(ds['x'], expected_ds['x']) + assert not ds['x'].equals(expected_ds['x']) test_sel(('a', 1, -1), 0) test_sel(('b', 2, -2), -1) @@ -1888,6 +1893,27 @@ def test_copy(self): v1 = copied.variables[k] assert v0 is not v1 + def test_copy_with_data(self): + orig = create_test_data() + new_data = {k: np.random.randn(*v.shape) + for k, v in iteritems(orig.data_vars)} + actual = orig.copy(data=new_data) + + expected = orig.copy() + for k, v in new_data.items(): + expected[k].data = v + assert_identical(expected, actual) + + def test_copy_with_data_errors(self): + orig = create_test_data() + new_var1 = np.arange(orig['var1'].size).reshape(orig['var1'].shape) + with raises_regex(ValueError, 'Data must be dict-like'): + orig.copy(data=new_var1) + with raises_regex(ValueError, 'only contain variables in original'): + orig.copy(data={'not_in_original': new_var1}) + with raises_regex(ValueError, 'contain all variables in original'): + orig.copy(data={'var1': new_var1}) + def test_rename(self): data = create_test_data() newnames = {'var1': 'renamed_var1', 'dim2': 'renamed_dim2'} @@ -2103,17 +2129,18 @@ def test_unstack(self): expected = Dataset({'b': (('x', 'y'), [[0, 1], [2, 3]]), 'x': [0, 1], 'y': ['a', 'b']}) - actual = ds.unstack('z') - assert_identical(actual, expected) + for dim in ['z', ['z'], None]: + actual = ds.unstack(dim) + assert_identical(actual, expected) def test_unstack_errors(self): ds = Dataset({'x': [1, 2, 3]}) - with raises_regex(ValueError, 'invalid dimension'): + with raises_regex(ValueError, 'does not contain the dimensions'): ds.unstack('foo') - with raises_regex(ValueError, 'does not have a MultiIndex'): + with raises_regex(ValueError, 'do not have a MultiIndex'): ds.unstack('x') - def test_stack_unstack(self): + def test_stack_unstack_fast(self): ds = Dataset({'a': ('x', [0, 1]), 'b': (('x', 'y'), [[0, 1], [2, 3]]), 'x': [0, 1], @@ -2124,6 +2151,19 @@ def test_stack_unstack(self): actual = ds[['b']].stack(z=['x', 'y']).unstack('z') assert actual.identical(ds[['b']]) + def test_stack_unstack_slow(self): + ds = Dataset({'a': ('x', [0, 1]), + 'b': (('x', 'y'), [[0, 1], [2, 3]]), + 'x': [0, 1], + 'y': ['a', 'b']}) + stacked = ds.stack(z=['x', 'y']) + actual = stacked.isel(z=slice(None, None, -1)).unstack('z') + assert actual.broadcast_equals(ds) + + stacked = ds[['b']].stack(z=['x', 'y']) + actual = stacked.isel(z=slice(None, None, -1)).unstack('z') + assert actual.identical(ds[['b']]) + def test_update(self): data = create_test_data(seed=0) expected = data.copy() @@ -2506,12 +2546,11 @@ def test_setitem_multiindex_level(self): def test_delitem(self): data = create_test_data() all_items = set(data.variables) - self.assertItemsEqual(data.variables, all_items) + assert set(data.variables) == all_items del data['var1'] - self.assertItemsEqual(data.variables, all_items - set(['var1'])) + assert set(data.variables) == all_items - set(['var1']) del data['numbers'] - self.assertItemsEqual(data.variables, - all_items - set(['var1', 'numbers'])) + assert set(data.variables) == all_items - set(['var1', 'numbers']) assert 'numbers' not in data.coords def test_squeeze(self): @@ -2609,20 +2648,28 @@ def test_groupby_reduce(self): expected = data.mean('y') expected['yonly'] = expected['yonly'].variable.set_dims({'x': 3}) - actual = data.groupby('x').mean() + actual = data.groupby('x').mean(ALL_DIMS) assert_allclose(expected, actual) actual = data.groupby('x').mean('y') assert_allclose(expected, actual) letters = data['letters'] - expected = Dataset({'xy': data['xy'].groupby(letters).mean(), + expected = Dataset({'xy': data['xy'].groupby(letters).mean(ALL_DIMS), 'xonly': (data['xonly'].mean().variable .set_dims({'letters': 2})), 'yonly': data['yonly'].groupby(letters).mean()}) - actual = data.groupby('letters').mean() + actual = data.groupby('letters').mean(ALL_DIMS) assert_allclose(expected, actual) + def test_groupby_warn(self): + data = Dataset({'xy': (['x', 'y'], np.random.randn(3, 4)), + 'xonly': ('x', np.random.randn(3)), + 'yonly': ('y', np.random.randn(4)), + 'letters': ('y', ['a', 'a', 'b', 'b'])}) + with pytest.warns(FutureWarning): + data.groupby('x').mean() + def test_groupby_math(self): def reorder_dims(x): return x.transpose('dim1', 'dim2', 'dim3', 'time') @@ -2677,7 +2724,7 @@ def test_groupby_math_virtual(self): ds = Dataset({'x': ('t', [1, 2, 3])}, {'t': pd.date_range('20100101', periods=3)}) grouped = ds.groupby('t.day') - actual = grouped - grouped.mean() + actual = grouped - grouped.mean(ALL_DIMS) expected = Dataset({'x': ('t', [0, 0, 0])}, ds[['t', 't.day']]) assert_identical(actual, expected) @@ -2686,18 +2733,17 @@ def test_groupby_nan(self): # nan should be excluded from groupby ds = Dataset({'foo': ('x', [1, 2, 3, 4])}, {'bar': ('x', [1, 1, 2, np.nan])}) - actual = ds.groupby('bar').mean() + actual = ds.groupby('bar').mean(ALL_DIMS) expected = Dataset({'foo': ('bar', [1.5, 3]), 'bar': [1, 2]}) assert_identical(actual, expected) def test_groupby_order(self): # groupby should preserve variables order - ds = Dataset() for vn in ['a', 'b', 'c']: ds[vn] = DataArray(np.arange(10), dims=['t']) data_vars_ref = list(ds.data_vars.keys()) - ds = ds.groupby('t').mean() + ds = ds.groupby('t').mean(ALL_DIMS) data_vars = list(ds.data_vars.keys()) assert data_vars == data_vars_ref # coords are now at the end of the list, so the test below fails @@ -2727,6 +2773,20 @@ def test_resample_and_first(self): result = actual.reduce(method) assert_equal(expected, result) + def test_resample_min_count(self): + times = pd.date_range('2000-01-01', freq='6H', periods=10) + ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), + 'bar': ('time', np.random.randn(10), {'meta': 'data'}), + 'time': times}) + # inject nan + ds['foo'] = xr.where(ds['foo'] > 2.0, np.nan, ds['foo']) + + actual = ds.resample(time='1D').sum(min_count=1) + expected = xr.concat([ + ds.isel(time=slice(i * 4, (i + 1) * 4)).sum('time', min_count=1) + for i in range(3)], dim=actual['time']) + assert_equal(expected, actual) + def test_resample_by_mean_with_keep_attrs(self): times = pd.date_range('2000-01-01', freq='6H', periods=10) ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), @@ -3364,9 +3424,8 @@ def test_reduce(self): (['dim2', 'time'], ['dim1', 'dim3']), (('dim2', 'time'), ['dim1', 'dim3']), ((), ['dim1', 'dim2', 'dim3', 'time'])]: - actual = data.min(dim=reduct).dims - print(reduct, actual, expected) - self.assertItemsEqual(actual, expected) + actual = list(data.min(dim=reduct).dims) + assert actual == expected assert_equal(data.mean(dim=[]), data) @@ -3420,8 +3479,7 @@ def test_reduce_cumsum_test_dims(self): ('time', ['dim1', 'dim2', 'dim3']) ]: actual = getattr(data, cumfunc)(dim=reduct).dims - print(reduct, actual, expected) - self.assertItemsEqual(actual, expected) + assert list(actual) == expected def test_reduce_non_numeric(self): data1 = create_test_data(seed=44) @@ -3559,14 +3617,14 @@ def test_rank(self): ds = create_test_data(seed=1234) # only ds.var3 depends on dim3 z = ds.rank('dim3') - self.assertItemsEqual(['var3'], list(z.data_vars)) + assert ['var3'] == list(z.data_vars) # same as dataarray version x = z.var3 y = ds.var3.rank('dim3') assert_equal(x, y) # coordinates stick - self.assertItemsEqual(list(z.coords), list(ds.coords)) - self.assertItemsEqual(list(x.coords), list(y.coords)) + assert list(z.coords) == list(ds.coords) + assert list(x.coords) == list(y.coords) # invalid dim with raises_regex(ValueError, 'does not contain'): x.rank('invalid_dim') @@ -3849,18 +3907,52 @@ def test_shift(self): with raises_regex(ValueError, 'dimensions'): ds.shift(foo=123) - def test_roll(self): + def test_roll_coords(self): coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} attrs = {'meta': 'data'} ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) - actual = ds.roll(x=1) + actual = ds.roll(x=1, roll_coords=True) ex_coords = {'bar': ('x', list('cab')), 'x': [2, -4, 3]} expected = Dataset({'foo': ('x', [3, 1, 2])}, ex_coords, attrs) assert_identical(expected, actual) with raises_regex(ValueError, 'dimensions'): - ds.roll(foo=123) + ds.roll(foo=123, roll_coords=True) + + def test_roll_no_coords(self): + coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} + attrs = {'meta': 'data'} + ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) + actual = ds.roll(x=1, roll_coords=False) + + expected = Dataset({'foo': ('x', [3, 1, 2])}, coords, attrs) + assert_identical(expected, actual) + + with raises_regex(ValueError, 'dimensions'): + ds.roll(abc=321, roll_coords=False) + + def test_roll_coords_none(self): + coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]} + attrs = {'meta': 'data'} + ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs) + + with pytest.warns(FutureWarning): + actual = ds.roll(x=1, roll_coords=None) + + ex_coords = {'bar': ('x', list('cab')), 'x': [2, -4, 3]} + expected = Dataset({'foo': ('x', [3, 1, 2])}, ex_coords, attrs) + assert_identical(expected, actual) + + def test_roll_multidim(self): + # regression test for 2445 + arr = xr.DataArray( + [[1, 2, 3], [4, 5, 6]], coords={'x': range(3), 'y': range(2)}, + dims=('y', 'x')) + actual = arr.roll(x=1, roll_coords=True) + expected = xr.DataArray([[3, 1, 2], [6, 4, 5]], + coords=[('y', [0, 1]), ('x', [2, 0, 1])]) + assert_identical(expected, actual) def test_real_and_imag(self): attrs = {'foo': 'bar'} @@ -3920,6 +4012,26 @@ def test_filter_by_attrs(self): for var in new_ds.data_vars: assert new_ds[var].height == '10 m' + # Test return empty Dataset due to conflicting filters + new_ds = ds.filter_by_attrs( + standard_name='convective_precipitation_flux', + height='0 m') + assert not bool(new_ds.data_vars) + + # Test return one DataArray with two filter conditions + new_ds = ds.filter_by_attrs( + standard_name='air_potential_temperature', + height='0 m') + for var in new_ds.data_vars: + assert new_ds[var].standard_name == 'air_potential_temperature' + assert new_ds[var].height == '0 m' + assert new_ds[var].height != '10 m' + + # Test return empty Dataset due to conflicting callables + new_ds = ds.filter_by_attrs(standard_name=lambda v: False, + height=lambda v: True) + assert not bool(new_ds.data_vars) + def test_binary_op_join_setting(self): # arithmetic_join applies to data array coordinates missing_2 = xr.Dataset({'x': [0, 1]}) @@ -4418,3 +4530,107 @@ def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: Dataset(data_vars={'x': ('y', [1, 2, np.NaN])}) > 0 assert len(record) == 0 + + +@pytest.mark.parametrize('dask', [True, False]) +@pytest.mark.parametrize('edge_order', [1, 2]) +def test_differentiate(dask, edge_order): + rs = np.random.RandomState(42) + coord = [0.2, 0.35, 0.4, 0.6, 0.7, 0.75, 0.76, 0.8] + + da = xr.DataArray(rs.randn(8, 6), dims=['x', 'y'], + coords={'x': coord, + 'z': 3, 'x2d': (('x', 'y'), rs.randn(8, 6))}) + if dask and has_dask: + da = da.chunk({'x': 4}) + + ds = xr.Dataset({'var': da}) + + # along x + actual = da.differentiate('x', edge_order) + expected_x = xr.DataArray( + npcompat.gradient(da, da['x'], axis=0, edge_order=edge_order), + dims=da.dims, coords=da.coords) + assert_equal(expected_x, actual) + assert_equal(ds['var'].differentiate('x', edge_order=edge_order), + ds.differentiate('x', edge_order=edge_order)['var']) + # coordinate should not change + assert_equal(da['x'], actual['x']) + + # along y + actual = da.differentiate('y', edge_order) + expected_y = xr.DataArray( + npcompat.gradient(da, da['y'], axis=1, edge_order=edge_order), + dims=da.dims, coords=da.coords) + assert_equal(expected_y, actual) + assert_equal(actual, ds.differentiate('y', edge_order=edge_order)['var']) + assert_equal(ds['var'].differentiate('y', edge_order=edge_order), + ds.differentiate('y', edge_order=edge_order)['var']) + + with pytest.raises(ValueError): + da.differentiate('x2d') + + +@pytest.mark.parametrize('dask', [True, False]) +def test_differentiate_datetime(dask): + rs = np.random.RandomState(42) + coord = np.array( + ['2004-07-13', '2006-01-13', '2010-08-13', '2010-09-13', + '2010-10-11', '2010-12-13', '2011-02-13', '2012-08-13'], + dtype='datetime64') + + da = xr.DataArray(rs.randn(8, 6), dims=['x', 'y'], + coords={'x': coord, + 'z': 3, 'x2d': (('x', 'y'), rs.randn(8, 6))}) + if dask and has_dask: + da = da.chunk({'x': 4}) + + # along x + actual = da.differentiate('x', edge_order=1, datetime_unit='D') + expected_x = xr.DataArray( + npcompat.gradient( + da, utils.datetime_to_numeric(da['x'], datetime_unit='D'), + axis=0, edge_order=1), dims=da.dims, coords=da.coords) + assert_equal(expected_x, actual) + + actual2 = da.differentiate('x', edge_order=1, datetime_unit='h') + assert np.allclose(actual, actual2 * 24) + + # for datetime variable + actual = da['x'].differentiate('x', edge_order=1, datetime_unit='D') + assert np.allclose(actual, 1.0) + + # with different date unit + da = xr.DataArray(coord.astype('datetime64[ms]'), dims=['x'], + coords={'x': coord}) + actual = da.differentiate('x', edge_order=1) + assert np.allclose(actual, 1.0) + + +@pytest.mark.skipif(not has_cftime, reason='Test requires cftime.') +@pytest.mark.parametrize('dask', [True, False]) +def test_differentiate_cftime(dask): + rs = np.random.RandomState(42) + coord = xr.cftime_range('2000', periods=8, freq='2M') + + da = xr.DataArray( + rs.randn(8, 6), + coords={'time': coord, 'z': 3, 't2d': (('time', 'y'), rs.randn(8, 6))}, + dims=['time', 'y']) + + if dask and has_dask: + da = da.chunk({'time': 4}) + + actual = da.differentiate('time', edge_order=1, datetime_unit='D') + expected_data = npcompat.gradient( + da, utils.datetime_to_numeric(da['time'], datetime_unit='D'), + axis=0, edge_order=1) + expected = xr.DataArray(expected_data, coords=da.coords, dims=da.dims) + assert_equal(expected, actual) + + actual2 = da.differentiate('time', edge_order=1, datetime_unit='h') + assert_allclose(actual, actual2 * 24) + + # Test the differentiation of datetimes themselves + actual = da['time'].differentiate('time', edge_order=1, datetime_unit='D') + assert_allclose(actual, xr.ones_like(da['time']).astype(float)) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 32035afdc57..7c77a62d3c9 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -15,12 +15,13 @@ from distributed.utils_test import cluster, gen_cluster from distributed.utils_test import loop # flake8: noqa from distributed.client import futures_of +import numpy as np import xarray as xr +from xarray.backends.locks import HDF5_LOCK, CombinedLock from xarray.tests.test_backends import (ON_WINDOWS, create_tmp_file, create_tmp_geotiff) from xarray.tests.test_dataset import create_test_data -from xarray.backends.common import HDF5_LOCK, CombinedLock from . import ( assert_allclose, has_h5netcdf, has_netCDF4, requires_rasterio, has_scipy, @@ -33,6 +34,11 @@ da = pytest.importorskip('dask.array') +@pytest.fixture +def tmp_netcdf_filename(tmpdir): + return str(tmpdir.join('testfile.nc')) + + ENGINES = [] if has_scipy: ENGINES.append('scipy') @@ -45,81 +51,69 @@ 'NETCDF3_64BIT_DATA', 'NETCDF4_CLASSIC', 'NETCDF4'], 'scipy': ['NETCDF3_CLASSIC', 'NETCDF3_64BIT'], 'h5netcdf': ['NETCDF4']} -TEST_FORMATS = ['NETCDF3_CLASSIC', 'NETCDF4_CLASSIC', 'NETCDF4'] +ENGINES_AND_FORMATS = [ + ('netcdf4', 'NETCDF3_CLASSIC'), + ('netcdf4', 'NETCDF4_CLASSIC'), + ('netcdf4', 'NETCDF4'), + ('h5netcdf', 'NETCDF4'), + ('scipy', 'NETCDF3_64BIT'), +] -@pytest.mark.xfail(sys.platform == 'win32', - reason='https://github.com/pydata/xarray/issues/1738') -@pytest.mark.parametrize('engine', ['netcdf4']) -@pytest.mark.parametrize('autoclose', [True, False]) -@pytest.mark.parametrize('nc_format', TEST_FORMATS) -def test_dask_distributed_netcdf_roundtrip(monkeypatch, loop, - engine, autoclose, nc_format): - monkeypatch.setenv('HDF5_USE_FILE_LOCKING', 'FALSE') - - chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} - - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: +@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) +def test_dask_distributed_netcdf_roundtrip( + loop, tmp_netcdf_filename, engine, nc_format): - original = create_test_data().chunk(chunks) - original.to_netcdf(filename, engine=engine, format=nc_format) - - with xr.open_dataset(filename, - chunks=chunks, - engine=engine, - autoclose=autoclose) as restored: - assert isinstance(restored.var1.data, da.Array) - computed = restored.compute() - assert_allclose(original, computed) + if engine not in ENGINES: + pytest.skip('engine not available') + chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} -@pytest.mark.xfail(sys.platform == 'win32', - reason='https://github.com/pydata/xarray/issues/1738') -@pytest.mark.parametrize('engine', ENGINES) -@pytest.mark.parametrize('autoclose', [True, False]) -@pytest.mark.parametrize('nc_format', TEST_FORMATS) -def test_dask_distributed_read_netcdf_integration_test(loop, engine, autoclose, - nc_format): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: - if engine == 'h5netcdf' and autoclose: - pytest.skip('h5netcdf does not support autoclose') + original = create_test_data().chunk(chunks) - if nc_format not in NC_FORMATS[engine]: - pytest.skip('invalid format for engine') + if engine == 'scipy': + with pytest.raises(NotImplementedError): + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) + return - chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with xr.open_dataset(tmp_netcdf_filename, + chunks=chunks, engine=engine) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) - original = create_test_data() - original.to_netcdf(filename, engine=engine, format=nc_format) - with xr.open_dataset(filename, - chunks=chunks, - engine=engine, - autoclose=autoclose) as restored: - assert isinstance(restored.var1.data, da.Array) - computed = restored.compute() - assert_allclose(original, computed) +@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) +def test_dask_distributed_read_netcdf_integration_test( + loop, tmp_netcdf_filename, engine, nc_format): + if engine not in ENGINES: + pytest.skip('engine not available') -@pytest.mark.parametrize('engine', ['h5netcdf', 'scipy']) -def test_dask_distributed_netcdf_integration_test_not_implemented(loop, engine): chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + + original = create_test_data() + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) - original = create_test_data().chunk(chunks) + with xr.open_dataset(tmp_netcdf_filename, + chunks=chunks, + engine=engine) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) - with raises_regex(NotImplementedError, 'distributed'): - original.to_netcdf(filename, engine=engine) @requires_zarr diff --git a/xarray/tests/test_dtypes.py b/xarray/tests/test_dtypes.py index 833df85f8af..292c60b4d05 100644 --- a/xarray/tests/test_dtypes.py +++ b/xarray/tests/test_dtypes.py @@ -50,3 +50,39 @@ def error(): def test_inf(obj): assert dtypes.INF > obj assert dtypes.NINF < obj + + +@pytest.mark.parametrize("kind, expected", [ + ('a', (np.dtype('O'), 'nan')), # dtype('S') + ('b', (np.float32, 'nan')), # dtype('int8') + ('B', (np.float32, 'nan')), # dtype('uint8') + ('c', (np.dtype('O'), 'nan')), # dtype('S1') + ('D', (np.complex128, '(nan+nanj)')), # dtype('complex128') + ('d', (np.float64, 'nan')), # dtype('float64') + ('e', (np.float16, 'nan')), # dtype('float16') + ('F', (np.complex64, '(nan+nanj)')), # dtype('complex64') + ('f', (np.float32, 'nan')), # dtype('float32') + ('h', (np.float32, 'nan')), # dtype('int16') + ('H', (np.float32, 'nan')), # dtype('uint16') + ('i', (np.float64, 'nan')), # dtype('int32') + ('I', (np.float64, 'nan')), # dtype('uint32') + ('l', (np.float64, 'nan')), # dtype('int64') + ('L', (np.float64, 'nan')), # dtype('uint64') + ('m', (np.timedelta64, 'NaT')), # dtype(' 0: + assert isinstance(da.data, dask_array_type) + + @pytest.mark.parametrize('dim_num', [1, 2]) @pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_]) @pytest.mark.parametrize('dask', [False, True]) @pytest.mark.parametrize('func', ['sum', 'min', 'max', 'mean', 'var']) +# TODO test cumsum, cumprod @pytest.mark.parametrize('skipna', [False, True]) @pytest.mark.parametrize('aggdim', [None, 'x']) def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): @@ -251,6 +269,9 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): if dask and not has_dask: pytest.skip('requires dask') + if dask and skipna is False and dtype in [np.bool_]: + pytest.skip('dask does not compute object-typed array') + rtol = 1e-04 if dtype == np.float32 else 1e-05 da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask) @@ -259,6 +280,7 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): # TODO: remove these after resolving # https://github.com/dask/dask/issues/3245 with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Mean of empty slice') warnings.filterwarnings('ignore', 'All-NaN slice') warnings.filterwarnings('ignore', 'invalid value encountered in') @@ -272,6 +294,7 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): expected = getattr(np, func)(da.values, axis=axis) actual = getattr(da, func)(skipna=skipna, dim=aggdim) + assert_dask_array(actual, dask) assert np.allclose(actual.values, np.array(expected), rtol=1.0e-4, equal_nan=True) except (TypeError, AttributeError, ZeroDivisionError): @@ -279,14 +302,21 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): # nanmean for object dtype pass - # make sure the compatiblility with pandas' results. actual = getattr(da, func)(skipna=skipna, dim=aggdim) - if func == 'var': + + # for dask case, make sure the result is the same for numpy backend + expected = getattr(da.compute(), func)(skipna=skipna, dim=aggdim) + assert_allclose(actual, expected, rtol=rtol) + + # make sure the compatiblility with pandas' results. + if func in ['var', 'std']: expected = series_reduce(da, func, skipna=skipna, dim=aggdim, ddof=0) assert_allclose(actual, expected, rtol=rtol) # also check ddof!=0 case actual = getattr(da, func)(skipna=skipna, dim=aggdim, ddof=5) + if dask: + assert isinstance(da.data, dask_array_type) expected = series_reduce(da, func, skipna=skipna, dim=aggdim, ddof=5) assert_allclose(actual, expected, rtol=rtol) @@ -297,11 +327,14 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): # make sure the dtype argument if func not in ['max', 'min']: actual = getattr(da, func)(skipna=skipna, dim=aggdim, dtype=float) + assert_dask_array(actual, dask) assert actual.dtype == float # without nan da = construct_dataarray(dim_num, dtype, contains_nan=False, dask=dask) actual = getattr(da, func)(skipna=skipna) + if dask: + assert isinstance(da.data, dask_array_type) expected = getattr(np, 'nan{}'.format(func))(da.values) if actual.dtype == object: assert actual.values == np.array(expected) @@ -338,13 +371,6 @@ def test_argmin_max(dim_num, dtype, contains_nan, dask, func, skipna, aggdim): with warnings.catch_warnings(): warnings.filterwarnings('ignore', 'All-NaN slice') - if aggdim == 'y' and contains_nan and skipna: - with pytest.raises(ValueError): - actual = da.isel(**{ - aggdim: getattr(da, 'arg' + func)( - dim=aggdim, skipna=skipna).compute()}) - return - actual = da.isel(**{aggdim: getattr(da, 'arg' + func) (dim=aggdim, skipna=skipna).compute()}) expected = getattr(da, func)(dim=aggdim, skipna=skipna) @@ -354,6 +380,7 @@ def test_argmin_max(dim_num, dtype, contains_nan, dask, func, skipna, aggdim): def test_argmin_max_error(): da = construct_dataarray(2, np.bool_, contains_nan=True, dask=False) + da[0] = np.nan with pytest.raises(ValueError): da.argmin(dim='y') @@ -388,3 +415,139 @@ def test_dask_rolling(axis, window, center): with pytest.raises(ValueError): rolling_window(dx, axis=axis, window=100, center=center, fill_value=np.nan) + + +@pytest.mark.skipif(not has_dask, reason='This is for dask.') +@pytest.mark.parametrize('axis', [0, -1, 1]) +@pytest.mark.parametrize('edge_order', [1, 2]) +def test_dask_gradient(axis, edge_order): + import dask.array as da + + array = np.array(np.random.randn(100, 5, 40)) + x = np.exp(np.linspace(0, 1, array.shape[axis])) + + darray = da.from_array(array, chunks=[(6, 30, 30, 20, 14), 5, 8]) + expected = gradient(array, x, axis=axis, edge_order=edge_order) + actual = gradient(darray, x, axis=axis, edge_order=edge_order) + + assert isinstance(actual, da.Array) + assert_array_equal(actual, expected) + + +@pytest.mark.parametrize('dim_num', [1, 2]) +@pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_]) +@pytest.mark.parametrize('dask', [False, True]) +@pytest.mark.parametrize('func', ['sum', 'prod']) +@pytest.mark.parametrize('aggdim', [None, 'x']) +def test_min_count(dim_num, dtype, dask, func, aggdim): + if dask and not has_dask: + pytest.skip('requires dask') + + da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask) + min_count = 3 + + actual = getattr(da, func)(dim=aggdim, skipna=True, min_count=min_count) + + if LooseVersion(pd.__version__) >= LooseVersion('0.22.0'): + # min_count is only implenented in pandas > 0.22 + expected = series_reduce(da, func, skipna=True, dim=aggdim, + min_count=min_count) + assert_allclose(actual, expected) + + assert_dask_array(actual, dask) + + +@pytest.mark.parametrize('func', ['sum', 'prod']) +def test_min_count_dataset(func): + da = construct_dataarray(2, dtype=float, contains_nan=True, dask=False) + ds = Dataset({'var1': da}, coords={'scalar': 0}) + actual = getattr(ds, func)(dim='x', skipna=True, min_count=3)['var1'] + expected = getattr(ds['var1'], func)(dim='x', skipna=True, min_count=3) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_]) +@pytest.mark.parametrize('dask', [False, True]) +@pytest.mark.parametrize('func', ['sum', 'prod']) +def test_multiple_dims(dtype, dask, func): + if dask and not has_dask: + pytest.skip('requires dask') + da = construct_dataarray(3, dtype, contains_nan=True, dask=dask) + + actual = getattr(da, func)(('x', 'y')) + expected = getattr(getattr(da, func)('x'), func)('y') + assert_allclose(actual, expected) + + +def test_docs(): + # with min_count + actual = DataArray.sum.__doc__ + expected = dedent("""\ + Reduce this DataArray's data by applying `sum` along some dimension(s). + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to apply `sum`. + axis : int or sequence of int, optional + Axis(es) over which to apply `sum`. Only one of the 'dim' + and 'axis' arguments can be supplied. If neither are supplied, then + `sum` is calculated over axes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + min_count : int, default None + The required number of valid values to perform the operation. + If fewer than min_count non-NA values are present the result will + be NA. New in version 0.10.8: Added with the default being None. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating `sum` on this object's data. + + Returns + ------- + reduced : DataArray + New DataArray object with `sum` applied to its data and the + indicated dimension(s) removed. + """) + assert actual == expected + + # without min_count + actual = DataArray.std.__doc__ + expected = dedent("""\ + Reduce this DataArray's data by applying `std` along some dimension(s). + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to apply `std`. + axis : int or sequence of int, optional + Axis(es) over which to apply `std`. Only one of the 'dim' + and 'axis' arguments can be supplied. If neither are supplied, then + `std` is calculated over axes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to the appropriate array + function for calculating `std` on this object's data. + + Returns + ------- + reduced : DataArray + New DataArray object with `std` applied to its data and the + indicated dimension(s) removed. + """) + assert actual == expected diff --git a/xarray/tests/test_extensions.py b/xarray/tests/test_extensions.py index 24b710ae223..ffefa78aa34 100644 --- a/xarray/tests/test_extensions.py +++ b/xarray/tests/test_extensions.py @@ -4,7 +4,7 @@ import xarray as xr -from . import TestCase, raises_regex +from . import raises_regex try: import cPickle as pickle @@ -21,7 +21,7 @@ def __init__(self, xarray_obj): self.obj = xarray_obj -class TestAccessor(TestCase): +class TestAccessor(object): def test_register(self): @xr.register_dataset_accessor('demo') diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 8a1003f1ced..024c669bed9 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -7,10 +7,10 @@ from xarray.core import formatting from xarray.core.pycompat import PY3 -from . import TestCase, raises_regex +from . import raises_regex -class TestFormatting(TestCase): +class TestFormatting(object): def test_get_indexer_at_least_n_items(self): cases = [ @@ -45,7 +45,7 @@ def test_first_n_items(self): for n in [3, 10, 13, 100, 200]: actual = formatting.first_n_items(array, n) expected = array.flat[:n] - self.assertItemsEqual(expected, actual) + assert (expected == actual).all() with raises_regex(ValueError, 'at least one item'): formatting.first_n_items(array, 0) @@ -55,7 +55,7 @@ def test_last_n_items(self): for n in [3, 10, 13, 100, 200]: actual = formatting.last_n_items(array, n) expected = array.flat[-n:] - self.assertItemsEqual(expected, actual) + assert (expected == actual).all() with raises_regex(ValueError, 'at least one item'): formatting.first_n_items(array, 0) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 6dd14f5d6ad..8ace55be66b 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -5,9 +5,10 @@ import pytest import xarray as xr -from . import assert_identical from xarray.core.groupby import _consolidate_slices +from . import assert_identical + def test_consolidate_slices(): diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 0d1045d35c0..701eefcb462 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -10,13 +10,12 @@ from xarray.core import indexing, nputils from xarray.core.pycompat import native_int_types -from . import ( - IndexerMaker, ReturnItem, TestCase, assert_array_equal, raises_regex) +from . import IndexerMaker, ReturnItem, assert_array_equal, raises_regex B = IndexerMaker(indexing.BasicIndexer) -class TestIndexers(TestCase): +class TestIndexers(object): def set_to_zero(self, x, i): x = x.copy() x[i] = 0 @@ -25,7 +24,7 @@ def set_to_zero(self, x, i): def test_expanded_indexer(self): x = np.random.randn(10, 11, 12, 13, 14) y = np.arange(5) - I = ReturnItem() # noqa: E741 # allow ambiguous name + I = ReturnItem() # noqa for i in [I[:], I[...], I[0, :, 10], I[..., 10], I[:5, ..., 0], I[..., 0, :], I[y], I[y, y], I[..., y, y], I[..., 0, 1, 2, 3, 4]]: @@ -133,7 +132,7 @@ def test_indexer(data, x, expected_pos, expected_idx=None): pd.MultiIndex.from_product([[1, 2], [-1, -2]])) -class TestLazyArray(TestCase): +class TestLazyArray(object): def test_slice_slice(self): I = ReturnItem() # noqa: E741 # allow ambiguous name for size in [100, 99]: @@ -248,7 +247,7 @@ def check_indexing(v_eager, v_lazy, indexers): check_indexing(v_eager, v_lazy, indexers) -class TestCopyOnWriteArray(TestCase): +class TestCopyOnWriteArray(object): def test_setitem(self): original = np.arange(10) wrapped = indexing.CopyOnWriteArray(original) @@ -272,7 +271,7 @@ def test_index_scalar(self): assert np.array(x[B[0]][B[()]]) == 'foo' -class TestMemoryCachedArray(TestCase): +class TestMemoryCachedArray(object): def test_wrapper(self): original = indexing.LazilyOuterIndexedArray(np.arange(10)) wrapped = indexing.MemoryCachedArray(original) @@ -385,8 +384,9 @@ def test_vectorized_indexer(): np.arange(5, dtype=np.int64))) -class Test_vectorized_indexer(TestCase): - def setUp(self): +class Test_vectorized_indexer(object): + @pytest.fixture(autouse=True) + def setup(self): self.data = indexing.NumpyIndexingAdapter(np.random.randn(10, 12, 13)) self.indexers = [np.array([[0, 3, 2], ]), np.array([[0, 3, 3], [4, 6, 7]]), diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 4a8f4e6eedf..624879cce1f 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -5,8 +5,11 @@ import pytest import xarray as xr -from xarray.tests import assert_allclose, assert_equal, requires_scipy +from xarray.tests import ( + assert_allclose, assert_equal, requires_cftime, requires_scipy) + from . import has_dask, has_scipy +from ..coding.cftimeindex import _parse_array_of_cftime_strings from .test_dataset import create_test_data try: @@ -490,3 +493,83 @@ def test_datetime_single_string(): expected = xr.DataArray(0.5) assert_allclose(actual.drop('time'), expected) + + +@requires_cftime +@requires_scipy +def test_cftime(): + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = xr.cftime_range('2000-01-01T12:00:00', periods=3, freq='D') + actual = da.interp(time=times_new) + expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new], dims=['time']) + + assert_allclose(actual, expected) + + +@requires_cftime +@requires_scipy +def test_cftime_type_error(): + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = xr.cftime_range('2000-01-01T12:00:00', periods=3, freq='D', + calendar='noleap') + with pytest.raises(TypeError): + da.interp(time=times_new) + + +@requires_cftime +@requires_scipy +def test_cftime_list_of_strings(): + from cftime import DatetimeProlepticGregorian + + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = ['2000-01-01T12:00', '2000-01-02T12:00', '2000-01-03T12:00'] + actual = da.interp(time=times_new) + + times_new_array = _parse_array_of_cftime_strings( + np.array(times_new), DatetimeProlepticGregorian) + expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new_array], + dims=['time']) + + assert_allclose(actual, expected) + + +@requires_cftime +@requires_scipy +def test_cftime_single_string(): + from cftime import DatetimeProlepticGregorian + + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + times_new = '2000-01-01T12:00' + actual = da.interp(time=times_new) + + times_new_array = _parse_array_of_cftime_strings( + np.array(times_new), DatetimeProlepticGregorian) + expected = xr.DataArray(0.5, coords={'time': times_new_array}) + + assert_allclose(actual, expected) + + +@requires_scipy +def test_datetime_to_non_datetime_error(): + da = xr.DataArray(np.arange(24), dims='time', + coords={'time': pd.date_range('2000-01-01', periods=24)}) + with pytest.raises(TypeError): + da.interp(time=0.5) + + +@requires_cftime +@requires_scipy +def test_cftime_to_non_cftime_error(): + times = xr.cftime_range('2000', periods=24, freq='D') + da = xr.DataArray(np.arange(24), coords=[times], dims='time') + + with pytest.raises(TypeError): + da.interp(time=0.5) diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index 4d89be8ce55..300c490cff6 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -6,11 +6,11 @@ import xarray as xr from xarray.core import merge -from . import TestCase, raises_regex +from . import raises_regex from .test_dataset import create_test_data -class TestMergeInternals(TestCase): +class TestMergeInternals(object): def test_broadcast_dimension_size(self): actual = merge.broadcast_dimension_size( [xr.Variable('x', [1]), xr.Variable('y', [2, 1])]) @@ -25,7 +25,7 @@ def test_broadcast_dimension_size(self): [xr.Variable(('x', 'y'), [[1, 2]]), xr.Variable('y', [2])]) -class TestMergeFunction(TestCase): +class TestMergeFunction(object): def test_merge_arrays(self): data = create_test_data() actual = xr.merge([data.var1, data.var2]) @@ -130,7 +130,7 @@ def test_merge_no_conflicts_broadcast(self): assert expected.identical(actual) -class TestMergeMethod(TestCase): +class TestMergeMethod(object): def test_merge(self): data = create_test_data() @@ -195,7 +195,7 @@ def test_merge_compat(self): with pytest.raises(xr.MergeError): ds1.merge(ds2, compat='identical') - with raises_regex(ValueError, 'compat=\S+ invalid'): + with raises_regex(ValueError, 'compat=.* invalid'): ds1.merge(ds2, compat='foobar') def test_merge_auto_align(self): diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 5c7e384c789..47224e55473 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -93,14 +93,14 @@ def test_interpolate_pd_compat(): @requires_scipy -def test_scipy_methods_function(): - for method in ['barycentric', 'krog', 'pchip', 'spline', 'akima']: - kwargs = {} - # Note: Pandas does some wacky things with these methods and the full - # integration tests wont work. - da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True) - actual = da.interpolate_na(method=method, dim='time', **kwargs) - assert (da.count('time') <= actual.count('time')).all() +@pytest.mark.parametrize('method', ['barycentric', 'krog', + 'pchip', 'spline', 'akima']) +def test_scipy_methods_function(method): + # Note: Pandas does some wacky things with these methods and the full + # integration tests wont work. + da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True) + actual = da.interpolate_na(method=method, dim='time') + assert (da.count('time') <= actual.count('time')).all() @requires_scipy diff --git a/xarray/tests/test_options.py b/xarray/tests/test_options.py index aed96f1acb6..4441375a1b1 100644 --- a/xarray/tests/test_options.py +++ b/xarray/tests/test_options.py @@ -4,6 +4,7 @@ import xarray from xarray.core.options import OPTIONS +from xarray.backends.file_manager import FILE_CACHE def test_invalid_option_raises(): @@ -11,6 +12,38 @@ def test_invalid_option_raises(): xarray.set_options(not_a_valid_options=True) +def test_display_width(): + with pytest.raises(ValueError): + xarray.set_options(display_width=0) + with pytest.raises(ValueError): + xarray.set_options(display_width=-10) + with pytest.raises(ValueError): + xarray.set_options(display_width=3.5) + + +def test_arithmetic_join(): + with pytest.raises(ValueError): + xarray.set_options(arithmetic_join='invalid') + with xarray.set_options(arithmetic_join='exact'): + assert OPTIONS['arithmetic_join'] == 'exact' + + +def test_enable_cftimeindex(): + with pytest.raises(ValueError): + xarray.set_options(enable_cftimeindex=None) + with xarray.set_options(enable_cftimeindex=True): + assert OPTIONS['enable_cftimeindex'] + + +def test_file_cache_maxsize(): + with pytest.raises(ValueError): + xarray.set_options(file_cache_maxsize=0) + original_size = FILE_CACHE.maxsize + with xarray.set_options(file_cache_maxsize=123): + assert FILE_CACHE.maxsize == 123 + assert FILE_CACHE.maxsize == original_size + + def test_nested_options(): original = OPTIONS['display_width'] with xarray.set_options(display_width=1): diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 51450c2838f..0af67532416 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -7,6 +7,7 @@ import pandas as pd import pytest +import xarray as xr import xarray.plot as xplt from xarray import DataArray from xarray.coding.times import _import_cftime @@ -16,9 +17,8 @@ import_seaborn, label_from_attrs) from . import ( - TestCase, assert_array_equal, assert_equal, raises_regex, - requires_matplotlib, requires_matplotlib2, requires_seaborn, - requires_cftime) + assert_array_equal, assert_equal, raises_regex, requires_cftime, + requires_matplotlib, requires_matplotlib2, requires_seaborn) # import mpl and change the backend before other mpl imports try: @@ -64,8 +64,10 @@ def easy_array(shape, start=0, stop=1): @requires_matplotlib -class PlotTestCase(TestCase): - def tearDown(self): +class PlotTestCase(object): + @pytest.fixture(autouse=True) + def setup(self): + yield # Remove all matplotlib figures plt.close('all') @@ -87,7 +89,8 @@ def contourf_called(self, plotmethod): class TestPlot(PlotTestCase): - def setUp(self): + @pytest.fixture(autouse=True) + def setup_array(self): self.darray = DataArray(easy_array((2, 3, 4))) def test_label_from_attrs(self): @@ -159,8 +162,8 @@ def test_2d_line_accepts_legend_kw(self): self.darray[:, :, 0].plot.line(x='dim_0', add_legend=True) assert plt.gca().get_legend() # check whether legend title is set - assert plt.gca().get_legend().get_title().get_text() \ - == 'dim_1' + assert (plt.gca().get_legend().get_title().get_text() + == 'dim_1') def test_2d_line_accepts_x_kw(self): self.darray[:, :, 0].plot.line(x='dim_0') @@ -171,12 +174,12 @@ def test_2d_line_accepts_x_kw(self): def test_2d_line_accepts_hue_kw(self): self.darray[:, :, 0].plot.line(hue='dim_0') - assert plt.gca().get_legend().get_title().get_text() \ - == 'dim_0' + assert (plt.gca().get_legend().get_title().get_text() + == 'dim_0') plt.cla() self.darray[:, :, 0].plot.line(hue='dim_1') - assert plt.gca().get_legend().get_title().get_text() \ - == 'dim_1' + assert (plt.gca().get_legend().get_title().get_text() + == 'dim_1') def test_2d_before_squeeze(self): a = DataArray(easy_array((1, 5))) @@ -267,6 +270,7 @@ def test_datetime_dimension(self): assert ax.has_data() @pytest.mark.slow + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid(self): a = easy_array((10, 15, 4)) d = DataArray(a, dims=['y', 'x', 'z']) @@ -328,6 +332,7 @@ def test_plot_size(self): self.darray.plot(aspect=1) @pytest.mark.slow + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid_4d(self): a = easy_array((10, 15, 2, 3)) d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) @@ -346,6 +351,7 @@ def test_coord_with_interval(self): class TestPlot1D(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): d = [0, 1.1, 0, 2] self.darray = DataArray( @@ -358,7 +364,7 @@ def test_xlabel_is_index_name(self): def test_no_label_name_on_x_axis(self): self.darray.plot(y='period') - self.assertEqual('', plt.gca().get_xlabel()) + assert '' == plt.gca().get_xlabel() def test_no_label_name_on_y_axis(self): self.darray.plot() @@ -431,6 +437,7 @@ def test_coord_with_interval_step(self): class TestPlotHistogram(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): self.darray = DataArray(easy_array((2, 3, 4))) @@ -470,7 +477,8 @@ def test_hist_coord_with_interval(self): @requires_matplotlib -class TestDetermineCmapParams(TestCase): +class TestDetermineCmapParams(object): + @pytest.fixture(autouse=True) def setUp(self): self.data = np.linspace(0, 1, num=100) @@ -491,6 +499,21 @@ def test_center(self): assert cmap_params['levels'] is None assert cmap_params['norm'] is None + def test_cmap_sequential_option(self): + with xr.set_options(cmap_sequential='magma'): + cmap_params = _determine_cmap_params(self.data) + assert cmap_params['cmap'] == 'magma' + + def test_cmap_sequential_explicit_option(self): + with xr.set_options(cmap_sequential=mpl.cm.magma): + cmap_params = _determine_cmap_params(self.data) + assert cmap_params['cmap'] == mpl.cm.magma + + def test_cmap_divergent_option(self): + with xr.set_options(cmap_divergent='magma'): + cmap_params = _determine_cmap_params(self.data, center=0.5) + assert cmap_params['cmap'] == 'magma' + def test_nan_inf_are_ignored(self): cmap_params1 = _determine_cmap_params(self.data) data = self.data @@ -626,9 +649,30 @@ def test_divergentcontrol(self): assert cmap_params['vmax'] == 0.6 assert cmap_params['cmap'] == "viridis" + def test_norm_sets_vmin_vmax(self): + vmin = self.data.min() + vmax = self.data.max() + + for norm, extend in zip([mpl.colors.LogNorm(), + mpl.colors.LogNorm(vmin + 1, vmax - 1), + mpl.colors.LogNorm(None, vmax - 1), + mpl.colors.LogNorm(vmin + 1, None)], + ['neither', 'both', 'max', 'min']): + + test_min = vmin if norm.vmin is None else norm.vmin + test_max = vmax if norm.vmax is None else norm.vmax + + cmap_params = _determine_cmap_params(self.data, norm=norm) + + assert cmap_params['vmin'] == test_min + assert cmap_params['vmax'] == test_max + assert cmap_params['extend'] == extend + assert cmap_params['norm'] == norm + @requires_matplotlib -class TestDiscreteColorMap(TestCase): +class TestDiscreteColorMap(object): + @pytest.fixture(autouse=True) def setUp(self): x = np.arange(start=0, stop=10, step=2) y = np.arange(start=9, stop=-7, step=-3) @@ -662,10 +706,10 @@ def test_build_discrete_cmap(self): @pytest.mark.slow def test_discrete_colormap_list_of_levels(self): - for extend, levels in [('max', [-1, 2, 4, 8, 10]), ('both', - [2, 5, 10, 11]), - ('neither', [0, 5, 10, 15]), ('min', - [2, 5, 10, 15])]: + for extend, levels in [('max', [-1, 2, 4, 8, 10]), + ('both', [2, 5, 10, 11]), + ('neither', [0, 5, 10, 15]), + ('min', [2, 5, 10, 15])]: for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: primitive = getattr(self.darray.plot, kind)(levels=levels) assert_array_equal(levels, primitive.norm.boundaries) @@ -679,10 +723,10 @@ def test_discrete_colormap_list_of_levels(self): @pytest.mark.slow def test_discrete_colormap_int_levels(self): - for extend, levels, vmin, vmax in [('neither', 7, None, - None), ('neither', 7, None, 20), - ('both', 7, 4, 8), ('min', 10, 4, - 15)]: + for extend, levels, vmin, vmax in [('neither', 7, None, None), + ('neither', 7, None, 20), + ('both', 7, 4, 8), + ('min', 10, 4, 15)]: for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: primitive = getattr(self.darray.plot, kind)( levels=levels, vmin=vmin, vmax=vmax) @@ -708,8 +752,13 @@ def test_discrete_colormap_list_levels_and_vmin_or_vmax(self): assert primitive.norm.vmax == max(levels) assert primitive.norm.vmin == min(levels) + def test_discrete_colormap_provided_boundary_norm(self): + norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4) + primitive = self.darray.plot.contourf(norm=norm) + np.testing.assert_allclose(primitive.levels, norm.boundaries) + -class Common2dMixin: +class Common2dMixin(object): """ Common tests for 2d plotting go here. @@ -717,6 +766,7 @@ class Common2dMixin: Should have the same name as the method. """ + @pytest.fixture(autouse=True) def setUp(self): da = DataArray(easy_array((10, 15), start=-1), dims=['y', 'x'], @@ -765,6 +815,24 @@ def test_nonnumeric_index_raises_typeerror(self): def test_can_pass_in_axis(self): self.pass_in_axis(self.plotmethod) + def test_xyincrease_defaults(self): + + # With default settings the axis must be ordered regardless + # of the coords order. + self.plotfunc(DataArray(easy_array((3, 2)), coords=[[1, 2, 3], + [1, 2]])) + bounds = plt.gca().get_ylim() + assert bounds[0] < bounds[1] + bounds = plt.gca().get_xlim() + assert bounds[0] < bounds[1] + # Inverted coords + self.plotfunc(DataArray(easy_array((3, 2)), coords=[[3, 2, 1], + [2, 1]])) + bounds = plt.gca().get_ylim() + assert bounds[0] < bounds[1] + bounds = plt.gca().get_xlim() + assert bounds[0] < bounds[1] + def test_xyincrease_false_changes_axes(self): self.plotmethod(xincrease=False, yincrease=False) xlim = plt.gca().get_xlim() @@ -796,10 +864,13 @@ def test_plot_nans(self): clim2 = self.plotfunc(x2).get_clim() assert clim1 == clim2 + @pytest.mark.filterwarnings('ignore::UserWarning') + @pytest.mark.filterwarnings('ignore:invalid value encountered') def test_can_plot_all_nans(self): # regression test for issue #1780 self.plotfunc(DataArray(np.full((2, 2), np.nan))) + @pytest.mark.filterwarnings('ignore: Attempting to set') def test_can_plot_axis_size_one(self): if self.plotfunc.__name__ not in ('contour', 'contourf'): self.plotfunc(DataArray(np.ones((1, 1)))) @@ -991,6 +1062,7 @@ def test_2d_function_and_method_signature_same(self): del func_sig['darray'] assert func_sig == method_sig + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid(self): a = easy_array((10, 15, 4)) d = DataArray(a, dims=['y', 'x', 'z']) @@ -1022,6 +1094,7 @@ def test_convenient_facetgrid(self): else: assert '' == ax.get_xlabel() + @pytest.mark.filterwarnings('ignore:tight_layout cannot') def test_convenient_facetgrid_4d(self): a = easy_array((10, 15, 2, 3)) d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) @@ -1031,6 +1104,19 @@ def test_convenient_facetgrid_4d(self): for ax in g.axes.flat: assert ax.has_data() + @pytest.mark.filterwarnings('ignore:This figure includes') + def test_facetgrid_map_only_appends_mappables(self): + a = easy_array((10, 15, 2, 3)) + d = DataArray(a, dims=['y', 'x', 'columns', 'rows']) + g = self.plotfunc(d, x='x', y='y', col='columns', row='rows') + + expected = g._mappables + + g.map(lambda: plt.plot(1, 1)) + actual = g._mappables + + assert expected == actual + def test_facetgrid_cmap(self): # Regression test for GH592 data = (np.random.random(size=(20, 25, 12)) + np.linspace(-3, 3, 12)) @@ -1051,6 +1137,15 @@ def test_2d_coord_with_interval(self): for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: getattr(gp.plot, kind)() + def test_colormap_error_norm_and_vmin_vmax(self): + norm = mpl.colors.LogNorm(0.1, 1e1) + + with pytest.raises(ValueError): + self.darray.plot(norm=norm, vmin=2) + + with pytest.raises(ValueError): + self.darray.plot(norm=norm, vmax=2) + @pytest.mark.slow class TestContourf(Common2dMixin, PlotTestCase): @@ -1113,23 +1208,23 @@ def test_colors(self): def _color_as_tuple(c): return tuple(c[:3]) + # with single color, we don't want rgb array artist = self.plotmethod(colors='k') - assert _color_as_tuple(artist.cmap.colors[0]) == \ - (0.0, 0.0, 0.0) + assert artist.cmap.colors[0] == 'k' artist = self.plotmethod(colors=['k', 'b']) - assert _color_as_tuple(artist.cmap.colors[1]) == \ - (0.0, 0.0, 1.0) + assert (_color_as_tuple(artist.cmap.colors[1]) == + (0.0, 0.0, 1.0)) artist = self.darray.plot.contour( levels=[-0.5, 0., 0.5, 1.], colors=['k', 'r', 'w', 'b']) - assert _color_as_tuple(artist.cmap.colors[1]) == \ - (1.0, 0.0, 0.0) - assert _color_as_tuple(artist.cmap.colors[2]) == \ - (1.0, 1.0, 1.0) + assert (_color_as_tuple(artist.cmap.colors[1]) == + (1.0, 0.0, 0.0)) + assert (_color_as_tuple(artist.cmap.colors[2]) == + (1.0, 1.0, 1.0)) # the last color is now under "over" - assert _color_as_tuple(artist.cmap._rgba_over) == \ - (0.0, 0.0, 1.0) + assert (_color_as_tuple(artist.cmap._rgba_over) == + (0.0, 0.0, 1.0)) def test_cmap_and_color_both(self): with pytest.raises(ValueError): @@ -1306,13 +1401,26 @@ def test_imshow_rgb_values_in_valid_range(self): assert out.dtype == np.uint8 assert (out[..., :3] == da.values).all() # Compare without added alpha + @pytest.mark.filterwarnings('ignore:Several dimensions of this array') def test_regression_rgb_imshow_dim_size_one(self): # Regression: https://github.com/pydata/xarray/issues/1966 da = DataArray(easy_array((1, 3, 3), start=0.0, stop=1.0)) da.plot.imshow() + def test_origin_overrides_xyincrease(self): + da = DataArray(easy_array((3, 2)), coords=[[-2, 0, 2], [-1, 1]]) + da.plot.imshow(origin='upper') + assert plt.xlim()[0] < 0 + assert plt.ylim()[1] < 0 + + plt.clf() + da.plot.imshow(origin='lower') + assert plt.xlim()[0] < 0 + assert plt.ylim()[0] < 0 + class TestFacetGrid(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): d = easy_array((10, 15, 3)) self.darray = DataArray( @@ -1484,7 +1592,9 @@ def test_num_ticks(self): @pytest.mark.slow def test_map(self): + assert self.g._finalized is False self.g.map(plt.contourf, 'x', 'y', Ellipsis) + assert self.g._finalized is True self.g.map(lambda: None) @pytest.mark.slow @@ -1538,7 +1648,9 @@ def test_facetgrid_polar(self): sharey=False) +@pytest.mark.filterwarnings('ignore:tight_layout cannot') class TestFacetGrid4d(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): a = easy_array((10, 15, 3, 2)) darray = DataArray(a, dims=['y', 'x', 'col', 'row']) @@ -1565,7 +1677,9 @@ def test_default_labels(self): assert substring_in_axes(label, ax) +@pytest.mark.filterwarnings('ignore:tight_layout cannot') class TestFacetedLinePlots(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): self.darray = DataArray(np.random.randn(10, 6, 3, 4), dims=['hue', 'x', 'col', 'row'], @@ -1646,6 +1760,7 @@ def test_wrong_num_of_dimensions(self): class TestDatetimePlot(PlotTestCase): + @pytest.fixture(autouse=True) def setUp(self): ''' Create a DataArray with a time-axis that contains datetime objects. diff --git a/xarray/tests/test_tutorial.py b/xarray/tests/test_tutorial.py index d550a85e8ce..083ec5ee72f 100644 --- a/xarray/tests/test_tutorial.py +++ b/xarray/tests/test_tutorial.py @@ -2,15 +2,17 @@ import os +import pytest + from xarray import DataArray, tutorial from xarray.core.pycompat import suppress -from . import TestCase, assert_identical, network +from . import assert_identical, network @network -class TestLoadDataset(TestCase): - +class TestLoadDataset(object): + @pytest.fixture(autouse=True) def setUp(self): self.testfile = 'tiny' self.testfilepath = os.path.expanduser(os.sep.join( diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index 195bb36e36e..6941efb1c6e 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -8,9 +8,9 @@ import xarray as xr import xarray.ufuncs as xu -from . import ( - assert_array_equal, assert_identical as assert_identical_, mock, - raises_regex, requires_np113) +from . import assert_array_equal +from . import assert_identical as assert_identical_ +from . import mock, raises_regex, requires_np113 def assert_identical(a, b): diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index ed8045b78e4..34f401dd243 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -6,19 +6,21 @@ import pandas as pd import pytest +import xarray as xr from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import duck_array_ops, utils from xarray.core.options import set_options from xarray.core.pycompat import OrderedDict from xarray.core.utils import either_dict_or_kwargs +from xarray.testing import assert_identical from . import ( - TestCase, assert_array_equal, has_cftime, has_cftime_or_netCDF4, + assert_array_equal, has_cftime, has_cftime_or_netCDF4, requires_cftime, requires_dask) from .test_coding_times import _all_cftime_date_types -class TestAlias(TestCase): +class TestAlias(object): def test(self): def new_method(): pass @@ -96,7 +98,7 @@ def test_multiindex_from_product_levels_non_unique(): np.testing.assert_array_equal(result.levels[1], [1, 2]) -class TestArrayEquiv(TestCase): +class TestArrayEquiv(object): def test_0d(self): # verify our work around for pd.isnull not working for 0-dimensional # object arrays @@ -106,8 +108,9 @@ def test_0d(self): assert not duck_array_ops.array_equiv(0, np.array(1, dtype=object)) -class TestDictionaries(TestCase): - def setUp(self): +class TestDictionaries(object): + @pytest.fixture(autouse=True) + def setup(self): self.x = {'a': 'A', 'b': 'B'} self.y = {'c': 'C', 'b': 'B'} self.z = {'a': 'Z'} @@ -174,7 +177,7 @@ def test_frozen(self): def test_sorted_keys_dict(self): x = {'a': 1, 'b': 2, 'c': 3} y = utils.SortedKeysDict(x) - self.assertItemsEqual(y, ['a', 'b', 'c']) + assert list(y) == ['a', 'b', 'c'] assert repr(utils.SortedKeysDict()) == \ "SortedKeysDict({})" @@ -189,7 +192,7 @@ def test_chain_map(self): m['x'] = 100 assert m['x'] == 100 assert m.maps[0]['x'] == 100 - self.assertItemsEqual(['x', 'y', 'z'], m) + assert set(m) == {'x', 'y', 'z'} def test_repr_object(): @@ -197,7 +200,7 @@ def test_repr_object(): assert repr(obj) == 'foo' -class Test_is_uniform_and_sorted(TestCase): +class Test_is_uniform_and_sorted(object): def test_sorted_uniform(self): assert utils.is_uniform_spaced(np.arange(5)) @@ -218,7 +221,7 @@ def test_relative_tolerance(self): assert utils.is_uniform_spaced([0, 0.97, 2], rtol=0.1) -class Test_hashable(TestCase): +class Test_hashable(object): def test_hashable(self): for v in [False, 1, (2, ), (3, 4), 'four']: @@ -263,3 +266,42 @@ def test_either_dict_or_kwargs(): with pytest.raises(ValueError, match=r'foo'): result = either_dict_or_kwargs(dict(a=1), dict(a=1), 'foo') + + +def test_datetime_to_numeric_datetime64(): + times = pd.date_range('2000', periods=5, freq='7D') + da = xr.DataArray(times, coords=[times], dims=['time']) + result = utils.datetime_to_numeric(da, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(0, 35, 7), coords=da.coords) + assert_identical(result, expected) + + offset = da.isel(time=1) + result = utils.datetime_to_numeric(da, offset=offset, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(-7, 28, 7), coords=da.coords) + assert_identical(result, expected) + + dtype = np.float32 + result = utils.datetime_to_numeric(da, datetime_unit='h', dtype=dtype) + expected = 24 * xr.DataArray( + np.arange(0, 35, 7), coords=da.coords).astype(dtype) + assert_identical(result, expected) + + +@requires_cftime +def test_datetime_to_numeric_cftime(): + times = xr.cftime_range('2000', periods=5, freq='7D') + da = xr.DataArray(times, coords=[times], dims=['time']) + result = utils.datetime_to_numeric(da, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(0, 35, 7), coords=da.coords) + assert_identical(result, expected) + + offset = da.isel(time=1) + result = utils.datetime_to_numeric(da, offset=offset, datetime_unit='h') + expected = 24 * xr.DataArray(np.arange(-7, 28, 7), coords=da.coords) + assert_identical(result, expected) + + dtype = np.float32 + result = utils.datetime_to_numeric(da, datetime_unit='h', dtype=dtype) + expected = 24 * xr.DataArray( + np.arange(0, 35, 7), coords=da.coords).astype(dtype) + assert_identical(result, expected) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index cdb578aff6c..52289a15d72 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1,12 +1,11 @@ from __future__ import absolute_import, division, print_function -from collections import namedtuple +import warnings from copy import copy, deepcopy from datetime import datetime, timedelta from distutils.version import LooseVersion from textwrap import dedent -import warnings import numpy as np import pandas as pd @@ -26,11 +25,11 @@ from xarray.tests import requires_bottleneck from . import ( - TestCase, assert_allclose, assert_array_equal, assert_equal, - assert_identical, raises_regex, requires_dask, source_ndarray) + assert_allclose, assert_array_equal, assert_equal, assert_identical, + raises_regex, requires_dask, source_ndarray) -class VariableSubclassTestCases(object): +class VariableSubclassobjects(object): def test_properties(self): data = 0.5 * np.arange(10) v = self.cls(['time'], data, {'foo': 'bar'}) @@ -480,20 +479,20 @@ def test_concat_mixed_dtypes(self): assert_identical(expected, actual) assert actual.dtype == object - def test_copy(self): + @pytest.mark.parametrize('deep', [True, False]) + def test_copy(self, deep): v = self.cls('x', 0.5 * np.arange(10), {'foo': 'bar'}) - for deep in [True, False]: - w = v.copy(deep=deep) - assert type(v) is type(w) - assert_identical(v, w) - assert v.dtype == w.dtype - if self.cls is Variable: - if deep: - assert source_ndarray(v.values) is not \ - source_ndarray(w.values) - else: - assert source_ndarray(v.values) is \ - source_ndarray(w.values) + w = v.copy(deep=deep) + assert type(v) is type(w) + assert_identical(v, w) + assert v.dtype == w.dtype + if self.cls is Variable: + if deep: + assert (source_ndarray(v.values) is not + source_ndarray(w.values)) + else: + assert (source_ndarray(v.values) is + source_ndarray(w.values)) assert_identical(v, copy(v)) def test_copy_index(self): @@ -506,6 +505,34 @@ def test_copy_index(self): assert isinstance(w.to_index(), pd.MultiIndex) assert_array_equal(v._data.array, w._data.array) + def test_copy_with_data(self): + orig = Variable(('x', 'y'), [[1.5, 2.0], [3.1, 4.3]], {'foo': 'bar'}) + new_data = np.array([[2.5, 5.0], [7.1, 43]]) + actual = orig.copy(data=new_data) + expected = orig.copy() + expected.data = new_data + assert_identical(expected, actual) + + def test_copy_with_data_errors(self): + orig = Variable(('x', 'y'), [[1.5, 2.0], [3.1, 4.3]], {'foo': 'bar'}) + new_data = [2.5, 5.0] + with raises_regex(ValueError, 'must match shape of object'): + orig.copy(data=new_data) + + def test_copy_index_with_data(self): + orig = IndexVariable('x', np.arange(5)) + new_data = np.arange(5, 10) + actual = orig.copy(data=new_data) + expected = orig.copy() + expected.data = new_data + assert_identical(expected, actual) + + def test_copy_index_with_data_errors(self): + orig = IndexVariable('x', np.arange(5)) + new_data = np.arange(5, 20) + with raises_regex(ValueError, 'must match shape of object'): + orig.copy(data=new_data) + def test_real_and_imag(self): v = self.cls('x', np.arange(3) - 1j * np.arange(3), {'foo': 'bar'}) expected_re = self.cls('x', np.arange(3), {'foo': 'bar'}) @@ -787,10 +814,11 @@ def test_rolling_window(self): v_loaded[0] = 1.0 -class TestVariable(TestCase, VariableSubclassTestCases): +class TestVariable(VariableSubclassobjects): cls = staticmethod(Variable) - def setUp(self): + @pytest.fixture(autouse=True) + def setup(self): self.d = np.random.random((10, 3)).astype(np.float64) def test_data_and_values(self): @@ -938,21 +966,6 @@ def test_as_variable(self): assert not isinstance(ds['x'], Variable) assert isinstance(as_variable(ds['x']), Variable) - FakeVariable = namedtuple('FakeVariable', 'values dims') - fake_xarray = FakeVariable(expected.values, expected.dims) - assert_identical(expected, as_variable(fake_xarray)) - - FakeVariable = namedtuple('FakeVariable', 'data dims') - fake_xarray = FakeVariable(expected.data, expected.dims) - assert_identical(expected, as_variable(fake_xarray)) - - FakeVariable = namedtuple('FakeVariable', - 'data values dims attrs encoding') - fake_xarray = FakeVariable(expected_extra.data, expected_extra.values, - expected_extra.dims, expected_extra.attrs, - expected_extra.encoding) - assert_identical(expected_extra, as_variable(fake_xarray)) - xarray_tuple = (expected_extra.dims, expected_extra.values, expected_extra.attrs, expected_extra.encoding) assert_identical(expected_extra, as_variable(xarray_tuple)) @@ -1503,8 +1516,8 @@ def test_reduce_funcs(self): assert_identical(v.all(dim='x'), Variable([], False)) v = Variable('t', pd.date_range('2000-01-01', periods=3)) - with pytest.raises(NotImplementedError): - v.argmax(skipna=True) + assert v.argmax(skipna=True) == 2 + assert_identical( v.max(), Variable([], pd.Timestamp('2000-01-03'))) @@ -1639,7 +1652,7 @@ def assert_assigned_2d(array, key_x, key_y, values): @requires_dask -class TestVariableWithDask(TestCase, VariableSubclassTestCases): +class TestVariableWithDask(VariableSubclassobjects): cls = staticmethod(lambda *args: Variable(*args).chunk()) @pytest.mark.xfail @@ -1667,7 +1680,7 @@ def test_getitem_1d_fancy(self): def test_equals_all_dtypes(self): import dask - if '0.18.2' <= LooseVersion(dask.__version__) < '0.18.3': + if '0.18.2' <= LooseVersion(dask.__version__) < '0.19.1': pytest.xfail('https://github.com/pydata/xarray/issues/2318') super(TestVariableWithDask, self).test_equals_all_dtypes() @@ -1679,7 +1692,7 @@ def test_getitem_with_mask_nd_indexer(self): self.cls(('x', 'y'), [[0, -1], [-1, 2]])) -class TestIndexVariable(TestCase, VariableSubclassTestCases): +class TestIndexVariable(VariableSubclassobjects): cls = staticmethod(IndexVariable) def test_init(self): @@ -1792,7 +1805,7 @@ def test_rolling_window(self): super(TestIndexVariable, self).test_rolling_window() -class TestAsCompatibleData(TestCase): +class TestAsCompatibleData(object): def test_unchanged_types(self): types = (np.asarray, PandasIndexAdapter, LazilyOuterIndexedArray) for t in types: @@ -1933,9 +1946,10 @@ def test_raise_no_warning_for_nan_in_binary_ops(): assert len(record) == 0 -class TestBackendIndexing(TestCase): +class TestBackendIndexing(object): """ Make sure all the array wrappers can be indexed. """ + @pytest.fixture(autouse=True) def setUp(self): self.d = np.random.random((10, 3)).astype(np.float64)