diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 5e9aa06f507..d1c79953a9b 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -1,4 +1,3 @@
- [ ] Closes #xxxx (remove if there is no corresponding issue, which should only be the case for minor changes)
- [ ] Tests added (for all bug fixes or enhancements)
- - [ ] Tests passed (for all non-documentation changes)
- [ ] Fully documented, including `whats-new.rst` for all changes and `api.rst` for new API (remove if this change should not be visible to users, e.g., if it is an internal clean-up, or if this is part of a larger project that will be documented later)
diff --git a/.pep8speaks.yml b/.pep8speaks.yml
new file mode 100644
index 00000000000..cd610907007
--- /dev/null
+++ b/.pep8speaks.yml
@@ -0,0 +1,11 @@
+# File : .pep8speaks.yml
+
+scanner:
+ diff_only: True # If True, errors caused by only the patch are shown
+
+pycodestyle:
+ max-line-length: 79
+ ignore: # Errors and warnings to ignore
+ - E402, # module level import not at top of file
+ - E731, # do not assign a lambda expression, use a def
+ - W503 # line break before binary operator
diff --git a/.stickler.yml b/.stickler.yml
deleted file mode 100644
index 79d8b7fb717..00000000000
--- a/.stickler.yml
+++ /dev/null
@@ -1,11 +0,0 @@
-linters:
- flake8:
- max-line-length: 79
- fixer: false
- ignore: I002
- # stickler doesn't support 'exclude' for flake8 properly, so we disable it
- # below with files.ignore:
- # https://github.com/markstory/lint-review/issues/184
-files:
- ignore:
- - doc/**/*.py
diff --git a/.travis.yml b/.travis.yml
index 951b151d829..defb37ec8aa 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,5 +1,5 @@
# Based on http://conda.pydata.org/docs/travis.html
-language: python
+language: minimal
sudo: false # use container based build
notifications:
email: false
@@ -10,72 +10,48 @@ branches:
matrix:
fast_finish: true
include:
- - python: 2.7
- env: CONDA_ENV=py27-min
- - python: 2.7
- env: CONDA_ENV=py27-cdat+iris+pynio
- - python: 3.5
- env: CONDA_ENV=py35
- - python: 3.6
- env: CONDA_ENV=py36
- - python: 3.6
- env:
+ - env: CONDA_ENV=py27-min
+ - env: CONDA_ENV=py27-cdat+iris+pynio
+ - env: CONDA_ENV=py35
+ - env: CONDA_ENV=py36
+ - env: CONDA_ENV=py37
+ - env:
- CONDA_ENV=py36
- EXTRA_FLAGS="--run-flaky --run-network-tests"
- - python: 3.6
- env: CONDA_ENV=py36-netcdf4-dev
+ - env: CONDA_ENV=py36-netcdf4-dev
addons:
apt_packages:
- libhdf5-serial-dev
- netcdf-bin
- libnetcdf-dev
- - python: 3.6
- env: CONDA_ENV=py36-dask-dev
- - python: 3.6
- env: CONDA_ENV=py36-pandas-dev
- - python: 3.6
- env: CONDA_ENV=py36-bottleneck-dev
- - python: 3.6
- env: CONDA_ENV=py36-condaforge-rc
- - python: 3.6
- env: CONDA_ENV=py36-pynio-dev
- - python: 3.6
- env: CONDA_ENV=py36-rasterio-0.36
- - python: 3.6
- env: CONDA_ENV=py36-zarr-dev
- - python: 3.5
- env: CONDA_ENV=docs
- - python: 3.6
- env: CONDA_ENV=py36-hypothesis
+ - env: CONDA_ENV=py36-dask-dev
+ - env: CONDA_ENV=py36-pandas-dev
+ - env: CONDA_ENV=py36-bottleneck-dev
+ - env: CONDA_ENV=py36-condaforge-rc
+ - env: CONDA_ENV=py36-pynio-dev
+ - env: CONDA_ENV=py36-rasterio-0.36
+ - env: CONDA_ENV=py36-zarr-dev
+ - env: CONDA_ENV=docs
+ - env: CONDA_ENV=py36-hypothesis
+
allow_failures:
- - python: 3.6
- env:
+ - env:
- CONDA_ENV=py36
- EXTRA_FLAGS="--run-flaky --run-network-tests"
- - python: 3.6
- env: CONDA_ENV=py36-netcdf4-dev
+ - env: CONDA_ENV=py36-netcdf4-dev
addons:
apt_packages:
- libhdf5-serial-dev
- netcdf-bin
- libnetcdf-dev
- - python: 3.6
- env: CONDA_ENV=py36-pandas-dev
- - python: 3.6
- env: CONDA_ENV=py36-bottleneck-dev
- - python: 3.6
- env: CONDA_ENV=py36-condaforge-rc
- - python: 3.6
- env: CONDA_ENV=py36-pynio-dev
- - python: 3.6
- env: CONDA_ENV=py36-zarr-dev
+ - env: CONDA_ENV=py36-pandas-dev
+ - env: CONDA_ENV=py36-bottleneck-dev
+ - env: CONDA_ENV=py36-condaforge-rc
+ - env: CONDA_ENV=py36-pynio-dev
+ - env: CONDA_ENV=py36-zarr-dev
before_install:
- - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
- wget http://repo.continuum.io/miniconda/Miniconda-3.16.0-Linux-x86_64.sh -O miniconda.sh;
- else
- wget http://repo.continuum.io/miniconda/Miniconda3-3.16.0-Linux-x86_64.sh -O miniconda.sh;
- fi
+ - wget http://repo.continuum.io/miniconda/Miniconda3-3.16.0-Linux-x86_64.sh -O miniconda.sh;
- bash miniconda.sh -b -p $HOME/miniconda
- export PATH="$HOME/miniconda/bin:$PATH"
- hash -r
@@ -95,9 +71,9 @@ install:
- python xarray/util/print_versions.py
script:
- # TODO: restore this check once the upstream pandas issue is fixed:
- # https://github.com/pandas-dev/pandas/issues/21071
- # - python -OO -c "import xarray"
+ - which python
+ - python --version
+ - python -OO -c "import xarray"
- if [[ "$CONDA_ENV" == "docs" ]]; then
conda install -c conda-forge sphinx sphinx_rtd_theme sphinx-gallery numpydoc;
sphinx-build -n -j auto -b html -d _build/doctrees doc _build/html;
diff --git a/README.rst b/README.rst
index 94beea1dba4..0ac71d33954 100644
--- a/README.rst
+++ b/README.rst
@@ -15,6 +15,8 @@ xarray: N-D labeled arrays and datasets
:target: https://zenodo.org/badge/latestdoi/13221727
.. image:: http://img.shields.io/badge/benchmarked%20by-asv-green.svg?style=flat
:target: http://pandas.pydata.org/speed/xarray/
+.. image:: https://img.shields.io/badge/powered%20by-NumFOCUS-orange.svg?style=flat&colorA=E1523D&colorB=007D8A
+ :target: http://numfocus.org
**xarray** (formerly **xray**) is an open source project and Python package that aims to bring the
labeled data power of pandas_ to the physical sciences, by providing
@@ -103,20 +105,36 @@ Get in touch
.. _mailing list: https://groups.google.com/forum/#!forum/xarray
.. _on GitHub: http://github.com/pydata/xarray
+NumFOCUS
+--------
+
+.. image:: https://numfocus.org/wp-content/uploads/2017/07/NumFocus_LRG.png
+ :scale: 25 %
+ :target: https://numfocus.org/
+
+Xarray is a fiscally sponsored project of NumFOCUS_, a nonprofit dedicated
+to supporting the open source scientific computing community. If you like
+Xarray and want to support our mission, please consider making a donation_
+to support our efforts.
+
+.. _donation: https://www.flipcause.com/secure/cause_pdetails/NDE2NTU=
+
History
-------
xarray is an evolution of an internal tool developed at `The Climate
Corporation`__. It was originally written by Climate Corp researchers Stephan
Hoyer, Alex Kleeman and Eugene Brevdo and was released as open source in
-May 2014. The project was renamed from "xray" in January 2016.
+May 2014. The project was renamed from "xray" in January 2016. Xarray became a
+fiscally sponsored project of NumFOCUS_ in August 2018.
__ http://climate.com/
+.. _NumFOCUS: https://numfocus.org
License
-------
-Copyright 2014-2017, xarray Developers
+Copyright 2014-2018, xarray Developers
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json
index b5953436387..e3933b400e6 100644
--- a/asv_bench/asv.conf.json
+++ b/asv_bench/asv.conf.json
@@ -64,6 +64,7 @@
"scipy": [""],
"bottleneck": ["", null],
"dask": [""],
+ "distributed": [""],
},
diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py
index 54ed9ac9fa2..da18d541a16 100644
--- a/asv_bench/benchmarks/dataset_io.py
+++ b/asv_bench/benchmarks/dataset_io.py
@@ -1,11 +1,13 @@
from __future__ import absolute_import, division, print_function
+import os
+
import numpy as np
import pandas as pd
import xarray as xr
-from . import randn, randint, requires_dask
+from . import randint, randn, requires_dask
try:
import dask
@@ -14,6 +16,9 @@
pass
+os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
+
+
class IOSingleNetCDF(object):
"""
A few examples that benchmark reading/writing a single netCDF file with
@@ -405,3 +410,39 @@ def time_open_dataset_scipy_with_time_chunks(self):
with dask.set_options(get=dask.multiprocessing.get):
xr.open_mfdataset(self.filenames_list, engine='scipy',
chunks=self.time_chunks)
+
+
+def create_delayed_write():
+ import dask.array as da
+ vals = da.random.random(300, chunks=(1,))
+ ds = xr.Dataset({'vals': (['a'], vals)})
+ return ds.to_netcdf('file.nc', engine='netcdf4', compute=False)
+
+
+class IOWriteNetCDFDask(object):
+ timeout = 60
+ repeat = 1
+ number = 5
+
+ def setup(self):
+ requires_dask()
+ self.write = create_delayed_write()
+
+ def time_write(self):
+ self.write.compute()
+
+
+class IOWriteNetCDFDaskDistributed(object):
+ def setup(self):
+ try:
+ import distributed
+ except ImportError:
+ raise NotImplementedError
+ self.client = distributed.Client()
+ self.write = create_delayed_write()
+
+ def cleanup(self):
+ self.client.shutdown()
+
+ def time_write(self):
+ self.write.compute()
diff --git a/asv_bench/benchmarks/unstacking.py b/asv_bench/benchmarks/unstacking.py
new file mode 100644
index 00000000000..54436b422e9
--- /dev/null
+++ b/asv_bench/benchmarks/unstacking.py
@@ -0,0 +1,26 @@
+from __future__ import absolute_import, division, print_function
+
+import numpy as np
+
+import xarray as xr
+
+from . import requires_dask
+
+
+class Unstacking(object):
+ def setup(self):
+ data = np.random.RandomState(0).randn(1, 1000, 500)
+ self.ds = xr.DataArray(data).stack(flat_dim=['dim_1', 'dim_2'])
+
+ def time_unstack_fast(self):
+ self.ds.unstack('flat_dim')
+
+ def time_unstack_slow(self):
+ self.ds[:, ::-1].unstack('flat_dim')
+
+
+class UnstackingDask(Unstacking):
+ def setup(self, *args, **kwargs):
+ requires_dask()
+ super(UnstackingDask, self).setup(**kwargs)
+ self.ds = self.ds.chunk({'flat_dim': 50})
diff --git a/ci/requirements-py37.yml b/ci/requirements-py37.yml
new file mode 100644
index 00000000000..5f973936f63
--- /dev/null
+++ b/ci/requirements-py37.yml
@@ -0,0 +1,13 @@
+name: test_env
+channels:
+ - defaults
+dependencies:
+ - python=3.7
+ - pip:
+ - pytest
+ - flake8
+ - mock
+ - numpy
+ - pandas
+ - coveralls
+ - pytest-cov
diff --git a/doc/_static/numfocus_logo.png b/doc/_static/numfocus_logo.png
new file mode 100644
index 00000000000..af3c84209e0
Binary files /dev/null and b/doc/_static/numfocus_logo.png differ
diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst
index 1826cc86892..0e8143c72ea 100644
--- a/doc/api-hidden.rst
+++ b/doc/api-hidden.rst
@@ -151,3 +151,5 @@
plot.FacetGrid.set_titles
plot.FacetGrid.set_ticks
plot.FacetGrid.map
+
+ CFTimeIndex.shift
diff --git a/doc/api.rst b/doc/api.rst
index 927c0aa072c..662ef567710 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -150,6 +150,7 @@ Computation
Dataset.resample
Dataset.diff
Dataset.quantile
+ Dataset.differentiate
**Aggregation**:
:py:attr:`~Dataset.all`
@@ -317,6 +318,7 @@ Computation
DataArray.diff
DataArray.dot
DataArray.quantile
+ DataArray.differentiate
**Aggregation**:
:py:attr:`~DataArray.all`
@@ -555,6 +557,13 @@ Custom Indexes
CFTimeIndex
+Creating custom indexes
+-----------------------
+.. autosummary::
+ :toctree: generated/
+
+ cftime_range
+
Plotting
========
@@ -615,3 +624,6 @@ arguments for the ``from_store`` and ``dump_to_store`` Dataset methods:
backends.H5NetCDFStore
backends.PydapDataStore
backends.ScipyDataStore
+ backends.FileManager
+ backends.CachingFileManager
+ backends.DummyFileManager
diff --git a/doc/computation.rst b/doc/computation.rst
index 6793e667e06..759c87a6cc7 100644
--- a/doc/computation.rst
+++ b/doc/computation.rst
@@ -200,6 +200,31 @@ You can also use ``construct`` to compute a weighted rolling sum:
To avoid this, use ``skipna=False`` as the above example.
+Computation using Coordinates
+=============================
+
+Xarray objects have some handy methods for the computation with their
+coordinates. :py:meth:`~xarray.DataArray.differentiate` computes derivatives by
+central finite differences using their coordinates,
+
+.. ipython:: python
+
+ a = xr.DataArray([0, 1, 2, 3], dims=['x'], coords=[[0.1, 0.11, 0.2, 0.3]])
+ a
+ a.differentiate('x')
+
+This method can be used also for multidimensional arrays,
+
+.. ipython:: python
+
+ a = xr.DataArray(np.arange(8).reshape(4, 2), dims=['x', 'y'],
+ coords={'x': [0.1, 0.11, 0.2, 0.3]})
+ a.differentiate('x')
+
+.. note::
+ This method is limited to simple cartesian geometry. Differentiation along
+ multidimensional coordinate is not supported.
+
.. _compute.broadcasting:
Broadcasting by dimension name
diff --git a/doc/data-structures.rst b/doc/data-structures.rst
index 10d83ca448f..618ccccff3e 100644
--- a/doc/data-structures.rst
+++ b/doc/data-structures.rst
@@ -408,13 +408,6 @@ operations keep around coordinates:
list(ds[['x']])
list(ds.drop('temperature'))
-If a dimension name is given as an argument to ``drop``, it also drops all
-variables that use that dimension:
-
-.. ipython:: python
-
- list(ds.drop('time'))
-
As an alternate to dictionary-like modifications, you can use
:py:meth:`~xarray.Dataset.assign` and :py:meth:`~xarray.Dataset.assign_coords`.
These methods return a new dataset with additional (or replaced) or values:
diff --git a/doc/index.rst b/doc/index.rst
index e66c448f780..45897f4bccb 100644
--- a/doc/index.rst
+++ b/doc/index.rst
@@ -120,12 +120,20 @@ Get in touch
.. _mailing list: https://groups.google.com/forum/#!forum/xarray
.. _on GitHub: http://github.com/pydata/xarray
-License
--------
+NumFOCUS
+--------
-xarray is available under the open source `Apache License`__.
+.. image:: _static/numfocus_logo.png
+ :scale: 50 %
+ :target: https://numfocus.org/
+
+Xarray is a fiscally sponsored project of NumFOCUS_, a nonprofit dedicated
+to supporting the open source scientific computing community. If you like
+Xarray and want to support our mission, please consider making a donation_
+to support our efforts.
+
+.. _donation: https://www.flipcause.com/secure/cause_pdetails/NDE2NTU=
-__ http://www.apache.org/licenses/LICENSE-2.0.html
History
-------
@@ -133,6 +141,15 @@ History
xarray is an evolution of an internal tool developed at `The Climate
Corporation`__. It was originally written by Climate Corp researchers Stephan
Hoyer, Alex Kleeman and Eugene Brevdo and was released as open source in
-May 2014. The project was renamed from "xray" in January 2016.
+May 2014. The project was renamed from "xray" in January 2016. Xarray became a
+fiscally sponsored project of NumFOCUS_ in August 2018.
__ http://climate.com/
+.. _NumFOCUS: https://numfocus.org
+
+License
+-------
+
+xarray is available under the open source `Apache License`__.
+
+__ http://www.apache.org/licenses/LICENSE-2.0.html
diff --git a/doc/installing.rst b/doc/installing.rst
index 85cd5a02568..eb74eb7162b 100644
--- a/doc/installing.rst
+++ b/doc/installing.rst
@@ -6,7 +6,7 @@ Installation
Required dependencies
---------------------
-- Python 2.7 [1]_, 3.5, or 3.6
+- Python 2.7 [1]_, 3.5, 3.6, or 3.7
- `numpy `__ (1.12 or later)
- `pandas `__ (0.19.2 or later)
diff --git a/doc/interpolation.rst b/doc/interpolation.rst
index e5230e95dae..10e46331d0a 100644
--- a/doc/interpolation.rst
+++ b/doc/interpolation.rst
@@ -63,6 +63,9 @@ by specifing the time periods required.
da_dt64.interp(time=pd.date_range('1/1/2000', '1/3/2000', periods=3))
+Interpolation of data indexed by a :py:class:`~xarray.CFTimeIndex` is also
+allowed. See :ref:`CFTimeIndex` for examples.
+
.. note::
Currently, our interpolation only works for regular grids.
diff --git a/doc/related-projects.rst b/doc/related-projects.rst
index 9b75d0e1b3e..524ea3b9d8d 100644
--- a/doc/related-projects.rst
+++ b/doc/related-projects.rst
@@ -35,7 +35,6 @@ Geosciences
- `xgcm `_: Extends the xarray data model to understand finite volume grid cells (common in General Circulation Models) and provides interpolation and difference operations for such grids.
- `xmitgcm `_: a python package for reading `MITgcm `_ binary MDS files into xarray data structures.
- `xshape `_: Tools for working with shapefiles, topographies, and polygons in xarray.
-- `xskillscore `_: Metrics for verifying forecasts.
Machine Learning
~~~~~~~~~~~~~~~~
@@ -48,10 +47,11 @@ Extend xarray capabilities
~~~~~~~~~~~~~~~~~~~~~~~~~~
- `Collocate `_: Collocate xarray trajectories in arbitrary physical dimensions
- `eofs `_: EOF analysis in Python.
-- `xarray_extras `_: Advanced algorithms for xarray objects (e.g. intergrations/interpolations).
+- `xarray_extras `_: Advanced algorithms for xarray objects (e.g. integrations/interpolations).
- `xrft `_: Fourier transforms for xarray data.
- `xr-scipy `_: A lightweight scipy wrapper for xarray.
- `X-regression `_: Multiple linear regression from Statsmodels library coupled with Xarray library.
+- `xskillscore `_: Metrics for verifying forecasts.
- `xyzpy `_: Easily generate high dimensional data, including parallelization.
Visualization
diff --git a/doc/roadmap.rst b/doc/roadmap.rst
index 2708cb7cf8f..34d203c3f48 100644
--- a/doc/roadmap.rst
+++ b/doc/roadmap.rst
@@ -1,3 +1,5 @@
+.. _roadmap:
+
Development roadmap
===================
diff --git a/doc/time-series.rst b/doc/time-series.rst
index a7ce9226d4d..c1a686b409f 100644
--- a/doc/time-series.rst
+++ b/doc/time-series.rst
@@ -70,9 +70,9 @@ You can manual decode arrays in this form by passing a dataset to
One unfortunate limitation of using ``datetime64[ns]`` is that it limits the
native representation of dates to those that fall between the years 1678 and
2262. When a netCDF file contains dates outside of these bounds, dates will be
-returned as arrays of ``cftime.datetime`` objects and a ``CFTimeIndex``
-can be used for indexing. The ``CFTimeIndex`` enables only a subset of
-the indexing functionality of a ``pandas.DatetimeIndex`` and is only enabled
+returned as arrays of :py:class:`cftime.datetime` objects and a :py:class:`~xarray.CFTimeIndex`
+can be used for indexing. The :py:class:`~xarray.CFTimeIndex` enables only a subset of
+the indexing functionality of a :py:class:`pandas.DatetimeIndex` and is only enabled
when using the standalone version of ``cftime`` (not the version packaged with
earlier versions ``netCDF4``). See :ref:`CFTimeIndex` for more information.
@@ -219,12 +219,12 @@ Non-standard calendars and dates outside the Timestamp-valid range
------------------------------------------------------------------
Through the standalone ``cftime`` library and a custom subclass of
-``pandas.Index``, xarray supports a subset of the indexing functionality enabled
-through the standard ``pandas.DatetimeIndex`` for dates from non-standard
-calendars or dates using a standard calendar, but outside the
-`Timestamp-valid range`_ (approximately between years 1678 and 2262). This
-behavior has not yet been turned on by default; to take advantage of this
-functionality, you must have the ``enable_cftimeindex`` option set to
+:py:class:`pandas.Index`, xarray supports a subset of the indexing
+functionality enabled through the standard :py:class:`pandas.DatetimeIndex` for
+dates from non-standard calendars or dates using a standard calendar, but
+outside the `Timestamp-valid range`_ (approximately between years 1678 and
+2262). This behavior has not yet been turned on by default; to take advantage
+of this functionality, you must have the ``enable_cftimeindex`` option set to
``True`` within your context (see :py:func:`~xarray.set_options` for more
information). It is expected that this will become the default behavior in
xarray version 0.11.
@@ -232,7 +232,7 @@ xarray version 0.11.
For instance, you can create a DataArray indexed by a time
coordinate with a no-leap calendar within a context manager setting the
``enable_cftimeindex`` option, and the time index will be cast to a
-``CFTimeIndex``:
+:py:class:`~xarray.CFTimeIndex`:
.. ipython:: python
@@ -247,19 +247,28 @@ coordinate with a no-leap calendar within a context manager setting the
.. note::
- With the ``enable_cftimeindex`` option activated, a ``CFTimeIndex``
+ With the ``enable_cftimeindex`` option activated, a :py:class:`~xarray.CFTimeIndex`
will be used for time indexing if any of the following are true:
- The dates are from a non-standard calendar
- Any dates are outside the Timestamp-valid range
- Otherwise a ``pandas.DatetimeIndex`` will be used. In addition, if any
+ Otherwise a :py:class:`pandas.DatetimeIndex` will be used. In addition, if any
variable (not just an index variable) is encoded using a non-standard
- calendar, its times will be decoded into ``cftime.datetime`` objects,
+ calendar, its times will be decoded into :py:class:`cftime.datetime` objects,
regardless of whether or not they can be represented using
``np.datetime64[ns]`` objects.
-
-For data indexed by a ``CFTimeIndex`` xarray currently supports:
+
+xarray also includes a :py:func:`~xarray.cftime_range` function, which enables
+creating a :py:class:`~xarray.CFTimeIndex` with regularly-spaced dates. For instance, we can
+create the same dates and DataArray we created above using:
+
+.. ipython:: python
+
+ dates = xr.cftime_range(start='0001', periods=24, freq='MS', calendar='noleap')
+ da = xr.DataArray(np.arange(24), coords=[dates], dims=['time'], name='foo')
+
+For data indexed by a :py:class:`~xarray.CFTimeIndex` xarray currently supports:
- `Partial datetime string indexing`_ using strictly `ISO 8601-format`_ partial
datetime strings:
@@ -285,7 +294,25 @@ For data indexed by a ``CFTimeIndex`` xarray currently supports:
.. ipython:: python
da.groupby('time.month').sum()
-
+
+- Interpolation using :py:class:`cftime.datetime` objects:
+
+.. ipython:: python
+
+ da.interp(time=[DatetimeNoLeap(1, 1, 15), DatetimeNoLeap(1, 2, 15)])
+
+- Interpolation using datetime strings:
+
+.. ipython:: python
+
+ da.interp(time=['0001-01-15', '0001-02-15'])
+
+- Differentiation:
+
+.. ipython:: python
+
+ da.differentiate('time')
+
- And serialization:
.. ipython:: python
@@ -296,7 +323,7 @@ For data indexed by a ``CFTimeIndex`` xarray currently supports:
.. note::
Currently resampling along the time dimension for data indexed by a
- ``CFTimeIndex`` is not supported.
+ :py:class:`~xarray.CFTimeIndex` is not supported.
.. _Timestamp-valid range: https://pandas.pydata.org/pandas-docs/stable/timeseries.html#timestamp-limitations
.. _ISO 8601-format: https://en.wikipedia.org/wiki/ISO_8601
diff --git a/doc/whats-new.rst b/doc/whats-new.rst
index 718e68af04b..93083c23353 100644
--- a/doc/whats-new.rst
+++ b/doc/whats-new.rst
@@ -25,24 +25,117 @@ What's New
- `Python 3 Statement `__
- `Tips on porting to Python 3 `__
-.. _whats-new.0.10.9:
+.. _whats-new.0.11.0:
-v0.10.9 (unreleased)
+v0.11.0 (unreleased)
--------------------
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+- Xarray's storage backends now automatically open and close files when
+ necessary, rather than requiring opening a file with ``autoclose=True``. A
+ global least-recently-used cache is used to store open files; the default
+ limit of 128 open files should suffice in most cases, but can be adjusted if
+ necessary with
+ ``xarray.set_options(file_cache_maxsize=...)``. The ``autoclose`` argument
+ to ``open_dataset`` and related functions has been deprecated and is now a
+ no-op.
+
+ This change, along with an internal refactor of xarray's storage backends,
+ should significantly improve performance when reading and writing
+ netCDF files with Dask, especially when working with many files or using
+ Dask Distributed. By `Stephan Hoyer `_
+
Documentation
~~~~~~~~~~~~~
+- Reduction of :py:meth:`DataArray.groupby` and :py:meth:`DataArray.resample`
+ without dimension argument will change in the next release.
+ Now we warn a FutureWarning.
+ By `Keisuke Fujii `_.
Enhancements
~~~~~~~~~~~~
-- :py:meth:`plot()` now accepts the kwargs ``xscale, yscale, xlim, ylim, xticks, yticks`` just like Pandas. Also ``xincrease=False, yincrease=False`` now use matplotlib's axis inverting methods instead of setting limits.
+- Added support for Python 3.7. (:issue:`2271`).
+ By `Joe Hamman `_.
+- Added :py:meth:`~xarray.CFTimeIndex.shift` for shifting the values of a
+ CFTimeIndex by a specified frequency. (:issue:`2244`).
+ By `Spencer Clark `_.
+- Added support for using ``cftime.datetime`` coordinates with
+ :py:meth:`~xarray.DataArray.differentiate`,
+ :py:meth:`~xarray.Dataset.differentiate`,
+ :py:meth:`~xarray.DataArray.interp`, and
+ :py:meth:`~xarray.Dataset.interp`.
+ By `Spencer Clark `_
+
+Bug fixes
+~~~~~~~~~
+
+- Addition and subtraction operators used with a CFTimeIndex now preserve the
+ index's type. (:issue:`2244`).
+ By `Spencer Clark `_.
+- ``xarray.DataArray.roll`` correctly handles multidimensional arrays.
+ (:issue:`2445`)
+ By `Keisuke Fujii `_.
+- ``xarray.plot()`` now properly accepts a ``norm`` argument and does not override
+ the norm's ``vmin`` and ``vmax``. (:issue:`2381`)
+ By `Deepak Cherian `_.
+- ``xarray.DataArray.std()`` now correctly accepts ``ddof`` keyword argument.
+ (:issue:`2240`)
+ By `Keisuke Fujii `_.
+- Restore matplotlib's default of plotting dashed negative contours when
+ a single color is passed to ``DataArray.contour()`` e.g. ``colors='k'``.
+ By `Deepak Cherian `_.
+
+
+- Fix a bug that caused some indexing operations on arrays opened with
+ ``open_rasterio`` to error (:issue:`2454`).
+ By `Stephan Hoyer `_.
+
+.. _whats-new.0.10.9:
+
+v0.10.9 (21 September 2018)
+---------------------------
+
+This minor release contains a number of backwards compatible enhancements.
+
+Announcements of note:
+
+- Xarray is now a NumFOCUS fiscally sponsored project! Read
+ `the anouncement `_
+ for more details.
+- We have a new :doc:`roadmap` that outlines our future development plans.
+
+Enhancements
+~~~~~~~~~~~~
+
+- :py:meth:`~xarray.DataArray.differentiate` and
+ :py:meth:`~xarray.Dataset.differentiate` are newly added.
+ (:issue:`1332`)
+ By `Keisuke Fujii `_.
+- Default colormap for sequential and divergent data can now be set via
+ :py:func:`~xarray.set_options()`
+ (:issue:`2394`)
+ By `Julius Busecke `_.
+
+- min_count option is newly supported in :py:meth:`~xarray.DataArray.sum`,
+ :py:meth:`~xarray.DataArray.prod` and :py:meth:`~xarray.Dataset.sum`, and
+ :py:meth:`~xarray.Dataset.prod`.
+ (:issue:`2230`)
+ By `Keisuke Fujii `_.
+
+- :py:meth:`plot()` now accepts the kwargs
+ ``xscale, yscale, xlim, ylim, xticks, yticks`` just like Pandas. Also ``xincrease=False, yincrease=False`` now use matplotlib's axis inverting methods instead of setting limits.
By `Deepak Cherian `_. (:issue:`2224`)
- DataArray coordinates and Dataset coordinates and data variables are
now displayed as `a b ... y z` rather than `a b c d ...`.
(:issue:`1186`)
By `Seth P `_.
+- A new CFTimeIndex-enabled :py:func:`cftime_range` function for use in
+ generating dates from standard or non-standard calendars. By `Spencer Clark
+ `_.
- When interpolating over a ``datetime64`` axis, you can now provide a datetime string instead of a ``datetime64`` object. E.g. ``da.interp(time='1991-02-01')``
(:issue:`2284`)
@@ -52,10 +145,30 @@ Enhancements
(:issue:`2331`)
By `Maximilian Roos `_.
+- Applying ``unstack`` to a large DataArray or Dataset is now much faster if the MultiIndex has not been modified after stacking the indices.
+ (:issue:`1560`)
+ By `Maximilian Maahn `_.
+
+- You can now control whether or not to offset the coordinates when using
+ the ``roll`` method and the current behavior, coordinates rolled by default,
+ raises a deprecation warning unless explicitly setting the keyword argument.
+ (:issue:`1875`)
+ By `Andrew Huang `_.
+
+- You can now call ``unstack`` without arguments to unstack every MultiIndex in a DataArray or Dataset.
+ By `Julia Signell `_.
+
+- Added the ability to pass a data kwarg to ``copy`` to create a new object with the
+ same metadata as the original object but using new values.
+ By `Julia Signell `_.
Bug fixes
~~~~~~~~~
+- ``xarray.plot.imshow()`` correctly uses the ``origin`` argument.
+ (:issue:`2379`)
+ By `Deepak Cherian `_.
+
- Fixed ``DataArray.to_iris()`` failure while creating ``DimCoord`` by
falling back to creating ``AuxCoord``. Fixed dependency on ``var_name``
attribute being set.
@@ -69,6 +182,9 @@ Bug fixes
- Tests can be run in parallel with pytest-xdist
By `Tony Tung `_.
+- Follow up the renamings in dask; from dask.ghost to dask.overlap
+ By `Keisuke Fujii `_.
+
- Now raises a ValueError when there is a conflict between dimension names and
level names of MultiIndex. (:issue:`2299`)
By `Keisuke Fujii `_.
@@ -77,10 +193,16 @@ Bug fixes
By `Keisuke Fujii `_.
- Now :py:func:`xr.apply_ufunc` raises a ValueError when the size of
-``input_core_dims`` is inconsistent with the number of arguments.
+ ``input_core_dims`` is inconsistent with the number of arguments.
(:issue:`2341`)
By `Keisuke Fujii `_.
+- Fixed ``Dataset.filter_by_attrs()`` behavior not matching ``netCDF4.Dataset.get_variables_by_attributes()``.
+ When more than one ``key=value`` is passed into ``Dataset.filter_by_attrs()`` it will now return a Dataset with variables which pass
+ all the filters.
+ (:issue:`2315`)
+ By `Andrew Barna `_.
+
.. _whats-new.0.10.8:
v0.10.8 (18 July 2018)
@@ -112,7 +234,6 @@ Enhancements
:py:meth:`~xarray.DataArray.from_cdms2` (:issue:`2262`).
By `Stephane Raynaud `_.
-
Bug fixes
~~~~~~~~~
diff --git a/properties/test_encode_decode.py b/properties/test_encode_decode.py
index 8d84c0f6815..13f63f259cf 100644
--- a/properties/test_encode_decode.py
+++ b/properties/test_encode_decode.py
@@ -6,14 +6,15 @@
"""
from __future__ import absolute_import, division, print_function
-from hypothesis import given, settings
-import hypothesis.strategies as st
import hypothesis.extra.numpy as npst
+import hypothesis.strategies as st
+from hypothesis import given, settings
import xarray as xr
# Run for a while - arrays are a bigger search space than usual
-settings.deadline = None
+settings.register_profile("ci", deadline=None)
+settings.load_profile("ci")
an_array = npst.arrays(
diff --git a/setup.py b/setup.py
index 88c27c95118..a7519bac6da 100644
--- a/setup.py
+++ b/setup.py
@@ -1,10 +1,8 @@
#!/usr/bin/env python
import sys
-from setuptools import find_packages, setup
-
import versioneer
-
+from setuptools import find_packages, setup
DISTNAME = 'xarray'
LICENSE = 'Apache'
@@ -69,5 +67,6 @@
install_requires=INSTALL_REQUIRES,
tests_require=TESTS_REQUIRE,
url=URL,
+ python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*',
packages=find_packages(),
package_data={'xarray': ['tests/data/*']})
diff --git a/versioneer.py b/versioneer.py
index 64fea1c8927..dffd66b69a6 100644
--- a/versioneer.py
+++ b/versioneer.py
@@ -277,10 +277,7 @@
"""
from __future__ import print_function
-try:
- import configparser
-except ImportError:
- import ConfigParser as configparser
+
import errno
import json
import os
@@ -288,6 +285,11 @@
import subprocess
import sys
+try:
+ import configparser
+except ImportError:
+ import ConfigParser as configparser
+
class VersioneerConfig:
"""Container for Versioneer configuration parameters."""
diff --git a/xarray/__init__.py b/xarray/__init__.py
index 7cc7811b783..59a961c6b56 100644
--- a/xarray/__init__.py
+++ b/xarray/__init__.py
@@ -10,7 +10,7 @@
from .core.alignment import align, broadcast, broadcast_arrays
from .core.common import full_like, zeros_like, ones_like
from .core.combine import concat, auto_combine
-from .core.computation import apply_ufunc, where, dot
+from .core.computation import apply_ufunc, dot, where
from .core.extensions import (register_dataarray_accessor,
register_dataset_accessor)
from .core.variable import as_variable, Variable, IndexVariable, Coordinate
@@ -26,6 +26,7 @@
from .conventions import decode_cf, SerializationWarning
+from .coding.cftime_offsets import cftime_range
from .coding.cftimeindex import CFTimeIndex
from .util.print_versions import show_versions
@@ -33,3 +34,5 @@
from . import tutorial
from . import ufuncs
from . import testing
+
+from .core.common import ALL_DIMS
diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py
index 47a2011a3af..a2f0d79a6d1 100644
--- a/xarray/backends/__init__.py
+++ b/xarray/backends/__init__.py
@@ -4,6 +4,7 @@
formats. They should not be used directly, but rather through Dataset objects.
"""
from .common import AbstractDataStore
+from .file_manager import FileManager, CachingFileManager, DummyFileManager
from .memory import InMemoryDataStore
from .netCDF4_ import NetCDF4DataStore
from .pydap_ import PydapDataStore
@@ -15,6 +16,9 @@
__all__ = [
'AbstractDataStore',
+ 'FileManager',
+ 'CachingFileManager',
+ 'DummyFileManager',
'InMemoryDataStore',
'NetCDF4DataStore',
'PydapDataStore',
diff --git a/xarray/backends/api.py b/xarray/backends/api.py
index b2c0df7b01b..65112527045 100644
--- a/xarray/backends/api.py
+++ b/xarray/backends/api.py
@@ -4,6 +4,7 @@
from glob import glob
from io import BytesIO
from numbers import Number
+import warnings
import numpy as np
@@ -12,8 +13,9 @@
from ..core.combine import auto_combine
from ..core.pycompat import basestring, path_type
from ..core.utils import close_on_error, is_remote_uri
-from .common import (
- HDF5_LOCK, ArrayWriter, CombinedLock, _get_scheduler, _get_scheduler_lock)
+from .common import ArrayWriter
+from .locks import _get_scheduler
+
DATAARRAY_NAME = '__xarray_dataarray_name__'
DATAARRAY_VARIABLE = '__xarray_dataarray_variable__'
@@ -52,27 +54,6 @@ def _normalize_path(path):
return os.path.abspath(os.path.expanduser(path))
-def _default_lock(filename, engine):
- if filename.endswith('.gz'):
- lock = False
- else:
- if engine is None:
- engine = _get_default_engine(filename, allow_remote=True)
-
- if engine == 'netcdf4':
- if is_remote_uri(filename):
- lock = False
- else:
- # TODO: identify netcdf3 files and don't use the global lock
- # for them
- lock = HDF5_LOCK
- elif engine in {'h5netcdf', 'pynio'}:
- lock = HDF5_LOCK
- else:
- lock = False
- return lock
-
-
def _validate_dataset_names(dataset):
"""DataArray.name and Dataset keys must be a string or None"""
def check_name(name):
@@ -130,29 +111,14 @@ def _protect_dataset_variables_inplace(dataset, cache):
variable.data = data
-def _get_lock(engine, scheduler, format, path_or_file):
- """ Get the lock(s) that apply to a particular scheduler/engine/format"""
-
- locks = []
- if format in ['NETCDF4', None] and engine in ['h5netcdf', 'netcdf4']:
- locks.append(HDF5_LOCK)
- locks.append(_get_scheduler_lock(scheduler, path_or_file))
-
- # When we have more than one lock, use the CombinedLock wrapper class
- lock = CombinedLock(locks) if len(locks) > 1 else locks[0]
-
- return lock
-
-
def _finalize_store(write, store):
""" Finalize this store by explicitly syncing and closing"""
del write # ensure writing is done first
- store.sync()
store.close()
def open_dataset(filename_or_obj, group=None, decode_cf=True,
- mask_and_scale=None, decode_times=True, autoclose=False,
+ mask_and_scale=None, decode_times=True, autoclose=None,
concat_characters=True, decode_coords=True, engine=None,
chunks=None, lock=None, cache=None, drop_variables=None,
backend_kwargs=None):
@@ -204,12 +170,11 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True,
If chunks is provided, it used to load the new dataset into dask
arrays. ``chunks={}`` loads the dataset with dask using a single
chunk for all arrays.
- lock : False, True or threading.Lock, optional
- If chunks is provided, this argument is passed on to
- :py:func:`dask.array.from_array`. By default, a global lock is
- used when reading data from netCDF files with the netcdf4 and h5netcdf
- engines to avoid issues with concurrent access when using dask's
- multithreaded backend.
+ lock : False or duck threading.Lock, optional
+ Resource lock to use when reading data from disk. Only relevant when
+ using dask or another form of parallelism. By default, appropriate
+ locks are chosen to safely read and write files with the currently
+ active dask scheduler.
cache : bool, optional
If True, cache data loaded from the underlying datastore in memory as
NumPy arrays when accessed to avoid reading from the underlying data-
@@ -235,6 +200,14 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True,
--------
open_mfdataset
"""
+ if autoclose is not None:
+ warnings.warn(
+ 'The autoclose argument is no longer used by '
+ 'xarray.open_dataset() and is now ignored; it will be removed in '
+ 'xarray v0.12. If necessary, you can control the maximum number '
+ 'of simultaneous open files with '
+ 'xarray.set_options(file_cache_maxsize=...).',
+ FutureWarning, stacklevel=2)
if mask_and_scale is None:
mask_and_scale = not engine == 'pseudonetcdf'
@@ -272,18 +245,11 @@ def maybe_decode_store(store, lock=False):
mask_and_scale, decode_times, concat_characters,
decode_coords, engine, chunks, drop_variables)
name_prefix = 'open_dataset-%s' % token
- ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token,
- lock=lock)
+ ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token)
ds2._file_obj = ds._file_obj
else:
ds2 = ds
- # protect so that dataset store isn't necessarily closed, e.g.,
- # streams like BytesIO can't be reopened
- # datastore backend is responsible for determining this capability
- if store._autoclose:
- store.close()
-
return ds2
if isinstance(filename_or_obj, path_type):
@@ -314,36 +280,28 @@ def maybe_decode_store(store, lock=False):
engine = _get_default_engine(filename_or_obj,
allow_remote=True)
if engine == 'netcdf4':
- store = backends.NetCDF4DataStore.open(filename_or_obj,
- group=group,
- autoclose=autoclose,
- **backend_kwargs)
+ store = backends.NetCDF4DataStore.open(
+ filename_or_obj, group=group, lock=lock, **backend_kwargs)
elif engine == 'scipy':
- store = backends.ScipyDataStore(filename_or_obj,
- autoclose=autoclose,
- **backend_kwargs)
+ store = backends.ScipyDataStore(filename_or_obj, **backend_kwargs)
elif engine == 'pydap':
- store = backends.PydapDataStore.open(filename_or_obj,
- **backend_kwargs)
+ store = backends.PydapDataStore.open(
+ filename_or_obj, **backend_kwargs)
elif engine == 'h5netcdf':
- store = backends.H5NetCDFStore(filename_or_obj, group=group,
- autoclose=autoclose,
- **backend_kwargs)
+ store = backends.H5NetCDFStore(
+ filename_or_obj, group=group, lock=lock, **backend_kwargs)
elif engine == 'pynio':
- store = backends.NioDataStore(filename_or_obj,
- autoclose=autoclose,
- **backend_kwargs)
+ store = backends.NioDataStore(
+ filename_or_obj, lock=lock, **backend_kwargs)
elif engine == 'pseudonetcdf':
store = backends.PseudoNetCDFDataStore.open(
- filename_or_obj, autoclose=autoclose, **backend_kwargs)
+ filename_or_obj, lock=lock, **backend_kwargs)
else:
raise ValueError('unrecognized engine for open_dataset: %r'
% engine)
- if lock is None:
- lock = _default_lock(filename_or_obj, engine)
with close_on_error(store):
- return maybe_decode_store(store, lock)
+ return maybe_decode_store(store)
else:
if engine is not None and engine != 'scipy':
raise ValueError('can only read file-like objects with '
@@ -355,7 +313,7 @@ def maybe_decode_store(store, lock=False):
def open_dataarray(filename_or_obj, group=None, decode_cf=True,
- mask_and_scale=None, decode_times=True, autoclose=False,
+ mask_and_scale=None, decode_times=True, autoclose=None,
concat_characters=True, decode_coords=True, engine=None,
chunks=None, lock=None, cache=None, drop_variables=None,
backend_kwargs=None):
@@ -390,10 +348,6 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True,
decode_times : bool, optional
If True, decode times encoded in the standard NetCDF datetime format
into datetime objects. Otherwise, leave them encoded as numbers.
- autoclose : bool, optional
- If True, automatically close files to avoid OS Error of too many files
- being open. However, this option doesn't work with streams, e.g.,
- BytesIO.
concat_characters : bool, optional
If True, concatenate along the last dimension of character arrays to
form string arrays. Dimensions will only be concatenated over (and
@@ -409,12 +363,11 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True,
chunks : int or dict, optional
If chunks is provided, it used to load the new dataset into dask
arrays.
- lock : False, True or threading.Lock, optional
- If chunks is provided, this argument is passed on to
- :py:func:`dask.array.from_array`. By default, a global lock is
- used when reading data from netCDF files with the netcdf4 and h5netcdf
- engines to avoid issues with concurrent access when using dask's
- multithreaded backend.
+ lock : False or duck threading.Lock, optional
+ Resource lock to use when reading data from disk. Only relevant when
+ using dask or another form of parallelism. By default, appropriate
+ locks are chosen to safely read and write files with the currently
+ active dask scheduler.
cache : bool, optional
If True, cache data loaded from the underlying datastore in memory as
NumPy arrays when accessed to avoid reading from the underlying data-
@@ -490,7 +443,7 @@ def close(self):
def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT,
compat='no_conflicts', preprocess=None, engine=None,
lock=None, data_vars='all', coords='different',
- autoclose=False, parallel=False, **kwargs):
+ autoclose=None, parallel=False, **kwargs):
"""Open multiple files as a single dataset.
Requires dask to be installed. See documentation for details on dask [1].
@@ -537,15 +490,11 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT,
Engine to use when reading files. If not provided, the default engine
is chosen based on available dependencies, with a preference for
'netcdf4'.
- autoclose : bool, optional
- If True, automatically close files to avoid OS Error of too many files
- being open. However, this option doesn't work with streams, e.g.,
- BytesIO.
- lock : False, True or threading.Lock, optional
- This argument is passed on to :py:func:`dask.array.from_array`. By
- default, a per-variable lock is used when reading data from netCDF
- files with the netcdf4 and h5netcdf engines to avoid issues with
- concurrent access when using dask's multithreaded backend.
+ lock : False or duck threading.Lock, optional
+ Resource lock to use when reading data from disk. Only relevant when
+ using dask or another form of parallelism. By default, appropriate
+ locks are chosen to safely read and write files with the currently
+ active dask scheduler.
data_vars : {'minimal', 'different', 'all' or list of str}, optional
These data variables will be concatenated together:
* 'minimal': Only data variables in which the dimension already
@@ -604,9 +553,6 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT,
if not paths:
raise IOError('no files to open')
- if lock is None:
- lock = _default_lock(paths[0], engine)
-
open_kwargs = dict(engine=engine, chunks=chunks or {}, lock=lock,
autoclose=autoclose, **kwargs)
@@ -656,19 +602,21 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT,
def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None,
- engine=None, writer=None, encoding=None, unlimited_dims=None,
- compute=True):
+ engine=None, encoding=None, unlimited_dims=None, compute=True,
+ multifile=False):
"""This function creates an appropriate datastore for writing a dataset to
disk as a netCDF file
See `Dataset.to_netcdf` for full API docs.
- The ``writer`` argument is only for the private use of save_mfdataset.
+ The ``multifile`` argument is only for the private use of save_mfdataset.
"""
if isinstance(path_or_file, path_type):
path_or_file = str(path_or_file)
+
if encoding is None:
encoding = {}
+
if path_or_file is None:
if engine is None:
engine = 'scipy'
@@ -676,6 +624,10 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None,
raise ValueError('invalid engine for creating bytes with '
'to_netcdf: %r. Only the default engine '
"or engine='scipy' is supported" % engine)
+ if not compute:
+ raise NotImplementedError(
+ 'to_netcdf() with compute=False is not yet implemented when '
+ 'returning bytes')
elif isinstance(path_or_file, basestring):
if engine is None:
engine = _get_default_engine(path_or_file)
@@ -695,45 +647,78 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None,
if format is not None:
format = format.upper()
- # if a writer is provided, store asynchronously
- sync = writer is None
-
# handle scheduler specific logic
scheduler = _get_scheduler()
have_chunks = any(v.chunks for v in dataset.variables.values())
- if (have_chunks and scheduler in ['distributed', 'multiprocessing'] and
- engine != 'netcdf4'):
+
+ autoclose = have_chunks and scheduler in ['distributed', 'multiprocessing']
+ if autoclose and engine == 'scipy':
raise NotImplementedError("Writing netCDF files with the %s backend "
"is not currently supported with dask's %s "
"scheduler" % (engine, scheduler))
- lock = _get_lock(engine, scheduler, format, path_or_file)
- autoclose = (have_chunks and
- scheduler in ['distributed', 'multiprocessing'])
target = path_or_file if path_or_file is not None else BytesIO()
- store = store_open(target, mode, format, group, writer,
- autoclose=autoclose, lock=lock)
+ kwargs = dict(autoclose=True) if autoclose else {}
+ store = store_open(target, mode, format, group, **kwargs)
if unlimited_dims is None:
unlimited_dims = dataset.encoding.get('unlimited_dims', None)
if isinstance(unlimited_dims, basestring):
unlimited_dims = [unlimited_dims]
+ writer = ArrayWriter()
+
+ # TODO: figure out how to refactor this logic (here and in save_mfdataset)
+ # to avoid this mess of conditionals
try:
- dataset.dump_to_store(store, sync=sync, encoding=encoding,
- unlimited_dims=unlimited_dims, compute=compute)
+ # TODO: allow this work (setting up the file for writing array data)
+ # to be parallelized with dask
+ dump_to_store(dataset, store, writer, encoding=encoding,
+ unlimited_dims=unlimited_dims)
+ if autoclose:
+ store.close()
+
+ if multifile:
+ return writer, store
+
+ writes = writer.sync(compute=compute)
+
if path_or_file is None:
+ store.sync()
return target.getvalue()
finally:
- if sync and isinstance(path_or_file, basestring):
+ if not multifile and compute:
store.close()
if not compute:
import dask
- return dask.delayed(_finalize_store)(store.delayed_store, store)
+ return dask.delayed(_finalize_store)(writes, store)
+
+
+def dump_to_store(dataset, store, writer=None, encoder=None,
+ encoding=None, unlimited_dims=None):
+ """Store dataset contents to a backends.*DataStore object."""
+ if writer is None:
+ writer = ArrayWriter()
+
+ if encoding is None:
+ encoding = {}
+
+ variables, attrs = conventions.encode_dataset_coordinates(dataset)
+
+ check_encoding = set()
+ for k, enc in encoding.items():
+ # no need to shallow copy the variable again; that already happened
+ # in encode_dataset_coordinates
+ variables[k].encoding = enc
+ check_encoding.add(k)
+
+ if encoder:
+ variables, attrs = encoder(variables, attrs)
+
+ store.store(variables, attrs, check_encoding, writer,
+ unlimited_dims=unlimited_dims)
- if not sync:
- return store
def save_mfdataset(datasets, paths, mode='w', format=None, groups=None,
engine=None, compute=True):
@@ -806,7 +791,7 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None,
for obj in datasets:
if not isinstance(obj, Dataset):
raise TypeError('save_mfdataset only supports writing Dataset '
- 'objects, recieved type %s' % type(obj))
+ 'objects, received type %s' % type(obj))
if groups is None:
groups = [None] * len(datasets)
@@ -816,22 +801,22 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None,
'datasets, paths and groups arguments to '
'save_mfdataset')
- writer = ArrayWriter() if compute else None
- stores = [to_netcdf(ds, path, mode, format, group, engine, writer,
- compute=compute)
- for ds, path, group in zip(datasets, paths, groups)]
-
- if not compute:
- import dask
- return dask.delayed(stores)
+ writers, stores = zip(*[
+ to_netcdf(ds, path, mode, format, group, engine, compute=compute,
+ multifile=True)
+ for ds, path, group in zip(datasets, paths, groups)])
try:
- delayed = writer.sync(compute=compute)
- for store in stores:
- store.sync()
+ writes = [w.sync(compute=compute) for w in writers]
finally:
- for store in stores:
- store.close()
+ if compute:
+ for store in stores:
+ store.close()
+
+ if not compute:
+ import dask
+ return dask.delayed([dask.delayed(_finalize_store)(w, s)
+ for w, s in zip(writes, stores)])
def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None,
@@ -852,13 +837,14 @@ def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None,
store = backends.ZarrStore.open_group(store=store, mode=mode,
synchronizer=synchronizer,
- group=group, writer=None)
+ group=group)
- # I think zarr stores should always be sync'd immediately
+ writer = ArrayWriter()
# TODO: figure out how to properly handle unlimited_dims
- dataset.dump_to_store(store, sync=True, encoding=encoding, compute=compute)
+ dump_to_store(dataset, store, writer, encoding=encoding)
+ writes = writer.sync(compute=compute)
if not compute:
import dask
- return dask.delayed(_finalize_store)(store.delayed_store, store)
+ return dask.delayed(_finalize_store)(writes, store)
return store
diff --git a/xarray/backends/common.py b/xarray/backends/common.py
index 99f7698ee92..405d989f4af 100644
--- a/xarray/backends/common.py
+++ b/xarray/backends/common.py
@@ -1,14 +1,10 @@
from __future__ import absolute_import, division, print_function
-import contextlib
import logging
-import multiprocessing
-import threading
import time
import traceback
import warnings
from collections import Mapping, OrderedDict
-from functools import partial
import numpy as np
@@ -17,13 +13,6 @@
from ..core.pycompat import dask_array_type, iteritems
from ..core.utils import FrozenOrderedDict, NdimSizeLenMixin
-# Import default lock
-try:
- from dask.utils import SerializableLock
- HDF5_LOCK = SerializableLock()
-except ImportError:
- HDF5_LOCK = threading.Lock()
-
# Create a logger object, but don't add any handlers. Leave that to user code.
logger = logging.getLogger(__name__)
@@ -31,62 +20,6 @@
NONE_VAR_NAME = '__values__'
-def _get_scheduler(get=None, collection=None):
- """ Determine the dask scheduler that is being used.
-
- None is returned if not dask scheduler is active.
-
- See also
- --------
- dask.base.get_scheduler
- """
- try:
- # dask 0.18.1 and later
- from dask.base import get_scheduler
- actual_get = get_scheduler(get, collection)
- except ImportError:
- try:
- from dask.utils import effective_get
- actual_get = effective_get(get, collection)
- except ImportError:
- return None
-
- try:
- from dask.distributed import Client
- if isinstance(actual_get.__self__, Client):
- return 'distributed'
- except (ImportError, AttributeError):
- try:
- import dask.multiprocessing
- if actual_get == dask.multiprocessing.get:
- return 'multiprocessing'
- else:
- return 'threaded'
- except ImportError:
- return 'threaded'
-
-
-def _get_scheduler_lock(scheduler, path_or_file=None):
- """ Get the appropriate lock for a certain situation based onthe dask
- scheduler used.
-
- See Also
- --------
- dask.utils.get_scheduler_lock
- """
-
- if scheduler == 'distributed':
- from dask.distributed import Lock
- return Lock(path_or_file)
- elif scheduler == 'multiprocessing':
- return multiprocessing.Lock()
- elif scheduler == 'threaded':
- from dask.utils import SerializableLock
- return SerializableLock()
- else:
- return threading.Lock()
-
-
def _encode_variable_name(name):
if name is None:
name = NONE_VAR_NAME
@@ -133,39 +66,6 @@ def robust_getitem(array, key, catch=Exception, max_retries=6,
time.sleep(1e-3 * next_delay)
-class CombinedLock(object):
- """A combination of multiple locks.
-
- Like a locked door, a CombinedLock is locked if any of its constituent
- locks are locked.
- """
-
- def __init__(self, locks):
- self.locks = tuple(set(locks)) # remove duplicates
-
- def acquire(self, *args):
- return all(lock.acquire(*args) for lock in self.locks)
-
- def release(self, *args):
- for lock in self.locks:
- lock.release(*args)
-
- def __enter__(self):
- for lock in self.locks:
- lock.__enter__()
-
- def __exit__(self, *args):
- for lock in self.locks:
- lock.__exit__(*args)
-
- @property
- def locked(self):
- return any(lock.locked for lock in self.locks)
-
- def __repr__(self):
- return "CombinedLock(%r)" % list(self.locks)
-
-
class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed):
def __array__(self, dtype=None):
@@ -174,9 +74,6 @@ def __array__(self, dtype=None):
class AbstractDataStore(Mapping):
- _autoclose = None
- _ds = None
- _isopen = False
def __iter__(self):
return iter(self.variables)
@@ -259,7 +156,7 @@ def __exit__(self, exception_type, exception_value, traceback):
class ArrayWriter(object):
- def __init__(self, lock=HDF5_LOCK):
+ def __init__(self, lock=None):
self.sources = []
self.targets = []
self.lock = lock
@@ -274,6 +171,9 @@ def add(self, source, target):
def sync(self, compute=True):
if self.sources:
import dask.array as da
+ # TODO: consider wrapping targets with dask.delayed, if this makes
+ # for any discernable difference in perforance, e.g.,
+ # targets = [dask.delayed(t) for t in self.targets]
delayed_store = da.store(self.sources, self.targets,
lock=self.lock, compute=compute,
flush=True)
@@ -283,11 +183,6 @@ def sync(self, compute=True):
class AbstractWritableDataStore(AbstractDataStore):
- def __init__(self, writer=None, lock=HDF5_LOCK):
- if writer is None:
- writer = ArrayWriter(lock=lock)
- self.writer = writer
- self.delayed_store = None
def encode(self, variables, attributes):
"""
@@ -329,12 +224,6 @@ def set_attribute(self, k, v): # pragma: no cover
def set_variable(self, k, v): # pragma: no cover
raise NotImplementedError
- def sync(self, compute=True):
- if self._isopen and self._autoclose:
- # datastore will be reopened during write
- self.close()
- self.delayed_store = self.writer.sync(compute=compute)
-
def store_dataset(self, dataset):
"""
in stores, variables are all variables AND coordinates
@@ -345,7 +234,7 @@ def store_dataset(self, dataset):
self.store(dataset, dataset.attrs)
def store(self, variables, attributes, check_encoding_set=frozenset(),
- unlimited_dims=None):
+ writer=None, unlimited_dims=None):
"""
Top level method for putting data on this store, this method:
- encodes variables/attributes
@@ -361,16 +250,19 @@ def store(self, variables, attributes, check_encoding_set=frozenset(),
check_encoding_set : list-like
List of variables that should be checked for invalid encoding
values
+ writer : ArrayWriter
unlimited_dims : list-like
List of dimension names that should be treated as unlimited
dimensions.
"""
+ if writer is None:
+ writer = ArrayWriter()
variables, attributes = self.encode(variables, attributes)
self.set_attributes(attributes)
self.set_dimensions(variables, unlimited_dims=unlimited_dims)
- self.set_variables(variables, check_encoding_set,
+ self.set_variables(variables, check_encoding_set, writer,
unlimited_dims=unlimited_dims)
def set_attributes(self, attributes):
@@ -386,7 +278,7 @@ def set_attributes(self, attributes):
for k, v in iteritems(attributes):
self.set_attribute(k, v)
- def set_variables(self, variables, check_encoding_set,
+ def set_variables(self, variables, check_encoding_set, writer,
unlimited_dims=None):
"""
This provides a centralized method to set the variables on the data
@@ -399,6 +291,7 @@ def set_variables(self, variables, check_encoding_set,
check_encoding_set : list-like
List of variables that should be checked for invalid encoding
values
+ writer : ArrayWriter
unlimited_dims : list-like
List of dimension names that should be treated as unlimited
dimensions.
@@ -410,7 +303,7 @@ def set_variables(self, variables, check_encoding_set,
target, source = self.prepare_variable(
name, v, check, unlimited_dims=unlimited_dims)
- self.writer.add(source, target)
+ writer.add(source, target)
def set_dimensions(self, variables, unlimited_dims=None):
"""
@@ -457,87 +350,3 @@ def encode(self, variables, attributes):
attributes = OrderedDict([(k, self.encode_attribute(v))
for k, v in attributes.items()])
return variables, attributes
-
-
-class DataStorePickleMixin(object):
- """Subclasses must define `ds`, `_opener` and `_mode` attributes.
-
- Do not subclass this class: it is not part of xarray's external API.
- """
-
- def __getstate__(self):
- state = self.__dict__.copy()
- del state['_ds']
- del state['_isopen']
- if self._mode == 'w':
- # file has already been created, don't override when restoring
- state['_mode'] = 'a'
- return state
-
- def __setstate__(self, state):
- self.__dict__.update(state)
- self._ds = None
- self._isopen = False
-
- @property
- def ds(self):
- if self._ds is not None and self._isopen:
- return self._ds
- ds = self._opener(mode=self._mode)
- self._isopen = True
- return ds
-
- @contextlib.contextmanager
- def ensure_open(self, autoclose=None):
- """
- Helper function to make sure datasets are closed and opened
- at appropriate times to avoid too many open file errors.
-
- Use requires `autoclose=True` argument to `open_mfdataset`.
- """
-
- if autoclose is None:
- autoclose = self._autoclose
-
- if not self._isopen:
- try:
- self._ds = self._opener()
- self._isopen = True
- yield
- finally:
- if autoclose:
- self.close()
- else:
- yield
-
- def assert_open(self):
- if not self._isopen:
- raise AssertionError('internal failure: file must be open '
- 'if `autoclose=True` is used.')
-
-
-class PickleByReconstructionWrapper(object):
-
- def __init__(self, opener, file, mode='r', **kwargs):
- self.opener = partial(opener, file, mode=mode, **kwargs)
- self.mode = mode
- self._ds = None
-
- @property
- def value(self):
- self._ds = self.opener()
- return self._ds
-
- def __getstate__(self):
- state = self.__dict__.copy()
- del state['_ds']
- if self.mode == 'w':
- # file has already been created, don't override when restoring
- state['mode'] = 'a'
- return state
-
- def __setstate__(self, state):
- self.__dict__.update(state)
-
- def close(self):
- self._ds.close()
diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py
new file mode 100644
index 00000000000..a93285370b2
--- /dev/null
+++ b/xarray/backends/file_manager.py
@@ -0,0 +1,206 @@
+import threading
+
+from ..core import utils
+from ..core.options import OPTIONS
+from .lru_cache import LRUCache
+
+
+# Global cache for storing open files.
+FILE_CACHE = LRUCache(
+ OPTIONS['file_cache_maxsize'], on_evict=lambda k, v: v.close())
+assert FILE_CACHE.maxsize, 'file cache must be at least size one'
+
+
+_DEFAULT_MODE = utils.ReprObject('')
+
+
+class FileManager(object):
+ """Manager for acquiring and closing a file object.
+
+ Use FileManager subclasses (CachingFileManager in particular) on backend
+ storage classes to automatically handle issues related to keeping track of
+ many open files and transferring them between multiple processes.
+ """
+
+ def acquire(self):
+ """Acquire the file object from this manager."""
+ raise NotImplementedError
+
+ def close(self, needs_lock=True):
+ """Close the file object associated with this manager, if needed."""
+ raise NotImplementedError
+
+
+class CachingFileManager(FileManager):
+ """Wrapper for automatically opening and closing file objects.
+
+ Unlike files, CachingFileManager objects can be safely pickled and passed
+ between processes. They should be explicitly closed to release resources,
+ but a per-process least-recently-used cache for open files ensures that you
+ can safely create arbitrarily large numbers of FileManager objects.
+
+ Don't directly close files acquired from a FileManager. Instead, call
+ FileManager.close(), which ensures that closed files are removed from the
+ cache as well.
+
+ Example usage:
+
+ manager = FileManager(open, 'example.txt', mode='w')
+ f = manager.acquire()
+ f.write(...)
+ manager.close() # ensures file is closed
+
+ Note that as long as previous files are still cached, acquiring a file
+ multiple times from the same FileManager is essentially free:
+
+ f1 = manager.acquire()
+ f2 = manager.acquire()
+ assert f1 is f2
+
+ """
+
+ def __init__(self, opener, *args, **keywords):
+ """Initialize a FileManager.
+
+ Parameters
+ ----------
+ opener : callable
+ Function that when called like ``opener(*args, **kwargs)`` returns
+ an open file object. The file object must implement a ``close()``
+ method.
+ *args
+ Positional arguments for opener. A ``mode`` argument should be
+ provided as a keyword argument (see below). All arguments must be
+ hashable.
+ mode : optional
+ If provided, passed as a keyword argument to ``opener`` along with
+ ``**kwargs``. ``mode='w' `` has special treatment: after the first
+ call it is replaced by ``mode='a'`` in all subsequent function to
+ avoid overriding the newly created file.
+ kwargs : dict, optional
+ Keyword arguments for opener, excluding ``mode``. All values must
+ be hashable.
+ lock : duck-compatible threading.Lock, optional
+ Lock to use when modifying the cache inside acquire() and close().
+ By default, uses a new threading.Lock() object. If set, this object
+ should be pickleable.
+ cache : MutableMapping, optional
+ Mapping to use as a cache for open files. By default, uses xarray's
+ global LRU file cache. Because ``cache`` typically points to a
+ global variable and contains non-picklable file objects, an
+ unpickled FileManager objects will be restored with the default
+ cache.
+ """
+ # TODO: replace with real keyword arguments when we drop Python 2
+ # support
+ mode = keywords.pop('mode', _DEFAULT_MODE)
+ kwargs = keywords.pop('kwargs', None)
+ lock = keywords.pop('lock', None)
+ cache = keywords.pop('cache', FILE_CACHE)
+ if keywords:
+ raise TypeError('FileManager() got unexpected keyword arguments: '
+ '%s' % list(keywords))
+
+ self._opener = opener
+ self._args = args
+ self._mode = mode
+ self._kwargs = {} if kwargs is None else dict(kwargs)
+ self._default_lock = lock is None or lock is False
+ self._lock = threading.Lock() if self._default_lock else lock
+ self._cache = cache
+ self._key = self._make_key()
+
+ def _make_key(self):
+ """Make a key for caching files in the LRU cache."""
+ value = (self._opener,
+ self._args,
+ self._mode,
+ tuple(sorted(self._kwargs.items())))
+ return _HashedSequence(value)
+
+ def acquire(self):
+ """Acquiring a file object from the manager.
+
+ A new file is only opened if it has expired from the
+ least-recently-used cache.
+
+ This method uses a reentrant lock, which ensures that it is
+ thread-safe. You can safely acquire a file in multiple threads at the
+ same time, as long as the underlying file object is thread-safe.
+
+ Returns
+ -------
+ An open file object, as returned by ``opener(*args, **kwargs)``.
+ """
+ with self._lock:
+ try:
+ file = self._cache[self._key]
+ except KeyError:
+ kwargs = self._kwargs
+ if self._mode is not _DEFAULT_MODE:
+ kwargs = kwargs.copy()
+ kwargs['mode'] = self._mode
+ file = self._opener(*self._args, **kwargs)
+ if self._mode == 'w':
+ # ensure file doesn't get overriden when opened again
+ self._mode = 'a'
+ self._key = self._make_key()
+ self._cache[self._key] = file
+ return file
+
+ def _close(self):
+ default = None
+ file = self._cache.pop(self._key, default)
+ if file is not None:
+ file.close()
+
+ def close(self, needs_lock=True):
+ """Explicitly close any associated file object (if necessary)."""
+ # TODO: remove needs_lock if/when we have a reentrant lock in
+ # dask.distributed: https://github.com/dask/dask/issues/3832
+ if needs_lock:
+ with self._lock:
+ self._close()
+ else:
+ self._close()
+
+ def __getstate__(self):
+ """State for pickling."""
+ lock = None if self._default_lock else self._lock
+ return (self._opener, self._args, self._mode, self._kwargs, lock)
+
+ def __setstate__(self, state):
+ """Restore from a pickle."""
+ opener, args, mode, kwargs, lock = state
+ self.__init__(opener, *args, mode=mode, kwargs=kwargs, lock=lock)
+
+
+class _HashedSequence(list):
+ """Speedup repeated look-ups by caching hash values.
+
+ Based on what Python uses internally in functools.lru_cache.
+
+ Python doesn't perform this optimization automatically:
+ https://bugs.python.org/issue1462796
+ """
+
+ def __init__(self, tuple_value):
+ self[:] = tuple_value
+ self.hashvalue = hash(tuple_value)
+
+ def __hash__(self):
+ return self.hashvalue
+
+
+class DummyFileManager(FileManager):
+ """FileManager that simply wraps an open file in the FileManager interface.
+ """
+ def __init__(self, value):
+ self._value = value
+
+ def acquire(self):
+ return self._value
+
+ def close(self, needs_lock=True):
+ del needs_lock # ignored
+ self._value.close()
diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py
index 959cd221734..59cd4e84793 100644
--- a/xarray/backends/h5netcdf_.py
+++ b/xarray/backends/h5netcdf_.py
@@ -8,11 +8,12 @@
from ..core import indexing
from ..core.pycompat import OrderedDict, bytes_type, iteritems, unicode_type
from ..core.utils import FrozenOrderedDict, close_on_error
-from .common import (
- HDF5_LOCK, DataStorePickleMixin, WritableCFDataStore, find_root)
+from .common import WritableCFDataStore
+from .file_manager import CachingFileManager
+from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock
from .netCDF4_ import (
- BaseNetCDF4Array, _encode_nc4_variable, _extract_nc4_variable_encoding,
- _get_datatype, _nc4_require_group)
+ BaseNetCDF4Array, GroupWrapper, _encode_nc4_variable,
+ _extract_nc4_variable_encoding, _get_datatype, _nc4_require_group)
class H5NetCDFArrayWrapper(BaseNetCDF4Array):
@@ -25,8 +26,9 @@ def _getitem(self, key):
# h5py requires using lists for fancy indexing:
# https://github.com/h5py/h5py/issues/992
key = tuple(list(k) if isinstance(k, np.ndarray) else k for k in key)
- with self.datastore.ensure_open(autoclose=True):
- return self.get_array()[key]
+ array = self.get_array()
+ with self.datastore.lock:
+ return array[key]
def maybe_decode_bytes(txt):
@@ -61,104 +63,102 @@ def _open_h5netcdf_group(filename, mode, group):
import h5netcdf
ds = h5netcdf.File(filename, mode=mode)
with close_on_error(ds):
- return _nc4_require_group(
+ ds = _nc4_require_group(
ds, group, mode, create_group=_h5netcdf_create_group)
+ return GroupWrapper(ds)
-class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin):
+class H5NetCDFStore(WritableCFDataStore):
"""Store for reading and writing data via h5netcdf
"""
def __init__(self, filename, mode='r', format=None, group=None,
- writer=None, autoclose=False, lock=HDF5_LOCK):
+ lock=None, autoclose=False):
if format not in [None, 'NETCDF4']:
raise ValueError('invalid format for h5netcdf backend')
- opener = functools.partial(_open_h5netcdf_group, filename, mode=mode,
- group=group)
- self._ds = opener()
- if autoclose:
- raise NotImplementedError('autoclose=True is not implemented '
- 'for the h5netcdf backend pending '
- 'further exploration, e.g., bug fixes '
- '(in h5netcdf?)')
- self._autoclose = False
- self._isopen = True
+ self._manager = CachingFileManager(
+ _open_h5netcdf_group, filename, mode=mode,
+ kwargs=dict(group=group))
+
+ if lock is None:
+ if mode == 'r':
+ lock = HDF5_LOCK
+ else:
+ lock = combine_locks([HDF5_LOCK, get_write_lock(filename)])
+
self.format = format
- self._opener = opener
self._filename = filename
self._mode = mode
- super(H5NetCDFStore, self).__init__(writer, lock=lock)
+ self.lock = ensure_lock(lock)
+ self.autoclose = autoclose
+
+ @property
+ def ds(self):
+ return self._manager.acquire().value
def open_store_variable(self, name, var):
import h5py
- with self.ensure_open(autoclose=False):
- dimensions = var.dimensions
- data = indexing.LazilyOuterIndexedArray(
- H5NetCDFArrayWrapper(name, self))
- attrs = _read_attributes(var)
-
- # netCDF4 specific encoding
- encoding = {
- 'chunksizes': var.chunks,
- 'fletcher32': var.fletcher32,
- 'shuffle': var.shuffle,
- }
- # Convert h5py-style compression options to NetCDF4-Python
- # style, if possible
- if var.compression == 'gzip':
- encoding['zlib'] = True
- encoding['complevel'] = var.compression_opts
- elif var.compression is not None:
- encoding['compression'] = var.compression
- encoding['compression_opts'] = var.compression_opts
-
- # save source so __repr__ can detect if it's local or not
- encoding['source'] = self._filename
- encoding['original_shape'] = var.shape
-
- vlen_dtype = h5py.check_dtype(vlen=var.dtype)
- if vlen_dtype is unicode_type:
- encoding['dtype'] = str
- elif vlen_dtype is not None: # pragma: no cover
- # xarray doesn't support writing arbitrary vlen dtypes yet.
- pass
- else:
- encoding['dtype'] = var.dtype
+ dimensions = var.dimensions
+ data = indexing.LazilyOuterIndexedArray(
+ H5NetCDFArrayWrapper(name, self))
+ attrs = _read_attributes(var)
+
+ # netCDF4 specific encoding
+ encoding = {
+ 'chunksizes': var.chunks,
+ 'fletcher32': var.fletcher32,
+ 'shuffle': var.shuffle,
+ }
+ # Convert h5py-style compression options to NetCDF4-Python
+ # style, if possible
+ if var.compression == 'gzip':
+ encoding['zlib'] = True
+ encoding['complevel'] = var.compression_opts
+ elif var.compression is not None:
+ encoding['compression'] = var.compression
+ encoding['compression_opts'] = var.compression_opts
+
+ # save source so __repr__ can detect if it's local or not
+ encoding['source'] = self._filename
+ encoding['original_shape'] = var.shape
+
+ vlen_dtype = h5py.check_dtype(vlen=var.dtype)
+ if vlen_dtype is unicode_type:
+ encoding['dtype'] = str
+ elif vlen_dtype is not None: # pragma: no cover
+ # xarray doesn't support writing arbitrary vlen dtypes yet.
+ pass
+ else:
+ encoding['dtype'] = var.dtype
return Variable(dimensions, data, attrs, encoding)
def get_variables(self):
- with self.ensure_open(autoclose=False):
- return FrozenOrderedDict((k, self.open_store_variable(k, v))
- for k, v in iteritems(self.ds.variables))
+ return FrozenOrderedDict((k, self.open_store_variable(k, v))
+ for k, v in iteritems(self.ds.variables))
def get_attrs(self):
- with self.ensure_open(autoclose=True):
- return FrozenOrderedDict(_read_attributes(self.ds))
+ return FrozenOrderedDict(_read_attributes(self.ds))
def get_dimensions(self):
- with self.ensure_open(autoclose=True):
- return self.ds.dimensions
+ return self.ds.dimensions
def get_encoding(self):
- with self.ensure_open(autoclose=True):
- encoding = {}
- encoding['unlimited_dims'] = {
- k for k, v in self.ds.dimensions.items() if v is None}
+ encoding = {}
+ encoding['unlimited_dims'] = {
+ k for k, v in self.ds.dimensions.items() if v is None}
return encoding
def set_dimension(self, name, length, is_unlimited=False):
- with self.ensure_open(autoclose=False):
- if is_unlimited:
- self.ds.dimensions[name] = None
- self.ds.resize_dimension(name, length)
- else:
- self.ds.dimensions[name] = length
+ if is_unlimited:
+ self.ds.dimensions[name] = None
+ self.ds.resize_dimension(name, length)
+ else:
+ self.ds.dimensions[name] = length
def set_attribute(self, key, value):
- with self.ensure_open(autoclose=False):
- self.ds.attrs[key] = value
+ self.ds.attrs[key] = value
def encode_variable(self, variable):
return _encode_nc4_variable(variable)
@@ -226,18 +226,11 @@ def prepare_variable(self, name, variable, check_encoding=False,
return target, variable.data
- def sync(self, compute=True):
- if not compute:
- raise NotImplementedError(
- 'compute=False is not supported for the h5netcdf backend yet')
- with self.ensure_open(autoclose=True):
- super(H5NetCDFStore, self).sync(compute=compute)
- self.ds.sync()
-
- def close(self):
- if self._isopen:
- # netCDF4 only allows closing the root group
- ds = find_root(self.ds)
- if not ds._closed:
- ds.close()
- self._isopen = False
+ def sync(self):
+ self.ds.sync()
+ # if self.autoclose:
+ # self.close()
+ # super(H5NetCDFStore, self).sync(compute=compute)
+
+ def close(self, **kwargs):
+ self._manager.close(**kwargs)
diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py
new file mode 100644
index 00000000000..f633280ef1d
--- /dev/null
+++ b/xarray/backends/locks.py
@@ -0,0 +1,191 @@
+import multiprocessing
+import threading
+import weakref
+
+try:
+ from dask.utils import SerializableLock
+except ImportError:
+ # no need to worry about serializing the lock
+ SerializableLock = threading.Lock
+
+
+# Locks used by multiple backends.
+# Neither HDF5 nor the netCDF-C library are thread-safe.
+HDF5_LOCK = SerializableLock()
+NETCDFC_LOCK = SerializableLock()
+
+
+_FILE_LOCKS = weakref.WeakValueDictionary()
+
+
+def _get_threaded_lock(key):
+ try:
+ lock = _FILE_LOCKS[key]
+ except KeyError:
+ lock = _FILE_LOCKS[key] = threading.Lock()
+ return lock
+
+
+def _get_multiprocessing_lock(key):
+ # TODO: make use of the key -- maybe use locket.py?
+ # https://github.com/mwilliamson/locket.py
+ del key # unused
+ return multiprocessing.Lock()
+
+
+def _get_distributed_lock(key):
+ from dask.distributed import Lock
+ return Lock(key)
+
+
+_LOCK_MAKERS = {
+ None: _get_threaded_lock,
+ 'threaded': _get_threaded_lock,
+ 'multiprocessing': _get_multiprocessing_lock,
+ 'distributed': _get_distributed_lock,
+}
+
+
+def _get_lock_maker(scheduler=None):
+ """Returns an appropriate function for creating resource locks.
+
+ Parameters
+ ----------
+ scheduler : str or None
+ Dask scheduler being used.
+
+ See Also
+ --------
+ dask.utils.get_scheduler_lock
+ """
+ return _LOCK_MAKERS[scheduler]
+
+
+def _get_scheduler(get=None, collection=None):
+ """Determine the dask scheduler that is being used.
+
+ None is returned if no dask scheduler is active.
+
+ See also
+ --------
+ dask.base.get_scheduler
+ """
+ try:
+ # dask 0.18.1 and later
+ from dask.base import get_scheduler
+ actual_get = get_scheduler(get, collection)
+ except ImportError:
+ try:
+ from dask.utils import effective_get
+ actual_get = effective_get(get, collection)
+ except ImportError:
+ return None
+
+ try:
+ from dask.distributed import Client
+ if isinstance(actual_get.__self__, Client):
+ return 'distributed'
+ except (ImportError, AttributeError):
+ try:
+ import dask.multiprocessing
+ if actual_get == dask.multiprocessing.get:
+ return 'multiprocessing'
+ else:
+ return 'threaded'
+ except ImportError:
+ return 'threaded'
+
+
+def get_write_lock(key):
+ """Get a scheduler appropriate lock for writing to the given resource.
+
+ Parameters
+ ----------
+ key : str
+ Name of the resource for which to acquire a lock. Typically a filename.
+
+ Returns
+ -------
+ Lock object that can be used like a threading.Lock object.
+ """
+ scheduler = _get_scheduler()
+ lock_maker = _get_lock_maker(scheduler)
+ return lock_maker(key)
+
+
+class CombinedLock(object):
+ """A combination of multiple locks.
+
+ Like a locked door, a CombinedLock is locked if any of its constituent
+ locks are locked.
+ """
+
+ def __init__(self, locks):
+ self.locks = tuple(set(locks)) # remove duplicates
+
+ def acquire(self, *args):
+ return all(lock.acquire(*args) for lock in self.locks)
+
+ def release(self, *args):
+ for lock in self.locks:
+ lock.release(*args)
+
+ def __enter__(self):
+ for lock in self.locks:
+ lock.__enter__()
+
+ def __exit__(self, *args):
+ for lock in self.locks:
+ lock.__exit__(*args)
+
+ @property
+ def locked(self):
+ return any(lock.locked for lock in self.locks)
+
+ def __repr__(self):
+ return "CombinedLock(%r)" % list(self.locks)
+
+
+class DummyLock(object):
+ """DummyLock provides the lock API without any actual locking."""
+
+ def acquire(self, *args):
+ pass
+
+ def release(self, *args):
+ pass
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, *args):
+ pass
+
+ @property
+ def locked(self):
+ return False
+
+
+def combine_locks(locks):
+ """Combine a sequence of locks into a single lock."""
+ all_locks = []
+ for lock in locks:
+ if isinstance(lock, CombinedLock):
+ all_locks.extend(lock.locks)
+ elif lock is not None:
+ all_locks.append(lock)
+
+ num_locks = len(all_locks)
+ if num_locks > 1:
+ return CombinedLock(all_locks)
+ elif num_locks == 1:
+ return all_locks[0]
+ else:
+ return DummyLock()
+
+
+def ensure_lock(lock):
+ """Ensure that the given object is a lock."""
+ if lock is None or lock is False:
+ return DummyLock()
+ return lock
diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py
new file mode 100644
index 00000000000..321a1ca4da4
--- /dev/null
+++ b/xarray/backends/lru_cache.py
@@ -0,0 +1,91 @@
+import collections
+import threading
+
+from ..core.pycompat import move_to_end
+
+
+class LRUCache(collections.MutableMapping):
+ """Thread-safe LRUCache based on an OrderedDict.
+
+ All dict operations (__getitem__, __setitem__, __contains__) update the
+ priority of the relevant key and take O(1) time. The dict is iterated over
+ in order from the oldest to newest key, which means that a complete pass
+ over the dict should not affect the order of any entries.
+
+ When a new item is set and the maximum size of the cache is exceeded, the
+ oldest item is dropped and called with ``on_evict(key, value)``.
+
+ The ``maxsize`` property can be used to view or adjust the capacity of
+ the cache, e.g., ``cache.maxsize = new_size``.
+ """
+ def __init__(self, maxsize, on_evict=None):
+ """
+ Parameters
+ ----------
+ maxsize : int
+ Integer maximum number of items to hold in the cache.
+ on_evict: callable, optional
+ Function to call like ``on_evict(key, value)`` when items are
+ evicted.
+ """
+ if not isinstance(maxsize, int):
+ raise TypeError('maxsize must be an integer')
+ if maxsize < 0:
+ raise ValueError('maxsize must be non-negative')
+ self._maxsize = maxsize
+ self._on_evict = on_evict
+ self._cache = collections.OrderedDict()
+ self._lock = threading.RLock()
+
+ def __getitem__(self, key):
+ # record recent use of the key by moving it to the front of the list
+ with self._lock:
+ value = self._cache[key]
+ move_to_end(self._cache, key)
+ return value
+
+ def _enforce_size_limit(self, capacity):
+ """Shrink the cache if necessary, evicting the oldest items."""
+ while len(self._cache) > capacity:
+ key, value = self._cache.popitem(last=False)
+ if self._on_evict is not None:
+ self._on_evict(key, value)
+
+ def __setitem__(self, key, value):
+ with self._lock:
+ if key in self._cache:
+ # insert the new value at the end
+ del self._cache[key]
+ self._cache[key] = value
+ elif self._maxsize:
+ # make room if necessary
+ self._enforce_size_limit(self._maxsize - 1)
+ self._cache[key] = value
+ elif self._on_evict is not None:
+ # not saving, immediately evict
+ self._on_evict(key, value)
+
+ def __delitem__(self, key):
+ del self._cache[key]
+
+ def __iter__(self):
+ # create a list, so accessing the cache during iteration cannot change
+ # the iteration order
+ return iter(list(self._cache))
+
+ def __len__(self):
+ return len(self._cache)
+
+ @property
+ def maxsize(self):
+ """Maximum number of items can be held in the cache."""
+ return self._maxsize
+
+ @maxsize.setter
+ def maxsize(self, size):
+ """Resize the cache, evicting the oldest items if necessary."""
+ if size < 0:
+ raise ValueError('maxsize must be non-negative')
+ with self._lock:
+ self._enforce_size_limit(size)
+ self._maxsize = size
diff --git a/xarray/backends/memory.py b/xarray/backends/memory.py
index dcf092557b8..195d4647534 100644
--- a/xarray/backends/memory.py
+++ b/xarray/backends/memory.py
@@ -17,10 +17,9 @@ class InMemoryDataStore(AbstractWritableDataStore):
This store exists purely for internal testing purposes.
"""
- def __init__(self, variables=None, attributes=None, writer=None):
+ def __init__(self, variables=None, attributes=None):
self._variables = OrderedDict() if variables is None else variables
self._attributes = OrderedDict() if attributes is None else attributes
- super(InMemoryDataStore, self).__init__(writer)
def get_attrs(self):
return self._attributes
diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py
index 5c6d82fd126..08ba085b77e 100644
--- a/xarray/backends/netCDF4_.py
+++ b/xarray/backends/netCDF4_.py
@@ -10,12 +10,13 @@
from .. import Variable, coding
from ..coding.variables import pop_to
from ..core import indexing
-from ..core.pycompat import (
- PY3, OrderedDict, basestring, iteritems, suppress)
+from ..core.pycompat import PY3, OrderedDict, basestring, iteritems, suppress
from ..core.utils import FrozenOrderedDict, close_on_error, is_remote_uri
from .common import (
- HDF5_LOCK, BackendArray, DataStorePickleMixin, WritableCFDataStore,
- find_root, robust_getitem)
+ BackendArray, WritableCFDataStore, find_root, robust_getitem)
+from .locks import (NETCDFC_LOCK, HDF5_LOCK,
+ combine_locks, ensure_lock, get_write_lock)
+from .file_manager import CachingFileManager, DummyFileManager
from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable
# This lookup table maps from dtype.byteorder to a readable endian
@@ -26,6 +27,9 @@
'|': 'native'}
+NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK])
+
+
class BaseNetCDF4Array(BackendArray):
def __init__(self, variable_name, datastore):
self.datastore = datastore
@@ -43,12 +47,13 @@ def __init__(self, variable_name, datastore):
self.dtype = dtype
def __setitem__(self, key, value):
- with self.datastore.ensure_open(autoclose=True):
+ with self.datastore.lock:
data = self.get_array()
data[key] = value
+ if self.datastore.autoclose:
+ self.datastore.close(needs_lock=False)
def get_array(self):
- self.datastore.assert_open()
return self.datastore.ds.variables[self.variable_name]
@@ -64,20 +69,22 @@ def _getitem(self, key):
else:
getitem = operator.getitem
- with self.datastore.ensure_open(autoclose=True):
- try:
- array = getitem(self.get_array(), key)
- except IndexError:
- # Catch IndexError in netCDF4 and return a more informative
- # error message. This is most often called when an unsorted
- # indexer is used before the data is loaded from disk.
- msg = ('The indexing operation you are attempting to perform '
- 'is not valid on netCDF4.Variable object. Try loading '
- 'your data into memory first by calling .load().')
- if not PY3:
- import traceback
- msg += '\n\nOriginal traceback:\n' + traceback.format_exc()
- raise IndexError(msg)
+ original_array = self.get_array()
+
+ try:
+ with self.datastore.lock:
+ array = getitem(original_array, key)
+ except IndexError:
+ # Catch IndexError in netCDF4 and return a more informative
+ # error message. This is most often called when an unsorted
+ # indexer is used before the data is loaded from disk.
+ msg = ('The indexing operation you are attempting to perform '
+ 'is not valid on netCDF4.Variable object. Try loading '
+ 'your data into memory first by calling .load().')
+ if not PY3:
+ import traceback
+ msg += '\n\nOriginal traceback:\n' + traceback.format_exc()
+ raise IndexError(msg)
return array
@@ -224,7 +231,17 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False,
return encoding
-def _open_netcdf4_group(filename, mode, group=None, **kwargs):
+class GroupWrapper(object):
+ """Wrap netCDF4.Group objects so closing them closes the root group."""
+ def __init__(self, value):
+ self.value = value
+
+ def close(self):
+ # netCDF4 only allows closing the root group
+ find_root(self.value).close()
+
+
+def _open_netcdf4_group(filename, lock, mode, group=None, **kwargs):
import netCDF4 as nc4
ds = nc4.Dataset(filename, mode=mode, **kwargs)
@@ -234,7 +251,7 @@ def _open_netcdf4_group(filename, mode, group=None, **kwargs):
_disable_auto_decode_group(ds)
- return ds
+ return GroupWrapper(ds)
def _disable_auto_decode_variable(var):
@@ -280,40 +297,33 @@ def _set_nc_attribute(obj, key, value):
obj.setncattr(key, value)
-class NetCDF4DataStore(WritableCFDataStore, DataStorePickleMixin):
+class NetCDF4DataStore(WritableCFDataStore):
"""Store for reading and writing data via the Python-NetCDF4 library.
This store supports NetCDF3, NetCDF4 and OpenDAP datasets.
"""
- def __init__(self, netcdf4_dataset, mode='r', writer=None, opener=None,
- autoclose=False, lock=HDF5_LOCK):
-
- if autoclose and opener is None:
- raise ValueError('autoclose requires an opener')
+ def __init__(self, manager, lock=NETCDF4_PYTHON_LOCK, autoclose=False):
+ import netCDF4
- _disable_auto_decode_group(netcdf4_dataset)
+ if isinstance(manager, netCDF4.Dataset):
+ _disable_auto_decode_group(manager)
+ manager = DummyFileManager(GroupWrapper(manager))
- self._ds = netcdf4_dataset
- self._autoclose = autoclose
- self._isopen = True
+ self._manager = manager
self.format = self.ds.data_model
self._filename = self.ds.filepath()
self.is_remote = is_remote_uri(self._filename)
- self._mode = mode = 'a' if mode == 'w' else mode
- if opener:
- self._opener = functools.partial(opener, mode=self._mode)
- else:
- self._opener = opener
- super(NetCDF4DataStore, self).__init__(writer, lock=lock)
+ self.lock = ensure_lock(lock)
+ self.autoclose = autoclose
@classmethod
def open(cls, filename, mode='r', format='NETCDF4', group=None,
- writer=None, clobber=True, diskless=False, persist=False,
- autoclose=False, lock=HDF5_LOCK):
- import netCDF4 as nc4
+ clobber=True, diskless=False, persist=False,
+ lock=None, lock_maker=None, autoclose=False):
+ import netCDF4
if (len(filename) == 88 and
- LooseVersion(nc4.__version__) < "1.3.1"):
+ LooseVersion(netCDF4.__version__) < "1.3.1"):
warnings.warn(
'A segmentation fault may occur when the '
'file path has exactly 88 characters as it does '
@@ -324,86 +334,91 @@ def open(cls, filename, mode='r', format='NETCDF4', group=None,
'https://github.com/pydata/xarray/issues/1745')
if format is None:
format = 'NETCDF4'
- opener = functools.partial(_open_netcdf4_group, filename, mode=mode,
- group=group, clobber=clobber,
- diskless=diskless, persist=persist,
- format=format)
- ds = opener()
- return cls(ds, mode=mode, writer=writer, opener=opener,
- autoclose=autoclose, lock=lock)
- def open_store_variable(self, name, var):
- with self.ensure_open(autoclose=False):
- dimensions = var.dimensions
- data = indexing.LazilyOuterIndexedArray(
- NetCDF4ArrayWrapper(name, self))
- attributes = OrderedDict((k, var.getncattr(k))
- for k in var.ncattrs())
- _ensure_fill_value_valid(data, attributes)
- # netCDF4 specific encoding; save _FillValue for later
- encoding = {}
- filters = var.filters()
- if filters is not None:
- encoding.update(filters)
- chunking = var.chunking()
- if chunking is not None:
- if chunking == 'contiguous':
- encoding['contiguous'] = True
- encoding['chunksizes'] = None
+ if lock is None:
+ if mode == 'r':
+ if is_remote_uri(filename):
+ lock = NETCDFC_LOCK
+ else:
+ lock = NETCDF4_PYTHON_LOCK
+ else:
+ if format is None or format.startswith('NETCDF4'):
+ base_lock = NETCDF4_PYTHON_LOCK
else:
- encoding['contiguous'] = False
- encoding['chunksizes'] = tuple(chunking)
- # TODO: figure out how to round-trip "endian-ness" without raising
- # warnings from netCDF4
- # encoding['endian'] = var.endian()
- pop_to(attributes, encoding, 'least_significant_digit')
- # save source so __repr__ can detect if it's local or not
- encoding['source'] = self._filename
- encoding['original_shape'] = var.shape
- encoding['dtype'] = var.dtype
+ base_lock = NETCDFC_LOCK
+ lock = combine_locks([base_lock, get_write_lock(filename)])
+
+ manager = CachingFileManager(
+ _open_netcdf4_group, filename, lock, mode=mode,
+ kwargs=dict(group=group, clobber=clobber, diskless=diskless,
+ persist=persist, format=format))
+ return cls(manager, lock=lock, autoclose=autoclose)
+
+ @property
+ def ds(self):
+ return self._manager.acquire().value
+
+ def open_store_variable(self, name, var):
+ dimensions = var.dimensions
+ data = indexing.LazilyOuterIndexedArray(
+ NetCDF4ArrayWrapper(name, self))
+ attributes = OrderedDict((k, var.getncattr(k))
+ for k in var.ncattrs())
+ _ensure_fill_value_valid(data, attributes)
+ # netCDF4 specific encoding; save _FillValue for later
+ encoding = {}
+ filters = var.filters()
+ if filters is not None:
+ encoding.update(filters)
+ chunking = var.chunking()
+ if chunking is not None:
+ if chunking == 'contiguous':
+ encoding['contiguous'] = True
+ encoding['chunksizes'] = None
+ else:
+ encoding['contiguous'] = False
+ encoding['chunksizes'] = tuple(chunking)
+ # TODO: figure out how to round-trip "endian-ness" without raising
+ # warnings from netCDF4
+ # encoding['endian'] = var.endian()
+ pop_to(attributes, encoding, 'least_significant_digit')
+ # save source so __repr__ can detect if it's local or not
+ encoding['source'] = self._filename
+ encoding['original_shape'] = var.shape
+ encoding['dtype'] = var.dtype
return Variable(dimensions, data, attributes, encoding)
def get_variables(self):
- with self.ensure_open(autoclose=False):
- dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v))
- for k, v in
- iteritems(self.ds.variables))
+ dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v))
+ for k, v in
+ iteritems(self.ds.variables))
return dsvars
def get_attrs(self):
- with self.ensure_open(autoclose=True):
- attrs = FrozenOrderedDict((k, self.ds.getncattr(k))
- for k in self.ds.ncattrs())
+ attrs = FrozenOrderedDict((k, self.ds.getncattr(k))
+ for k in self.ds.ncattrs())
return attrs
def get_dimensions(self):
- with self.ensure_open(autoclose=True):
- dims = FrozenOrderedDict((k, len(v))
- for k, v in iteritems(self.ds.dimensions))
+ dims = FrozenOrderedDict((k, len(v))
+ for k, v in iteritems(self.ds.dimensions))
return dims
def get_encoding(self):
- with self.ensure_open(autoclose=True):
- encoding = {}
- encoding['unlimited_dims'] = {
- k for k, v in self.ds.dimensions.items() if v.isunlimited()}
+ encoding = {}
+ encoding['unlimited_dims'] = {
+ k for k, v in self.ds.dimensions.items() if v.isunlimited()}
return encoding
def set_dimension(self, name, length, is_unlimited=False):
- with self.ensure_open(autoclose=False):
- dim_length = length if not is_unlimited else None
- self.ds.createDimension(name, size=dim_length)
+ dim_length = length if not is_unlimited else None
+ self.ds.createDimension(name, size=dim_length)
def set_attribute(self, key, value):
- with self.ensure_open(autoclose=False):
- if self.format != 'NETCDF4':
- value = encode_nc3_attr_value(value)
- _set_nc_attribute(self.ds, key, value)
-
- def set_variables(self, *args, **kwargs):
- with self.ensure_open(autoclose=False):
- super(NetCDF4DataStore, self).set_variables(*args, **kwargs)
+ if self.format != 'NETCDF4':
+ value = encode_nc3_attr_value(value)
+ _set_nc_attribute(self.ds, key, value)
def encode_variable(self, variable):
variable = _force_native_endianness(variable)
@@ -461,15 +476,8 @@ def prepare_variable(self, name, variable, check_encoding=False,
return target, variable.data
- def sync(self, compute=True):
- with self.ensure_open(autoclose=True):
- super(NetCDF4DataStore, self).sync(compute=compute)
- self.ds.sync()
+ def sync(self):
+ self.ds.sync()
- def close(self):
- if self._isopen:
- # netCDF4 only allows closing the root group
- ds = find_root(self.ds)
- if ds._isopen:
- ds.close()
- self._isopen = False
+ def close(self, **kwargs):
+ self._manager.close(**kwargs)
diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py
index d946c6fa927..e4691d1f7e1 100644
--- a/xarray/backends/pseudonetcdf_.py
+++ b/xarray/backends/pseudonetcdf_.py
@@ -1,17 +1,18 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import functools
+from __future__ import absolute_import, division, print_function
import numpy as np
from .. import Variable
-from ..core.pycompat import OrderedDict
-from ..core.utils import (FrozenOrderedDict, Frozen)
from ..core import indexing
+from ..core.pycompat import OrderedDict
+from ..core.utils import Frozen
+from .common import AbstractDataStore, BackendArray
+from .file_manager import CachingFileManager
+from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock
+
-from .common import AbstractDataStore, DataStorePickleMixin, BackendArray
+# psuedonetcdf can invoke netCDF libraries internally
+PNETCDF_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK])
class PncArrayWrapper(BackendArray):
@@ -24,7 +25,6 @@ def __init__(self, variable_name, datastore):
self.dtype = np.dtype(array.dtype)
def get_array(self):
- self.datastore.assert_open()
return self.datastore.ds.variables[self.variable_name]
def __getitem__(self, key):
@@ -33,57 +33,55 @@ def __getitem__(self, key):
self._getitem)
def _getitem(self, key):
- with self.datastore.ensure_open(autoclose=True):
- return self.get_array()[key]
+ array = self.get_array()
+ with self.datastore.lock:
+ return array[key]
-class PseudoNetCDFDataStore(AbstractDataStore, DataStorePickleMixin):
+class PseudoNetCDFDataStore(AbstractDataStore):
"""Store for accessing datasets via PseudoNetCDF
"""
@classmethod
- def open(cls, filename, format=None, writer=None,
- autoclose=False, **format_kwds):
+ def open(cls, filename, lock=None, **format_kwds):
from PseudoNetCDF import pncopen
- opener = functools.partial(pncopen, filename, **format_kwds)
- ds = opener()
- mode = format_kwds.get('mode', 'r')
- return cls(ds, mode=mode, writer=writer, opener=opener,
- autoclose=autoclose)
- def __init__(self, pnc_dataset, mode='r', writer=None, opener=None,
- autoclose=False):
+ keywords = dict(kwargs=format_kwds)
+ # only include mode if explicitly passed
+ mode = format_kwds.pop('mode', None)
+ if mode is not None:
+ keywords['mode'] = mode
+
+ if lock is None:
+ lock = PNETCDF_LOCK
+
+ manager = CachingFileManager(pncopen, filename, lock=lock, **keywords)
+ return cls(manager, lock)
- if autoclose and opener is None:
- raise ValueError('autoclose requires an opener')
+ def __init__(self, manager, lock=None):
+ self._manager = manager
+ self.lock = ensure_lock(lock)
- self._ds = pnc_dataset
- self._autoclose = autoclose
- self._isopen = True
- self._opener = opener
- self._mode = mode
- super(PseudoNetCDFDataStore, self).__init__()
+ @property
+ def ds(self):
+ return self._manager.acquire()
def open_store_variable(self, name, var):
- with self.ensure_open(autoclose=False):
- data = indexing.LazilyOuterIndexedArray(
- PncArrayWrapper(name, self)
- )
+ data = indexing.LazilyOuterIndexedArray(
+ PncArrayWrapper(name, self)
+ )
attrs = OrderedDict((k, getattr(var, k)) for k in var.ncattrs())
return Variable(var.dimensions, data, attrs)
def get_variables(self):
- with self.ensure_open(autoclose=False):
- return FrozenOrderedDict((k, self.open_store_variable(k, v))
- for k, v in self.ds.variables.items())
+ return ((k, self.open_store_variable(k, v))
+ for k, v in self.ds.variables.items())
def get_attrs(self):
- with self.ensure_open(autoclose=True):
- return Frozen(dict([(k, getattr(self.ds, k))
- for k in self.ds.ncattrs()]))
+ return Frozen(dict([(k, getattr(self.ds, k))
+ for k in self.ds.ncattrs()]))
def get_dimensions(self):
- with self.ensure_open(autoclose=True):
- return Frozen(self.ds.dimensions)
+ return Frozen(self.ds.dimensions)
def get_encoding(self):
encoding = {}
@@ -93,6 +91,4 @@ def get_encoding(self):
return encoding
def close(self):
- if self._isopen:
- self.ds.close()
- self._isopen = False
+ self._manager.close()
diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py
index 98b76928597..574fff744e3 100644
--- a/xarray/backends/pynio_.py
+++ b/xarray/backends/pynio_.py
@@ -1,13 +1,20 @@
from __future__ import absolute_import, division, print_function
-import functools
-
import numpy as np
from .. import Variable
from ..core import indexing
from ..core.utils import Frozen, FrozenOrderedDict
-from .common import AbstractDataStore, BackendArray, DataStorePickleMixin
+from .common import AbstractDataStore, BackendArray
+from .file_manager import CachingFileManager
+from .locks import (
+ HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock, SerializableLock)
+
+
+# PyNIO can invoke netCDF libraries internally
+# Add a dedicated lock just in case NCL as well isn't thread-safe.
+NCL_LOCK = SerializableLock()
+PYNIO_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK, NCL_LOCK])
class NioArrayWrapper(BackendArray):
@@ -20,7 +27,6 @@ def __init__(self, variable_name, datastore):
self.dtype = np.dtype(array.typecode())
def get_array(self):
- self.datastore.assert_open()
return self.datastore.ds.variables[self.variable_name]
def __getitem__(self, key):
@@ -28,46 +34,45 @@ def __getitem__(self, key):
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem)
def _getitem(self, key):
- with self.datastore.ensure_open(autoclose=True):
- array = self.get_array()
+ array = self.get_array()
+ with self.datastore.lock:
if key == () and self.ndim == 0:
return array.get_value()
-
return array[key]
-class NioDataStore(AbstractDataStore, DataStorePickleMixin):
+class NioDataStore(AbstractDataStore):
"""Store for accessing datasets via PyNIO
"""
- def __init__(self, filename, mode='r', autoclose=False):
+ def __init__(self, filename, mode='r', lock=None):
import Nio
- opener = functools.partial(Nio.open_file, filename, mode=mode)
- self._ds = opener()
- self._autoclose = autoclose
- self._isopen = True
- self._opener = opener
- self._mode = mode
+ if lock is None:
+ lock = PYNIO_LOCK
+ self.lock = ensure_lock(lock)
+ self._manager = CachingFileManager(
+ Nio.open_file, filename, lock=lock, mode=mode)
# xarray provides its own support for FillValue,
# so turn off PyNIO's support for the same.
self.ds.set_option('MaskedArrayMode', 'MaskedNever')
+ @property
+ def ds(self):
+ return self._manager.acquire()
+
def open_store_variable(self, name, var):
data = indexing.LazilyOuterIndexedArray(NioArrayWrapper(name, self))
return Variable(var.dimensions, data, var.attributes)
def get_variables(self):
- with self.ensure_open(autoclose=False):
- return FrozenOrderedDict((k, self.open_store_variable(k, v))
- for k, v in self.ds.variables.items())
+ return FrozenOrderedDict((k, self.open_store_variable(k, v))
+ for k, v in self.ds.variables.items())
def get_attrs(self):
- with self.ensure_open(autoclose=True):
- return Frozen(self.ds.attributes)
+ return Frozen(self.ds.attributes)
def get_dimensions(self):
- with self.ensure_open(autoclose=True):
- return Frozen(self.ds.dimensions)
+ return Frozen(self.ds.dimensions)
def get_encoding(self):
encoding = {}
@@ -76,6 +81,4 @@ def get_encoding(self):
return encoding
def close(self):
- if self._isopen:
- self.ds.close()
- self._isopen = False
+ self._manager.close()
diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py
index 5221cf0e913..5746b4e748d 100644
--- a/xarray/backends/rasterio_.py
+++ b/xarray/backends/rasterio_.py
@@ -1,21 +1,20 @@
import os
+import warnings
from collections import OrderedDict
from distutils.version import LooseVersion
-import warnings
import numpy as np
from .. import DataArray
from ..core import indexing
from ..core.utils import is_scalar
-from .common import BackendArray, PickleByReconstructionWrapper
+from .common import BackendArray
+from .file_manager import CachingFileManager
+from .locks import SerializableLock
-try:
- from dask.utils import SerializableLock as Lock
-except ImportError:
- from threading import Lock
-RASTERIO_LOCK = Lock()
+# TODO: should this be GDAL_LOCK instead?
+RASTERIO_LOCK = SerializableLock()
_ERROR_MSG = ('The kind of indexing operation you are trying to do is not '
'valid on rasterio files. Try to load your data with ds.load()'
@@ -25,18 +24,22 @@
class RasterioArrayWrapper(BackendArray):
"""A wrapper around rasterio dataset objects"""
- def __init__(self, riods):
- self.riods = riods
- self._shape = (riods.value.count, riods.value.height,
- riods.value.width)
- self._ndims = len(self.shape)
+ def __init__(self, manager):
+ self.manager = manager
- @property
- def dtype(self):
- dtypes = self.riods.value.dtypes
+ # cannot save riods as an attribute: this would break pickleability
+ riods = manager.acquire()
+
+ self._shape = (riods.count, riods.height, riods.width)
+
+ dtypes = riods.dtypes
if not np.all(np.asarray(dtypes) == dtypes[0]):
raise ValueError('All bands should have the same dtype')
- return np.dtype(dtypes[0])
+ self._dtype = np.dtype(dtypes[0])
+
+ @property
+ def dtype(self):
+ return self._dtype
@property
def shape(self):
@@ -95,7 +98,7 @@ def _get_indexer(self, key):
if isinstance(key[1], np.ndarray) and isinstance(key[2], np.ndarray):
# do outer-style indexing
- np_inds[1:] = np.ix_(*np_inds[1:])
+ np_inds[-2:] = np.ix_(*np_inds[-2:])
return band_key, tuple(window), tuple(squeeze_axis), tuple(np_inds)
@@ -108,7 +111,8 @@ def _getitem(self, key):
stop - start for (start, stop) in window)
out = np.zeros(shape, dtype=self.dtype)
else:
- out = self.riods.value.read(band_key, window=window)
+ riods = self.manager.acquire()
+ out = riods.read(band_key, window=window)
if squeeze_axis:
out = np.squeeze(out, axis=squeeze_axis)
@@ -203,7 +207,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
import rasterio
- riods = PickleByReconstructionWrapper(rasterio.open, filename, mode='r')
+ manager = CachingFileManager(rasterio.open, filename, mode='r')
+ riods = manager.acquire()
if cache is None:
cache = chunks is None
@@ -211,20 +216,20 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
coords = OrderedDict()
# Get bands
- if riods.value.count < 1:
+ if riods.count < 1:
raise ValueError('Unknown dims')
- coords['band'] = np.asarray(riods.value.indexes)
+ coords['band'] = np.asarray(riods.indexes)
# Get coordinates
if LooseVersion(rasterio.__version__) < '1.0':
- transform = riods.value.affine
+ transform = riods.affine
else:
- transform = riods.value.transform
+ transform = riods.transform
if transform.is_rectilinear:
# 1d coordinates
parse = True if parse_coordinates is None else parse_coordinates
if parse:
- nx, ny = riods.value.width, riods.value.height
+ nx, ny = riods.width, riods.height
# xarray coordinates are pixel centered
x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform
_, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform
@@ -234,57 +239,60 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
# 2d coordinates
parse = False if (parse_coordinates is None) else parse_coordinates
if parse:
- warnings.warn("The file coordinates' transformation isn't "
- "rectilinear: xarray won't parse the coordinates "
- "in this case. Set `parse_coordinates=False` to "
- "suppress this warning.",
- RuntimeWarning, stacklevel=3)
+ warnings.warn(
+ "The file coordinates' transformation isn't "
+ "rectilinear: xarray won't parse the coordinates "
+ "in this case. Set `parse_coordinates=False` to "
+ "suppress this warning.",
+ RuntimeWarning, stacklevel=3)
# Attributes
attrs = dict()
# Affine transformation matrix (always available)
# This describes coefficients mapping pixel coordinates to CRS
# For serialization store as tuple of 6 floats, the last row being
- # always (0, 0, 1) per definition (see https://github.com/sgillies/affine)
+ # always (0, 0, 1) per definition (see
+ # https://github.com/sgillies/affine)
attrs['transform'] = tuple(transform)[:6]
- if hasattr(riods.value, 'crs') and riods.value.crs:
+ if hasattr(riods, 'crs') and riods.crs:
# CRS is a dict-like object specific to rasterio
# If CRS is not None, we convert it back to a PROJ4 string using
# rasterio itself
- attrs['crs'] = riods.value.crs.to_string()
- if hasattr(riods.value, 'res'):
+ attrs['crs'] = riods.crs.to_string()
+ if hasattr(riods, 'res'):
# (width, height) tuple of pixels in units of CRS
- attrs['res'] = riods.value.res
- if hasattr(riods.value, 'is_tiled'):
+ attrs['res'] = riods.res
+ if hasattr(riods, 'is_tiled'):
# Is the TIF tiled? (bool)
# We cast it to an int for netCDF compatibility
- attrs['is_tiled'] = np.uint8(riods.value.is_tiled)
- if hasattr(riods.value, 'nodatavals'):
+ attrs['is_tiled'] = np.uint8(riods.is_tiled)
+ if hasattr(riods, 'nodatavals'):
# The nodata values for the raster bands
- attrs['nodatavals'] = tuple([np.nan if nodataval is None else nodataval
- for nodataval in riods.value.nodatavals])
+ attrs['nodatavals'] = tuple(
+ np.nan if nodataval is None else nodataval
+ for nodataval in riods.nodatavals)
# Parse extra metadata from tags, if supported
parsers = {'ENVI': _parse_envi}
- driver = riods.value.driver
+ driver = riods.driver
if driver in parsers:
- meta = parsers[driver](riods.value.tags(ns=driver))
+ meta = parsers[driver](riods.tags(ns=driver))
for k, v in meta.items():
# Add values as coordinates if they match the band count,
# as attributes otherwise
if (isinstance(v, (list, np.ndarray)) and
- len(v) == riods.value.count):
+ len(v) == riods.count):
coords[k] = ('band', np.asarray(v))
else:
attrs[k] = v
- data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(riods))
+ data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager))
# this lets you write arrays loaded with rasterio
data = indexing.CopyOnWriteArray(data)
- if cache and (chunks is None):
+ if cache and chunks is None:
data = indexing.MemoryCachedArray(data)
result = DataArray(data=data, dims=('band', 'y', 'x'),
@@ -306,6 +314,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
lock=lock)
# Make the file closeable
- result._file_obj = riods
+ result._file_obj = manager
return result
diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py
index cd84431f6b7..b009342efb6 100644
--- a/xarray/backends/scipy_.py
+++ b/xarray/backends/scipy_.py
@@ -1,6 +1,5 @@
from __future__ import absolute_import, division, print_function
-import functools
import warnings
from distutils.version import LooseVersion
from io import BytesIO
@@ -11,7 +10,9 @@
from ..core.indexing import NumpyIndexingAdapter
from ..core.pycompat import OrderedDict, basestring, iteritems
from ..core.utils import Frozen, FrozenOrderedDict
-from .common import BackendArray, DataStorePickleMixin, WritableCFDataStore
+from .common import BackendArray, WritableCFDataStore
+from .locks import get_write_lock
+from .file_manager import CachingFileManager, DummyFileManager
from .netcdf3 import (
encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name)
@@ -40,31 +41,26 @@ def __init__(self, variable_name, datastore):
str(array.dtype.itemsize))
def get_array(self):
- self.datastore.assert_open()
return self.datastore.ds.variables[self.variable_name].data
def __getitem__(self, key):
- with self.datastore.ensure_open(autoclose=True):
- data = NumpyIndexingAdapter(self.get_array())[key]
- # Copy data if the source file is mmapped.
- # This makes things consistent
- # with the netCDF4 library by ensuring
- # we can safely read arrays even
- # after closing associated files.
- copy = self.datastore.ds.use_mmap
- return np.array(data, dtype=self.dtype, copy=copy)
+ data = NumpyIndexingAdapter(self.get_array())[key]
+ # Copy data if the source file is mmapped. This makes things consistent
+ # with the netCDF4 library by ensuring we can safely read arrays even
+ # after closing associated files.
+ copy = self.datastore.ds.use_mmap
+ return np.array(data, dtype=self.dtype, copy=copy)
def __setitem__(self, key, value):
- with self.datastore.ensure_open(autoclose=True):
- data = self.datastore.ds.variables[self.variable_name]
- try:
- data[key] = value
- except TypeError:
- if key is Ellipsis:
- # workaround for GH: scipy/scipy#6880
- data[:] = value
- else:
- raise
+ data = self.datastore.ds.variables[self.variable_name]
+ try:
+ data[key] = value
+ except TypeError:
+ if key is Ellipsis:
+ # workaround for GH: scipy/scipy#6880
+ data[:] = value
+ else:
+ raise
def _open_scipy_netcdf(filename, mode, mmap, version):
@@ -106,7 +102,7 @@ def _open_scipy_netcdf(filename, mode, mmap, version):
raise
-class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin):
+class ScipyDataStore(WritableCFDataStore):
"""Store for reading and writing data via scipy.io.netcdf.
This store has the advantage of being able to be initialized with a
@@ -116,7 +112,7 @@ class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin):
"""
def __init__(self, filename_or_obj, mode='r', format=None, group=None,
- writer=None, mmap=None, autoclose=False, lock=None):
+ mmap=None, lock=None):
import scipy
import scipy.io
@@ -140,34 +136,38 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None,
raise ValueError('invalid format for scipy.io.netcdf backend: %r'
% format)
- opener = functools.partial(_open_scipy_netcdf,
- filename=filename_or_obj,
- mode=mode, mmap=mmap, version=version)
- self._ds = opener()
- self._autoclose = autoclose
- self._isopen = True
- self._opener = opener
- self._mode = mode
+ if (lock is None and mode != 'r' and
+ isinstance(filename_or_obj, basestring)):
+ lock = get_write_lock(filename_or_obj)
+
+ if isinstance(filename_or_obj, basestring):
+ manager = CachingFileManager(
+ _open_scipy_netcdf, filename_or_obj, mode=mode, lock=lock,
+ kwargs=dict(mmap=mmap, version=version))
+ else:
+ scipy_dataset = _open_scipy_netcdf(
+ filename_or_obj, mode=mode, mmap=mmap, version=version)
+ manager = DummyFileManager(scipy_dataset)
+
+ self._manager = manager
- super(ScipyDataStore, self).__init__(writer, lock=lock)
+ @property
+ def ds(self):
+ return self._manager.acquire()
def open_store_variable(self, name, var):
- with self.ensure_open(autoclose=False):
- return Variable(var.dimensions, ScipyArrayWrapper(name, self),
- _decode_attrs(var._attributes))
+ return Variable(var.dimensions, ScipyArrayWrapper(name, self),
+ _decode_attrs(var._attributes))
def get_variables(self):
- with self.ensure_open(autoclose=False):
- return FrozenOrderedDict((k, self.open_store_variable(k, v))
- for k, v in iteritems(self.ds.variables))
+ return FrozenOrderedDict((k, self.open_store_variable(k, v))
+ for k, v in iteritems(self.ds.variables))
def get_attrs(self):
- with self.ensure_open(autoclose=True):
- return Frozen(_decode_attrs(self.ds._attributes))
+ return Frozen(_decode_attrs(self.ds._attributes))
def get_dimensions(self):
- with self.ensure_open(autoclose=True):
- return Frozen(self.ds.dimensions)
+ return Frozen(self.ds.dimensions)
def get_encoding(self):
encoding = {}
@@ -176,22 +176,20 @@ def get_encoding(self):
return encoding
def set_dimension(self, name, length, is_unlimited=False):
- with self.ensure_open(autoclose=False):
- if name in self.ds.dimensions:
- raise ValueError('%s does not support modifying dimensions'
- % type(self).__name__)
- dim_length = length if not is_unlimited else None
- self.ds.createDimension(name, dim_length)
+ if name in self.ds.dimensions:
+ raise ValueError('%s does not support modifying dimensions'
+ % type(self).__name__)
+ dim_length = length if not is_unlimited else None
+ self.ds.createDimension(name, dim_length)
def _validate_attr_key(self, key):
if not is_valid_nc3_name(key):
raise ValueError("Not a valid attribute name")
def set_attribute(self, key, value):
- with self.ensure_open(autoclose=False):
- self._validate_attr_key(key)
- value = encode_nc3_attr_value(value)
- setattr(self.ds, key, value)
+ self._validate_attr_key(key)
+ value = encode_nc3_attr_value(value)
+ setattr(self.ds, key, value)
def encode_variable(self, variable):
variable = encode_nc3_variable(variable)
@@ -219,27 +217,8 @@ def prepare_variable(self, name, variable, check_encoding=False,
return target, data
- def sync(self, compute=True):
- if not compute:
- raise NotImplementedError(
- 'compute=False is not supported for the scipy backend yet')
- with self.ensure_open(autoclose=True):
- super(ScipyDataStore, self).sync(compute=compute)
- self.ds.flush()
+ def sync(self):
+ self.ds.sync()
def close(self):
- self.ds.close()
- self._isopen = False
-
- def __exit__(self, type, value, tb):
- self.close()
-
- def __setstate__(self, state):
- filename = state['_opener'].keywords['filename']
- if hasattr(filename, 'seek'):
- # it's a file-like object
- # seek to the start of the file so scipy can read it
- filename.seek(0)
- super(ScipyDataStore, self).__setstate__(state)
- self._ds = None
- self._isopen = False
+ self._manager.close()
diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py
index 47b90c8a617..5f19c826289 100644
--- a/xarray/backends/zarr.py
+++ b/xarray/backends/zarr.py
@@ -217,8 +217,7 @@ class ZarrStore(AbstractWritableDataStore):
"""
@classmethod
- def open_group(cls, store, mode='r', synchronizer=None, group=None,
- writer=None):
+ def open_group(cls, store, mode='r', synchronizer=None, group=None):
import zarr
min_zarr = '2.2'
@@ -230,24 +229,14 @@ def open_group(cls, store, mode='r', synchronizer=None, group=None,
"#installation" % min_zarr)
zarr_group = zarr.open_group(store=store, mode=mode,
synchronizer=synchronizer, path=group)
- return cls(zarr_group, writer=writer)
+ return cls(zarr_group)
- def __init__(self, zarr_group, writer=None):
+ def __init__(self, zarr_group):
self.ds = zarr_group
self._read_only = self.ds.read_only
self._synchronizer = self.ds.synchronizer
self._group = self.ds.path
- if writer is None:
- # by default, we should not need a lock for writing zarr because
- # we do not (yet) allow overlapping chunks during write
- zarr_writer = ArrayWriter(lock=False)
- else:
- zarr_writer = writer
-
- # do we need to define attributes for all of the opener keyword args?
- super(ZarrStore, self).__init__(zarr_writer)
-
def open_store_variable(self, name, zarr_array):
data = indexing.LazilyOuterIndexedArray(ZarrArrayWrapper(name, self))
dimensions, attributes = _get_zarr_dims_and_attrs(zarr_array,
@@ -334,8 +323,8 @@ def store(self, variables, attributes, *args, **kwargs):
AbstractWritableDataStore.store(self, variables, attributes,
*args, **kwargs)
- def sync(self, compute=True):
- self.delayed_store = self.writer.sync(compute=compute)
+ def sync(self):
+ pass
def open_zarr(store, group=None, synchronizer=None, auto_chunk=True,
diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py
new file mode 100644
index 00000000000..83e8c7a7e4b
--- /dev/null
+++ b/xarray/coding/cftime_offsets.py
@@ -0,0 +1,735 @@
+"""Time offset classes for use with cftime.datetime objects"""
+# The offset classes and mechanisms for generating time ranges defined in
+# this module were copied/adapted from those defined in pandas. See in
+# particular the objects and methods defined in pandas.tseries.offsets
+# and pandas.core.indexes.datetimes.
+
+# For reference, here is a copy of the pandas copyright notice:
+
+# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team
+# All rights reserved.
+
+# Copyright (c) 2008-2011 AQR Capital Management, LLC
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following
+# disclaimer in the documentation and/or other materials provided
+# with the distribution.
+
+# * Neither the name of the copyright holder nor the names of any
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import re
+from datetime import timedelta
+from functools import partial
+
+import numpy as np
+
+from ..core.pycompat import basestring
+from .cftimeindex import CFTimeIndex, _parse_iso8601_with_reso
+from .times import format_cftime_datetime
+
+
+def get_date_type(calendar):
+ """Return the cftime date type for a given calendar name."""
+ try:
+ import cftime
+ except ImportError:
+ raise ImportError(
+ 'cftime is required for dates with non-standard calendars')
+ else:
+ calendars = {
+ 'noleap': cftime.DatetimeNoLeap,
+ '360_day': cftime.Datetime360Day,
+ '365_day': cftime.DatetimeNoLeap,
+ '366_day': cftime.DatetimeAllLeap,
+ 'gregorian': cftime.DatetimeGregorian,
+ 'proleptic_gregorian': cftime.DatetimeProlepticGregorian,
+ 'julian': cftime.DatetimeJulian,
+ 'all_leap': cftime.DatetimeAllLeap,
+ 'standard': cftime.DatetimeProlepticGregorian
+ }
+ return calendars[calendar]
+
+
+class BaseCFTimeOffset(object):
+ _freq = None
+
+ def __init__(self, n=1):
+ if not isinstance(n, int):
+ raise TypeError(
+ "The provided multiple 'n' must be an integer. "
+ "Instead a value of type {!r} was provided.".format(type(n)))
+ self.n = n
+
+ def rule_code(self):
+ return self._freq
+
+ def __eq__(self, other):
+ return self.n == other.n and self.rule_code() == other.rule_code()
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __add__(self, other):
+ return self.__apply__(other)
+
+ def __sub__(self, other):
+ import cftime
+
+ if isinstance(other, cftime.datetime):
+ raise TypeError('Cannot subtract a cftime.datetime '
+ 'from a time offset.')
+ elif type(other) == type(self):
+ return type(self)(self.n - other.n)
+ else:
+ return NotImplemented
+
+ def __mul__(self, other):
+ return type(self)(n=other * self.n)
+
+ def __neg__(self):
+ return self * -1
+
+ def __rmul__(self, other):
+ return self.__mul__(other)
+
+ def __radd__(self, other):
+ return self.__add__(other)
+
+ def __rsub__(self, other):
+ if isinstance(other, BaseCFTimeOffset) and type(self) != type(other):
+ raise TypeError('Cannot subtract cftime offsets of differing '
+ 'types')
+ return -self + other
+
+ def __apply__(self):
+ return NotImplemented
+
+ def onOffset(self, date):
+ """Check if the given date is in the set of possible dates created
+ using a length-one version of this offset class."""
+ test_date = (self + date) - self
+ return date == test_date
+
+ def rollforward(self, date):
+ if self.onOffset(date):
+ return date
+ else:
+ return date + type(self)()
+
+ def rollback(self, date):
+ if self.onOffset(date):
+ return date
+ else:
+ return date - type(self)()
+
+ def __str__(self):
+ return '<{}: n={}>'.format(type(self).__name__, self.n)
+
+ def __repr__(self):
+ return str(self)
+
+
+def _days_in_month(date):
+ """The number of days in the month of the given date"""
+ if date.month == 12:
+ reference = type(date)(date.year + 1, 1, 1)
+ else:
+ reference = type(date)(date.year, date.month + 1, 1)
+ return (reference - timedelta(days=1)).day
+
+
+def _adjust_n_months(other_day, n, reference_day):
+ """Adjust the number of times a monthly offset is applied based
+ on the day of a given date, and the reference day provided.
+ """
+ if n > 0 and other_day < reference_day:
+ n = n - 1
+ elif n <= 0 and other_day > reference_day:
+ n = n + 1
+ return n
+
+
+def _adjust_n_years(other, n, month, reference_day):
+ """Adjust the number of times an annual offset is applied based on
+ another date, and the reference day provided"""
+ if n > 0:
+ if other.month < month or (other.month == month and
+ other.day < reference_day):
+ n -= 1
+ else:
+ if other.month > month or (other.month == month and
+ other.day > reference_day):
+ n += 1
+ return n
+
+
+def _shift_months(date, months, day_option='start'):
+ """Shift the date to a month start or end a given number of months away.
+ """
+ delta_year = (date.month + months) // 12
+ month = (date.month + months) % 12
+
+ if month == 0:
+ month = 12
+ delta_year = delta_year - 1
+ year = date.year + delta_year
+
+ if day_option == 'start':
+ day = 1
+ elif day_option == 'end':
+ reference = type(date)(year, month, 1)
+ day = _days_in_month(reference)
+ else:
+ raise ValueError(day_option)
+ return date.replace(year=year, month=month, day=day)
+
+
+class MonthBegin(BaseCFTimeOffset):
+ _freq = 'MS'
+
+ def __apply__(self, other):
+ n = _adjust_n_months(other.day, self.n, 1)
+ return _shift_months(other, n, 'start')
+
+ def onOffset(self, date):
+ """Check if the given date is in the set of possible dates created
+ using a length-one version of this offset class."""
+ return date.day == 1
+
+
+class MonthEnd(BaseCFTimeOffset):
+ _freq = 'M'
+
+ def __apply__(self, other):
+ n = _adjust_n_months(other.day, self.n, _days_in_month(other))
+ return _shift_months(other, n, 'end')
+
+ def onOffset(self, date):
+ """Check if the given date is in the set of possible dates created
+ using a length-one version of this offset class."""
+ return date.day == _days_in_month(date)
+
+
+_MONTH_ABBREVIATIONS = {
+ 1: 'JAN',
+ 2: 'FEB',
+ 3: 'MAR',
+ 4: 'APR',
+ 5: 'MAY',
+ 6: 'JUN',
+ 7: 'JUL',
+ 8: 'AUG',
+ 9: 'SEP',
+ 10: 'OCT',
+ 11: 'NOV',
+ 12: 'DEC'
+}
+
+
+class YearOffset(BaseCFTimeOffset):
+ _freq = None
+ _day_option = None
+ _default_month = None
+
+ def __init__(self, n=1, month=None):
+ BaseCFTimeOffset.__init__(self, n)
+ if month is None:
+ self.month = self._default_month
+ else:
+ self.month = month
+ if not isinstance(self.month, int):
+ raise TypeError("'self.month' must be an integer value between 1 "
+ "and 12. Instead, it was set to a value of "
+ "{!r}".format(self.month))
+ elif not (1 <= self.month <= 12):
+ raise ValueError("'self.month' must be an integer value between 1 "
+ "and 12. Instead, it was set to a value of "
+ "{!r}".format(self.month))
+
+ def __apply__(self, other):
+ if self._day_option == 'start':
+ reference_day = 1
+ elif self._day_option == 'end':
+ reference_day = _days_in_month(other)
+ else:
+ raise ValueError(self._day_option)
+ years = _adjust_n_years(other, self.n, self.month, reference_day)
+ months = years * 12 + (self.month - other.month)
+ return _shift_months(other, months, self._day_option)
+
+ def __sub__(self, other):
+ import cftime
+
+ if isinstance(other, cftime.datetime):
+ raise TypeError('Cannot subtract cftime.datetime from offset.')
+ elif type(other) == type(self) and other.month == self.month:
+ return type(self)(self.n - other.n, month=self.month)
+ else:
+ return NotImplemented
+
+ def __mul__(self, other):
+ return type(self)(n=other * self.n, month=self.month)
+
+ def rule_code(self):
+ return '{}-{}'.format(self._freq, _MONTH_ABBREVIATIONS[self.month])
+
+ def __str__(self):
+ return '<{}: n={}, month={}>'.format(
+ type(self).__name__, self.n, self.month)
+
+
+class YearBegin(YearOffset):
+ _freq = 'AS'
+ _day_option = 'start'
+ _default_month = 1
+
+ def onOffset(self, date):
+ """Check if the given date is in the set of possible dates created
+ using a length-one version of this offset class."""
+ return date.day == 1 and date.month == self.month
+
+ def rollforward(self, date):
+ """Roll date forward to nearest start of year"""
+ if self.onOffset(date):
+ return date
+ else:
+ return date + YearBegin(month=self.month)
+
+ def rollback(self, date):
+ """Roll date backward to nearest start of year"""
+ if self.onOffset(date):
+ return date
+ else:
+ return date - YearBegin(month=self.month)
+
+
+class YearEnd(YearOffset):
+ _freq = 'A'
+ _day_option = 'end'
+ _default_month = 12
+
+ def onOffset(self, date):
+ """Check if the given date is in the set of possible dates created
+ using a length-one version of this offset class."""
+ return date.day == _days_in_month(date) and date.month == self.month
+
+ def rollforward(self, date):
+ """Roll date forward to nearest end of year"""
+ if self.onOffset(date):
+ return date
+ else:
+ return date + YearEnd(month=self.month)
+
+ def rollback(self, date):
+ """Roll date backward to nearest end of year"""
+ if self.onOffset(date):
+ return date
+ else:
+ return date - YearEnd(month=self.month)
+
+
+class Day(BaseCFTimeOffset):
+ _freq = 'D'
+
+ def __apply__(self, other):
+ return other + timedelta(days=self.n)
+
+
+class Hour(BaseCFTimeOffset):
+ _freq = 'H'
+
+ def __apply__(self, other):
+ return other + timedelta(hours=self.n)
+
+
+class Minute(BaseCFTimeOffset):
+ _freq = 'T'
+
+ def __apply__(self, other):
+ return other + timedelta(minutes=self.n)
+
+
+class Second(BaseCFTimeOffset):
+ _freq = 'S'
+
+ def __apply__(self, other):
+ return other + timedelta(seconds=self.n)
+
+
+_FREQUENCIES = {
+ 'A': YearEnd,
+ 'AS': YearBegin,
+ 'Y': YearEnd,
+ 'YS': YearBegin,
+ 'M': MonthEnd,
+ 'MS': MonthBegin,
+ 'D': Day,
+ 'H': Hour,
+ 'T': Minute,
+ 'min': Minute,
+ 'S': Second,
+ 'AS-JAN': partial(YearBegin, month=1),
+ 'AS-FEB': partial(YearBegin, month=2),
+ 'AS-MAR': partial(YearBegin, month=3),
+ 'AS-APR': partial(YearBegin, month=4),
+ 'AS-MAY': partial(YearBegin, month=5),
+ 'AS-JUN': partial(YearBegin, month=6),
+ 'AS-JUL': partial(YearBegin, month=7),
+ 'AS-AUG': partial(YearBegin, month=8),
+ 'AS-SEP': partial(YearBegin, month=9),
+ 'AS-OCT': partial(YearBegin, month=10),
+ 'AS-NOV': partial(YearBegin, month=11),
+ 'AS-DEC': partial(YearBegin, month=12),
+ 'A-JAN': partial(YearEnd, month=1),
+ 'A-FEB': partial(YearEnd, month=2),
+ 'A-MAR': partial(YearEnd, month=3),
+ 'A-APR': partial(YearEnd, month=4),
+ 'A-MAY': partial(YearEnd, month=5),
+ 'A-JUN': partial(YearEnd, month=6),
+ 'A-JUL': partial(YearEnd, month=7),
+ 'A-AUG': partial(YearEnd, month=8),
+ 'A-SEP': partial(YearEnd, month=9),
+ 'A-OCT': partial(YearEnd, month=10),
+ 'A-NOV': partial(YearEnd, month=11),
+ 'A-DEC': partial(YearEnd, month=12)
+}
+
+
+_FREQUENCY_CONDITION = '|'.join(_FREQUENCIES.keys())
+_PATTERN = '^((?P\d+)|())(?P({0}))$'.format(
+ _FREQUENCY_CONDITION)
+
+
+def to_offset(freq):
+ """Convert a frequency string to the appropriate subclass of
+ BaseCFTimeOffset."""
+ if isinstance(freq, BaseCFTimeOffset):
+ return freq
+ else:
+ try:
+ freq_data = re.match(_PATTERN, freq).groupdict()
+ except AttributeError:
+ raise ValueError('Invalid frequency string provided')
+
+ freq = freq_data['freq']
+ multiples = freq_data['multiple']
+ if multiples is None:
+ multiples = 1
+ else:
+ multiples = int(multiples)
+
+ return _FREQUENCIES[freq](n=multiples)
+
+
+def to_cftime_datetime(date_str_or_date, calendar=None):
+ import cftime
+
+ if isinstance(date_str_or_date, basestring):
+ if calendar is None:
+ raise ValueError(
+ 'If converting a string to a cftime.datetime object, '
+ 'a calendar type must be provided')
+ date, _ = _parse_iso8601_with_reso(get_date_type(calendar),
+ date_str_or_date)
+ return date
+ elif isinstance(date_str_or_date, cftime.datetime):
+ return date_str_or_date
+ else:
+ raise TypeError("date_str_or_date must be a string or a "
+ 'subclass of cftime.datetime. Instead got '
+ '{!r}.'.format(date_str_or_date))
+
+
+def normalize_date(date):
+ """Round datetime down to midnight."""
+ return date.replace(hour=0, minute=0, second=0, microsecond=0)
+
+
+def _maybe_normalize_date(date, normalize):
+ """Round datetime down to midnight if normalize is True."""
+ if normalize:
+ return normalize_date(date)
+ else:
+ return date
+
+
+def _generate_linear_range(start, end, periods):
+ """Generate an equally-spaced sequence of cftime.datetime objects between
+ and including two dates (whose length equals the number of periods)."""
+ import cftime
+
+ total_seconds = (end - start).total_seconds()
+ values = np.linspace(0., total_seconds, periods, endpoint=True)
+ units = 'seconds since {}'.format(format_cftime_datetime(start))
+ calendar = start.calendar
+ return cftime.num2date(values, units=units, calendar=calendar,
+ only_use_cftime_datetimes=True)
+
+
+def _generate_range(start, end, periods, offset):
+ """Generate a regular range of cftime.datetime objects with a
+ given time offset.
+
+ Adapted from pandas.tseries.offsets.generate_range.
+
+ Parameters
+ ----------
+ start : cftime.datetime, or None
+ Start of range
+ end : cftime.datetime, or None
+ End of range
+ periods : int, or None
+ Number of elements in the sequence
+ offset : BaseCFTimeOffset
+ An offset class designed for working with cftime.datetime objects
+
+ Returns
+ -------
+ A generator object
+ """
+ if start:
+ start = offset.rollforward(start)
+
+ if end:
+ end = offset.rollback(end)
+
+ if periods is None and end < start:
+ end = None
+ periods = 0
+
+ if end is None:
+ end = start + (periods - 1) * offset
+
+ if start is None:
+ start = end - (periods - 1) * offset
+
+ current = start
+ if offset.n >= 0:
+ while current <= end:
+ yield current
+
+ next_date = current + offset
+ if next_date <= current:
+ raise ValueError('Offset {offset} did not increment date'
+ .format(offset=offset))
+ current = next_date
+ else:
+ while current >= end:
+ yield current
+
+ next_date = current + offset
+ if next_date >= current:
+ raise ValueError('Offset {offset} did not decrement date'
+ .format(offset=offset))
+ current = next_date
+
+
+def _count_not_none(*args):
+ """Compute the number of non-None arguments."""
+ return sum([arg is not None for arg in args])
+
+
+def cftime_range(start=None, end=None, periods=None, freq='D',
+ tz=None, normalize=False, name=None, closed=None,
+ calendar='standard'):
+ """Return a fixed frequency CFTimeIndex.
+
+ Parameters
+ ----------
+ start : str or cftime.datetime, optional
+ Left bound for generating dates.
+ end : str or cftime.datetime, optional
+ Right bound for generating dates.
+ periods : integer, optional
+ Number of periods to generate.
+ freq : str, default 'D', BaseCFTimeOffset, or None
+ Frequency strings can have multiples, e.g. '5H'.
+ normalize : bool, default False
+ Normalize start/end dates to midnight before generating date range.
+ name : str, default None
+ Name of the resulting index
+ closed : {None, 'left', 'right'}, optional
+ Make the interval closed with respect to the given frequency to the
+ 'left', 'right', or both sides (None, the default).
+ calendar : str
+ Calendar type for the datetimes (default 'standard').
+
+ Returns
+ -------
+ CFTimeIndex
+
+ Notes
+ -----
+
+ This function is an analog of ``pandas.date_range`` for use in generating
+ sequences of ``cftime.datetime`` objects. It supports most of the
+ features of ``pandas.date_range`` (e.g. specifying how the index is
+ ``closed`` on either side, or whether or not to ``normalize`` the start and
+ end bounds); however, there are some notable exceptions:
+
+ - You cannot specify a ``tz`` (time zone) argument.
+ - Start or end dates specified as partial-datetime strings must use the
+ `ISO-8601 format `_.
+ - It supports many, but not all, frequencies supported by
+ ``pandas.date_range``. For example it does not currently support any of
+ the business-related, semi-monthly, or sub-second frequencies.
+ - Compound sub-monthly frequencies are not supported, e.g. '1H1min', as
+ these can easily be written in terms of the finest common resolution,
+ e.g. '61min'.
+
+ Valid simple frequency strings for use with ``cftime``-calendars include
+ any multiples of the following.
+
+ +--------+-----------------------+
+ | Alias | Description |
+ +========+=======================+
+ | A, Y | Year-end frequency |
+ +--------+-----------------------+
+ | AS, YS | Year-start frequency |
+ +--------+-----------------------+
+ | M | Month-end frequency |
+ +--------+-----------------------+
+ | MS | Month-start frequency |
+ +--------+-----------------------+
+ | D | Day frequency |
+ +--------+-----------------------+
+ | H | Hour frequency |
+ +--------+-----------------------+
+ | T, min | Minute frequency |
+ +--------+-----------------------+
+ | S | Second frequency |
+ +--------+-----------------------+
+
+ Any multiples of the following anchored offsets are also supported.
+
+ +----------+-------------------------------------------------------------------+
+ | Alias | Description |
+ +==========+===================================================================+
+ | A(S)-JAN | Annual frequency, anchored at the end (or beginning) of January |
+ +----------+-------------------------------------------------------------------+
+ | A(S)-FEB | Annual frequency, anchored at the end (or beginning) of February |
+ +----------+-------------------------------------------------------------------+
+ | A(S)-MAR | Annual frequency, anchored at the end (or beginning) of March |
+ +----------+-------------------------------------------------------------------+
+ | A(S)-APR | Annual frequency, anchored at the end (or beginning) of April |
+ +----------+-------------------------------------------------------------------+
+ | A(S)-MAY | Annual frequency, anchored at the end (or beginning) of May |
+ +----------+-------------------------------------------------------------------+
+ | A(S)-JUN | Annual frequency, anchored at the end (or beginning) of June |
+ +----------+-------------------------------------------------------------------+
+ | A(S)-JUL | Annual frequency, anchored at the end (or beginning) of July |
+ +----------+-------------------------------------------------------------------+
+ | A(S)-AUG | Annual frequency, anchored at the end (or beginning) of August |
+ +----------+-------------------------------------------------------------------+
+ | A(S)-SEP | Annual frequency, anchored at the end (or beginning) of September |
+ +----------+-------------------------------------------------------------------+
+ | A(S)-OCT | Annual frequency, anchored at the end (or beginning) of October |
+ +----------+-------------------------------------------------------------------+
+ | A(S)-NOV | Annual frequency, anchored at the end (or beginning) of November |
+ +----------+-------------------------------------------------------------------+
+ | A(S)-DEC | Annual frequency, anchored at the end (or beginning) of December |
+ +----------+-------------------------------------------------------------------+
+
+ Finally, the following calendar aliases are supported.
+
+ +--------------------------------+---------------------------------------+
+ | Alias | Date type |
+ +================================+=======================================+
+ | standard, proleptic_gregorian | ``cftime.DatetimeProlepticGregorian`` |
+ +--------------------------------+---------------------------------------+
+ | gregorian | ``cftime.DatetimeGregorian`` |
+ +--------------------------------+---------------------------------------+
+ | noleap, 365_day | ``cftime.DatetimeNoLeap`` |
+ +--------------------------------+---------------------------------------+
+ | all_leap, 366_day | ``cftime.DatetimeAllLeap`` |
+ +--------------------------------+---------------------------------------+
+ | 360_day | ``cftime.Datetime360Day`` |
+ +--------------------------------+---------------------------------------+
+ | julian | ``cftime.DatetimeJulian`` |
+ +--------------------------------+---------------------------------------+
+
+ Examples
+ --------
+
+ This function returns a ``CFTimeIndex``, populated with ``cftime.datetime``
+ objects associated with the specified calendar type, e.g.
+
+ >>> xr.cftime_range(start='2000', periods=6, freq='2MS', calendar='noleap')
+ CFTimeIndex([2000-01-01 00:00:00, 2000-03-01 00:00:00, 2000-05-01 00:00:00,
+ 2000-07-01 00:00:00, 2000-09-01 00:00:00, 2000-11-01 00:00:00],
+ dtype='object')
+
+ As in the standard pandas function, three of the ``start``, ``end``,
+ ``periods``, or ``freq`` arguments must be specified at a given time, with
+ the other set to ``None``. See the `pandas documentation
+ `_
+ for more examples of the behavior of ``date_range`` with each of the
+ parameters.
+
+ See Also
+ --------
+ pandas.date_range
+ """ # noqa: E501
+ # Adapted from pandas.core.indexes.datetimes._generate_range.
+ if _count_not_none(start, end, periods, freq) != 3:
+ raise ValueError(
+ "Of the arguments 'start', 'end', 'periods', and 'freq', three "
+ "must be specified at a time.")
+
+ if start is not None:
+ start = to_cftime_datetime(start, calendar)
+ start = _maybe_normalize_date(start, normalize)
+ if end is not None:
+ end = to_cftime_datetime(end, calendar)
+ end = _maybe_normalize_date(end, normalize)
+
+ if freq is None:
+ dates = _generate_linear_range(start, end, periods)
+ else:
+ offset = to_offset(freq)
+ dates = np.array(list(_generate_range(start, end, periods, offset)))
+
+ left_closed = False
+ right_closed = False
+
+ if closed is None:
+ left_closed = True
+ right_closed = True
+ elif closed == 'left':
+ left_closed = True
+ elif closed == 'right':
+ right_closed = True
+ else:
+ raise ValueError("Closed must be either 'left', 'right' or None")
+
+ if (not left_closed and len(dates) and
+ start is not None and dates[0] == start):
+ dates = dates[1:]
+ if (not right_closed and len(dates) and
+ end is not None and dates[-1] == end):
+ dates = dates[:-1]
+
+ return CFTimeIndex(dates, name=name)
diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py
index eb8cae2f398..dea896c199a 100644
--- a/xarray/coding/cftimeindex.py
+++ b/xarray/coding/cftimeindex.py
@@ -1,4 +1,46 @@
+"""DatetimeIndex analog for cftime.datetime objects"""
+# The pandas.Index subclass defined here was copied and adapted for
+# use with cftime.datetime objects based on the source code defining
+# pandas.DatetimeIndex.
+
+# For reference, here is a copy of the pandas copyright notice:
+
+# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team
+# All rights reserved.
+
+# Copyright (c) 2008-2011 AQR Capital Management, LLC
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following
+# disclaimer in the documentation and/or other materials provided
+# with the distribution.
+
+# * Neither the name of the copyright holder nor the names of any
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
from __future__ import absolute_import
+
import re
from datetime import timedelta
@@ -116,28 +158,43 @@ def f(self):
def get_date_type(self):
- return type(self._data[0])
+ if self._data.size:
+ return type(self._data[0])
+ else:
+ return None
def assert_all_valid_date_type(data):
import cftime
- sample = data[0]
- date_type = type(sample)
- if not isinstance(sample, cftime.datetime):
- raise TypeError(
- 'CFTimeIndex requires cftime.datetime '
- 'objects. Got object of {}.'.format(date_type))
- if not all(isinstance(value, date_type) for value in data):
- raise TypeError(
- 'CFTimeIndex requires using datetime '
- 'objects of all the same type. Got\n{}.'.format(data))
+ if data.size:
+ sample = data[0]
+ date_type = type(sample)
+ if not isinstance(sample, cftime.datetime):
+ raise TypeError(
+ 'CFTimeIndex requires cftime.datetime '
+ 'objects. Got object of {}.'.format(date_type))
+ if not all(isinstance(value, date_type) for value in data):
+ raise TypeError(
+ 'CFTimeIndex requires using datetime '
+ 'objects of all the same type. Got\n{}.'.format(data))
class CFTimeIndex(pd.Index):
"""Custom Index for working with CF calendars and dates
All elements of a CFTimeIndex must be cftime.datetime objects.
+
+ Parameters
+ ----------
+ data : array or CFTimeIndex
+ Sequence of cftime.datetime objects to use in index
+ name : str, default None
+ Name of the resulting index
+
+ See Also
+ --------
+ cftime_range
"""
year = _field_accessor('year', 'The year of the datetime')
month = _field_accessor('month', 'The month of the datetime')
@@ -149,10 +206,14 @@ class CFTimeIndex(pd.Index):
'The microseconds of the datetime')
date_type = property(get_date_type)
- def __new__(cls, data):
+ def __new__(cls, data, name=None):
+ if name is None and hasattr(data, 'name'):
+ name = data.name
+
result = object.__new__(cls)
- assert_all_valid_date_type(data)
- result._data = np.array(data)
+ result._data = np.array(data, dtype='O')
+ assert_all_valid_date_type(result._data)
+ result.name = name
return result
def _partial_date_slice(self, resolution, parsed):
@@ -254,3 +315,80 @@ def __contains__(self, key):
def contains(self, key):
"""Needed for .loc based partial-string indexing"""
return self.__contains__(key)
+
+ def shift(self, n, freq):
+ """Shift the CFTimeIndex a multiple of the given frequency.
+
+ See the documentation for :py:func:`~xarray.cftime_range` for a
+ complete listing of valid frequency strings.
+
+ Parameters
+ ----------
+ n : int
+ Periods to shift by
+ freq : str or datetime.timedelta
+ A frequency string or datetime.timedelta object to shift by
+
+ Returns
+ -------
+ CFTimeIndex
+
+ See also
+ --------
+ pandas.DatetimeIndex.shift
+
+ Examples
+ --------
+ >>> index = xr.cftime_range('2000', periods=1, freq='M')
+ >>> index
+ CFTimeIndex([2000-01-31 00:00:00], dtype='object')
+ >>> index.shift(1, 'M')
+ CFTimeIndex([2000-02-29 00:00:00], dtype='object')
+ """
+ from .cftime_offsets import to_offset
+
+ if not isinstance(n, int):
+ raise TypeError("'n' must be an int, got {}.".format(n))
+ if isinstance(freq, timedelta):
+ return self + n * freq
+ elif isinstance(freq, pycompat.basestring):
+ return self + n * to_offset(freq)
+ else:
+ raise TypeError(
+ "'freq' must be of type "
+ "str or datetime.timedelta, got {}.".format(freq))
+
+ def __add__(self, other):
+ return CFTimeIndex(np.array(self) + other)
+
+ def __radd__(self, other):
+ return CFTimeIndex(other + np.array(self))
+
+ def __sub__(self, other):
+ return CFTimeIndex(np.array(self) - other)
+
+
+def _parse_iso8601_without_reso(date_type, datetime_str):
+ date, _ = _parse_iso8601_with_reso(date_type, datetime_str)
+ return date
+
+
+def _parse_array_of_cftime_strings(strings, date_type):
+ """Create a numpy array from an array of strings.
+
+ For use in generating dates from strings for use with interp. Assumes the
+ array is either 0-dimensional or 1-dimensional.
+
+ Parameters
+ ----------
+ strings : array of strings
+ Strings to convert to dates
+ date_type : cftime.datetime type
+ Calendar type to use for dates
+
+ Returns
+ -------
+ np.array
+ """
+ return np.array([_parse_iso8601_without_reso(date_type, s)
+ for s in strings.ravel()]).reshape(strings.shape)
diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py
index 87b17d9175e..3502fd773d7 100644
--- a/xarray/coding/strings.py
+++ b/xarray/coding/strings.py
@@ -9,8 +9,8 @@
from ..core.pycompat import bytes_type, dask_array_type, unicode_type
from ..core.variable import Variable
from .variables import (
- VariableCoder, lazy_elemwise_func, pop_to,
- safe_setitem, unpack_for_decoding, unpack_for_encoding)
+ VariableCoder, lazy_elemwise_func, pop_to, safe_setitem,
+ unpack_for_decoding, unpack_for_encoding)
def create_vlen_dtype(element_type):
diff --git a/xarray/coding/times.py b/xarray/coding/times.py
index d946e2ed378..dff7e75bdcf 100644
--- a/xarray/coding/times.py
+++ b/xarray/coding/times.py
@@ -9,8 +9,8 @@
import numpy as np
import pandas as pd
-from ..core.common import contains_cftime_datetimes
from ..core import indexing
+from ..core.common import contains_cftime_datetimes
from ..core.formatting import first_n_items, format_timestamp, last_item
from ..core.options import OPTIONS
from ..core.pycompat import PY3
@@ -183,8 +183,11 @@ def decode_cf_datetime(num_dates, units, calendar=None,
# fixes: https://github.com/pydata/pandas/issues/14068
# these lines check if the the lowest or the highest value in dates
# cause an OutOfBoundsDatetime (Overflow) error
- pd.to_timedelta(flat_num_dates.min(), delta) + ref_date
- pd.to_timedelta(flat_num_dates.max(), delta) + ref_date
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', 'invalid value encountered',
+ RuntimeWarning)
+ pd.to_timedelta(flat_num_dates.min(), delta) + ref_date
+ pd.to_timedelta(flat_num_dates.max(), delta) + ref_date
# Cast input dates to integers of nanoseconds because `pd.to_datetime`
# works much faster when dealing with integers
diff --git a/xarray/conventions.py b/xarray/conventions.py
index 67dcb8d6d4e..f60ee6b2c15 100644
--- a/xarray/conventions.py
+++ b/xarray/conventions.py
@@ -6,11 +6,11 @@
import numpy as np
import pandas as pd
-from .coding import times, strings, variables
+from .coding import strings, times, variables
from .coding.variables import SerializationWarning
from .core import duck_array_ops, indexing
from .core.pycompat import (
- OrderedDict, basestring, bytes_type, iteritems, dask_array_type,
+ OrderedDict, basestring, bytes_type, dask_array_type, iteritems,
unicode_type)
from .core.variable import IndexVariable, Variable, as_variable
diff --git a/xarray/core/accessors.py b/xarray/core/accessors.py
index 81af0532d93..72791ed73ec 100644
--- a/xarray/core/accessors.py
+++ b/xarray/core/accessors.py
@@ -3,7 +3,7 @@
import numpy as np
import pandas as pd
-from .common import is_np_datetime_like, _contains_datetime_like_objects
+from .common import _contains_datetime_like_objects, is_np_datetime_like
from .pycompat import dask_array_type
diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py
index b0d2a49c29f..f82ddef25ba 100644
--- a/xarray/core/alignment.py
+++ b/xarray/core/alignment.py
@@ -174,11 +174,14 @@ def deep_align(objects, join='inner', copy=True, indexes=None,
This function is not public API.
"""
+ from .dataarray import DataArray
+ from .dataset import Dataset
+
if indexes is None:
indexes = {}
def is_alignable(obj):
- return hasattr(obj, 'indexes') and hasattr(obj, 'reindex')
+ return isinstance(obj, (DataArray, Dataset))
positions = []
keys = []
diff --git a/xarray/core/combine.py b/xarray/core/combine.py
index 430f0e564d6..6853939c02d 100644
--- a/xarray/core/combine.py
+++ b/xarray/core/combine.py
@@ -8,8 +8,8 @@
from .alignment import align
from .merge import merge
from .pycompat import OrderedDict, basestring, iteritems
-from .variable import concat as concat_vars
from .variable import IndexVariable, Variable, as_variable
+from .variable import concat as concat_vars
def concat(objs, dim=None, data_vars='all', coords='different',
@@ -125,16 +125,17 @@ def _calc_concat_dim_coord(dim):
Infer the dimension name and 1d coordinate variable (if appropriate)
for concatenating along the new dimension.
"""
+ from .dataarray import DataArray
+
if isinstance(dim, basestring):
coord = None
- elif not hasattr(dim, 'dims'):
- # dim is not a DataArray or IndexVariable
+ elif not isinstance(dim, (DataArray, Variable)):
dim_name = getattr(dim, 'name', None)
if dim_name is None:
dim_name = 'concat_dim'
coord = IndexVariable(dim_name, dim)
dim = dim_name
- elif not hasattr(dim, 'name'):
+ elif not isinstance(dim, DataArray):
coord = as_variable(dim).to_index_variable()
dim, = coord.dims
else:
diff --git a/xarray/core/common.py b/xarray/core/common.py
index 3f934fcc769..c74b1fa080b 100644
--- a/xarray/core/common.py
+++ b/xarray/core/common.py
@@ -2,14 +2,18 @@
import warnings
from distutils.version import LooseVersion
+from textwrap import dedent
import numpy as np
import pandas as pd
-from . import duck_array_ops, dtypes, formatting, ops
+from . import dtypes, duck_array_ops, formatting, ops
from .arithmetic import SupportsArithmetic
from .pycompat import OrderedDict, basestring, dask_array_type, suppress
-from .utils import Frozen, SortedKeysDict
+from .utils import Frozen, ReprObject, SortedKeysDict, either_dict_or_kwargs
+
+# Used as a sentinel value to indicate a all dimensions
+ALL_DIMS = ReprObject('')
class ImplementsArrayReduce(object):
@@ -27,20 +31,20 @@ def wrapped_func(self, dim=None, axis=None, keep_attrs=False,
allow_lazy=True, **kwargs)
return wrapped_func
- _reduce_extra_args_docstring = \
- """dim : str or sequence of str, optional
+ _reduce_extra_args_docstring = dedent("""\
+ dim : str or sequence of str, optional
Dimension(s) over which to apply `{name}`.
axis : int or sequence of int, optional
Axis(es) over which to apply `{name}`. Only one of the 'dim'
and 'axis' arguments can be supplied. If neither are supplied, then
- `{name}` is calculated over axes."""
+ `{name}` is calculated over axes.""")
- _cum_extra_args_docstring = \
- """dim : str or sequence of str, optional
+ _cum_extra_args_docstring = dedent("""\
+ dim : str or sequence of str, optional
Dimension over which to apply `{name}`.
axis : int or sequence of int, optional
Axis over which to apply `{name}`. Only one of the 'dim'
- and 'axis' arguments can be supplied."""
+ and 'axis' arguments can be supplied.""")
class ImplementsDatasetReduce(object):
@@ -308,12 +312,12 @@ def assign_coords(self, **kwargs):
assigned : same type as caller
A new object with the new coordinates in addition to the existing
data.
-
+
Examples
--------
-
+
Convert longitude coordinates from 0-359 to -180-179:
-
+
>>> da = xr.DataArray(np.random.rand(4),
... coords=[np.array([358, 359, 0, 1])],
... dims='lon')
@@ -445,11 +449,11 @@ def groupby(self, group, squeeze=True):
grouped : GroupBy
A `GroupBy` object patterned after `pandas.GroupBy` that can be
iterated over in the form of `(unique_value, grouped_array)` pairs.
-
+
Examples
--------
Calculate daily anomalies for daily data:
-
+
>>> da = xr.DataArray(np.linspace(0, 1826, num=1827),
... coords=[pd.date_range('1/1/2000', '31/12/2004',
... freq='D')],
@@ -465,7 +469,7 @@ def groupby(self, group, squeeze=True):
Coordinates:
* time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ...
dayofyear (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 ...
-
+
See Also
--------
core.groupby.DataArrayGroupBy
@@ -525,24 +529,24 @@ def groupby_bins(self, group, bins, right=True, labels=None, precision=3,
'precision': precision,
'include_lowest': include_lowest})
- def rolling(self, min_periods=None, center=False, **windows):
+ def rolling(self, dim=None, min_periods=None, center=False, **dim_kwargs):
"""
Rolling window object.
Parameters
----------
+ dim: dict, optional
+ Mapping from the dimension name to create the rolling iterator
+ along (e.g. `time`) to its moving window size.
min_periods : int, default None
Minimum number of observations in window required to have a value
(otherwise result is NA). The default, None, is equivalent to
setting min_periods equal to the size of the window.
center : boolean, default False
Set the labels at the center of the window.
- **windows : dim=window
- dim : str
- Name of the dimension to create the rolling iterator
- along (e.g., `time`).
- window : int
- Size of the moving window.
+ **dim_kwargs : optional
+ The keyword arguments form of ``dim``.
+ One of dim or dim_kwarg must be provided.
Returns
-------
@@ -581,15 +585,15 @@ def rolling(self, min_periods=None, center=False, **windows):
core.rolling.DataArrayRolling
core.rolling.DatasetRolling
"""
-
- return self._rolling_cls(self, min_periods=min_periods,
- center=center, **windows)
+ dim = either_dict_or_kwargs(dim, dim_kwargs, 'rolling')
+ return self._rolling_cls(self, dim, min_periods=min_periods,
+ center=center)
def resample(self, freq=None, dim=None, how=None, skipna=None,
closed=None, label=None, base=0, keep_attrs=False, **indexer):
"""Returns a Resample object for performing resampling operations.
- Handles both downsampling and upsampling. If any intervals contain no
+ Handles both downsampling and upsampling. If any intervals contain no
values from the original object, they will be given the value ``NaN``.
Parameters
@@ -616,11 +620,11 @@ def resample(self, freq=None, dim=None, how=None, skipna=None,
-------
resampled : same type as caller
This object resampled.
-
+
Examples
--------
Downsample monthly time-series data to seasonal data:
-
+
>>> da = xr.DataArray(np.linspace(0, 11, num=12),
... coords=[pd.date_range('15/12/1999',
... periods=12, freq=pd.DateOffset(months=1))],
@@ -637,18 +641,20 @@ def resample(self, freq=None, dim=None, how=None, skipna=None,
* time (time) datetime64[ns] 1999-12-01 2000-03-01 2000-06-01 2000-09-01
Upsample monthly time-series data to daily data:
-
+
>>> da.resample(time='1D').interpolate('linear')
array([ 0. , 0.032258, 0.064516, ..., 10.935484, 10.967742, 11. ])
Coordinates:
* time (time) datetime64[ns] 1999-12-15 1999-12-16 1999-12-17 ...
-
+
References
----------
.. [1] http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases
"""
+ # TODO support non-string indexer after removing the old API.
+
from .dataarray import DataArray
from .resample import RESAMPLE_DIM
@@ -957,8 +963,8 @@ def contains_cftime_datetimes(var):
sample = sample.item()
return isinstance(sample, cftime_datetime)
else:
- return False
-
+ return False
+
def _contains_datetime_like_objects(var):
"""Check if a variable contains datetime like objects (either
diff --git a/xarray/core/computation.py b/xarray/core/computation.py
index bdba72cb48a..7998cc4f72f 100644
--- a/xarray/core/computation.py
+++ b/xarray/core/computation.py
@@ -2,19 +2,19 @@
Functions for applying functions that act on arrays to xarray's labeled data.
"""
from __future__ import absolute_import, division, print_function
-from distutils.version import LooseVersion
+
import functools
import itertools
import operator
from collections import Counter
+from distutils.version import LooseVersion
import numpy as np
-from . import duck_array_ops
-from . import utils
+from . import duck_array_ops, utils
from .alignment import deep_align
from .merge import expand_and_merge_variables
-from .pycompat import OrderedDict, dask_array_type, basestring
+from .pycompat import OrderedDict, basestring, dask_array_type
from .utils import is_dict_like
_DEFAULT_FROZEN_SET = frozenset()
diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py
index c2417345f55..6b53dcffe6e 100644
--- a/xarray/core/dask_array_compat.py
+++ b/xarray/core/dask_array_compat.py
@@ -1,7 +1,10 @@
from __future__ import absolute_import, division, print_function
-import numpy as np
+from distutils.version import LooseVersion
+
import dask.array as da
+import numpy as np
+from dask import __version__ as dask_version
try:
from dask.array import isin
@@ -30,3 +33,130 @@ def isin(element, test_elements, assume_unique=False, invert=False):
if invert:
result = ~result
return result
+
+
+if LooseVersion(dask_version) > LooseVersion('1.19.2'):
+ gradient = da.gradient
+
+else: # pragma: no cover
+ # Copied from dask v0.19.2
+ # Used under the terms of Dask's license, see licenses/DASK_LICENSE.
+ import math
+ from numbers import Integral, Real
+
+ try:
+ AxisError = np.AxisError
+ except AttributeError:
+ try:
+ np.array([0]).sum(axis=5)
+ except Exception as e:
+ AxisError = type(e)
+
+ def validate_axis(axis, ndim):
+ """ Validate an input to axis= keywords """
+ if isinstance(axis, (tuple, list)):
+ return tuple(validate_axis(ax, ndim) for ax in axis)
+ if not isinstance(axis, Integral):
+ raise TypeError("Axis value must be an integer, got %s" % axis)
+ if axis < -ndim or axis >= ndim:
+ raise AxisError("Axis %d is out of bounds for array of dimension "
+ "%d" % (axis, ndim))
+ if axis < 0:
+ axis += ndim
+ return axis
+
+ def _gradient_kernel(x, block_id, coord, axis, array_locs, grad_kwargs):
+ """
+ x: nd-array
+ array of one block
+ coord: 1d-array or scalar
+ coordinate along which the gradient is computed.
+ axis: int
+ axis along which the gradient is computed
+ array_locs:
+ actual location along axis. None if coordinate is scalar
+ grad_kwargs:
+ keyword to be passed to np.gradient
+ """
+ block_loc = block_id[axis]
+ if array_locs is not None:
+ coord = coord[array_locs[0][block_loc]:array_locs[1][block_loc]]
+ grad = np.gradient(x, coord, axis=axis, **grad_kwargs)
+ return grad
+
+ def gradient(f, *varargs, **kwargs):
+ f = da.asarray(f)
+
+ kwargs["edge_order"] = math.ceil(kwargs.get("edge_order", 1))
+ if kwargs["edge_order"] > 2:
+ raise ValueError("edge_order must be less than or equal to 2.")
+
+ drop_result_list = False
+ axis = kwargs.pop("axis", None)
+ if axis is None:
+ axis = tuple(range(f.ndim))
+ elif isinstance(axis, Integral):
+ drop_result_list = True
+ axis = (axis,)
+
+ axis = validate_axis(axis, f.ndim)
+
+ if len(axis) != len(set(axis)):
+ raise ValueError("duplicate axes not allowed")
+
+ axis = tuple(ax % f.ndim for ax in axis)
+
+ if varargs == ():
+ varargs = (1,)
+ if len(varargs) == 1:
+ varargs = len(axis) * varargs
+ if len(varargs) != len(axis):
+ raise TypeError(
+ "Spacing must either be a single scalar, or a scalar / "
+ "1d-array per axis"
+ )
+
+ if issubclass(f.dtype.type, (np.bool8, Integral)):
+ f = f.astype(float)
+ elif issubclass(f.dtype.type, Real) and f.dtype.itemsize < 4:
+ f = f.astype(float)
+
+ results = []
+ for i, ax in enumerate(axis):
+ for c in f.chunks[ax]:
+ if np.min(c) < kwargs["edge_order"] + 1:
+ raise ValueError(
+ 'Chunk size must be larger than edge_order + 1. '
+ 'Minimum chunk for aixs {} is {}. Rechunk to '
+ 'proceed.'.format(np.min(c), ax))
+
+ if np.isscalar(varargs[i]):
+ array_locs = None
+ else:
+ if isinstance(varargs[i], da.Array):
+ raise NotImplementedError(
+ 'dask array coordinated is not supported.')
+ # coordinate position for each block taking overlap into
+ # account
+ chunk = np.array(f.chunks[ax])
+ array_loc_stop = np.cumsum(chunk) + 1
+ array_loc_start = array_loc_stop - chunk - 2
+ array_loc_stop[-1] -= 1
+ array_loc_start[0] = 0
+ array_locs = (array_loc_start, array_loc_stop)
+
+ results.append(f.map_overlap(
+ _gradient_kernel,
+ dtype=f.dtype,
+ depth={j: 1 if j == ax else 0 for j in range(f.ndim)},
+ boundary="none",
+ coord=varargs[i],
+ axis=ax,
+ array_locs=array_locs,
+ grad_kwargs=kwargs,
+ ))
+
+ if drop_result_list:
+ results = results[0]
+
+ return results
diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py
index 423a65aa3c2..25c572edd54 100644
--- a/xarray/core/dask_array_ops.py
+++ b/xarray/core/dask_array_ops.py
@@ -1,10 +1,10 @@
from __future__ import absolute_import, division, print_function
+
from distutils.version import LooseVersion
import numpy as np
-from . import nputils
-from . import dtypes
+from . import dtypes, nputils
try:
import dask
diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py
index f215bc47df8..f131b003a69 100644
--- a/xarray/core/dataarray.py
+++ b/xarray/core/dataarray.py
@@ -677,14 +677,77 @@ def persist(self, **kwargs):
ds = self._to_temp_dataset().persist(**kwargs)
return self._from_temp_dataset(ds)
- def copy(self, deep=True):
+ def copy(self, deep=True, data=None):
"""Returns a copy of this array.
- If `deep=True`, a deep copy is made of all variables in the underlying
- dataset. Otherwise, a shallow copy is made, so each variable in the new
+ If `deep=True`, a deep copy is made of the data array.
+ Otherwise, a shallow copy is made, so each variable in the new
array's dataset is also a variable in this array's dataset.
+
+ Use `data` to create a new object with the same structure as
+ original but entirely new data.
+
+ Parameters
+ ----------
+ deep : bool, optional
+ Whether the data array and its coordinates are loaded into memory
+ and copied onto the new object. Default is True.
+ data : array_like, optional
+ Data to use in the new object. Must have same shape as original.
+ When `data` is used, `deep` is ignored for all data variables,
+ and only used for coords.
+
+ Returns
+ -------
+ object : DataArray
+ New object with dimensions, attributes, coordinates, name,
+ encoding, and optionally data copied from original.
+
+ Examples
+ --------
+
+ Shallow versus deep copy
+
+ >>> array = xr.DataArray([1, 2, 3], dims='x',
+ ... coords={'x': ['a', 'b', 'c']})
+ >>> array.copy()
+
+ array([1, 2, 3])
+ Coordinates:
+ * x (x) >> array_0 = array.copy(deep=False)
+ >>> array_0[0] = 7
+ >>> array_0
+
+ array([7, 2, 3])
+ Coordinates:
+ * x (x) >> array
+
+ array([7, 2, 3])
+ Coordinates:
+ * x (x) >> array.copy(data=[0.1, 0.2, 0.3])
+
+ array([ 0.1, 0.2, 0.3])
+ Coordinates:
+ * x (x) >> array
+
+ array([1, 2, 3])
+ Coordinates:
+ * x (x) >> arr = DataArray(np.arange(6).reshape(2, 3),
+ ... coords=[('x', ['a', 'b']), ('y', [0, 1, 2])])
+ >>> arr
+
+ array([[0, 1, 2],
+ [3, 4, 5]])
+ Coordinates:
+ * x (x) |S1 'a' 'b'
+ * y (y) int64 0 1 2
+ >>> stacked = arr.stack(z=('x', 'y'))
+ >>> stacked.indexes['z']
+ MultiIndex(levels=[[u'a', u'b'], [0, 1, 2]],
+ labels=[[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]],
+ names=[u'x', u'y'])
+ >>> roundtripped = stacked.unstack()
+ >>> arr.identical(roundtripped)
+ True
+
See also
--------
DataArray.stack
@@ -1856,7 +1955,7 @@ def _binary_op(f, reflexive=False, join=None, **ignored_kwargs):
def func(self, other):
if isinstance(other, (Dataset, groupby.GroupBy)):
return NotImplemented
- if hasattr(other, 'indexes'):
+ if isinstance(other, DataArray):
align_type = (OPTIONS['arithmetic_join']
if join is None else join)
self, other = align(self, other, join=align_type, copy=False)
@@ -1974,11 +2073,14 @@ def diff(self, dim, n=1, label='upper'):
Coordinates:
* x (x) int64 3 4
+ See Also
+ --------
+ DataArray.differentiate
"""
ds = self._to_temp_dataset().diff(n=n, dim=dim, label=label)
return self._from_temp_dataset(ds)
- def shift(self, **shifts):
+ def shift(self, shifts=None, **shifts_kwargs):
"""Shift this array by an offset along one or more dimensions.
Only the data is moved; coordinates stay in place. Values shifted from
@@ -1987,10 +2089,13 @@ def shift(self, **shifts):
Parameters
----------
- **shifts : keyword arguments of the form {dim: offset}
+ shifts : Mapping with the form of {dim: offset}
Integer offset to shift along each of the given dimensions.
Positive offsets shift to the right; negative offsets shift to the
left.
+ **shifts_kwargs:
+ The keyword arguments form of ``shifts``.
+ One of shifts or shifts_kwarg must be provided.
Returns
-------
@@ -2012,17 +2117,23 @@ def shift(self, **shifts):
Coordinates:
* x (x) int64 0 1 2
"""
- variable = self.variable.shift(**shifts)
- return self._replace(variable)
+ ds = self._to_temp_dataset().shift(shifts=shifts, **shifts_kwargs)
+ return self._from_temp_dataset(ds)
- def roll(self, **shifts):
+ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs):
"""Roll this array by an offset along one or more dimensions.
- Unlike shift, roll rotates all variables, including coordinates. The
- direction of rotation is consistent with :py:func:`numpy.roll`.
+ Unlike shift, roll may rotate all variables, including coordinates
+ if specified. The direction of rotation is consistent with
+ :py:func:`numpy.roll`.
Parameters
----------
+ roll_coords : bool
+ Indicates whether to roll the coordinates by the offset
+ The current default of roll_coords (None, equivalent to True) is
+ deprecated and will change to False in a future version.
+ Explicitly pass roll_coords to silence the warning.
**shifts : keyword arguments of the form {dim: offset}
Integer offset to rotate each of the given dimensions. Positive
offsets roll to the right; negative offsets roll to the left.
@@ -2046,7 +2157,8 @@ def roll(self, **shifts):
Coordinates:
* x (x) int64 2 0 1
"""
- ds = self._to_temp_dataset().roll(**shifts)
+ ds = self._to_temp_dataset().roll(
+ shifts=shifts, roll_coords=roll_coords, **shifts_kwargs)
return self._from_temp_dataset(ds)
@property
@@ -2243,6 +2355,61 @@ def rank(self, dim, pct=False, keep_attrs=False):
ds = self._to_temp_dataset().rank(dim, pct=pct, keep_attrs=keep_attrs)
return self._from_temp_dataset(ds)
+ def differentiate(self, coord, edge_order=1, datetime_unit=None):
+ """ Differentiate the array with the second order accurate central
+ differences.
+
+ .. note::
+ This feature is limited to simple cartesian geometry, i.e. coord
+ must be one dimensional.
+
+ Parameters
+ ----------
+ coord: str
+ The coordinate to be used to compute the gradient.
+ edge_order: 1 or 2. Default 1
+ N-th order accurate differences at the boundaries.
+ datetime_unit: None or any of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms',
+ 'us', 'ns', 'ps', 'fs', 'as'}
+ Unit to compute gradient. Only valid for datetime coordinate.
+
+ Returns
+ -------
+ differentiated: DataArray
+
+ See also
+ --------
+ numpy.gradient: corresponding numpy function
+
+ Examples
+ --------
+
+ >>> da = xr.DataArray(np.arange(12).reshape(4, 3), dims=['x', 'y'],
+ ... coords={'x': [0, 0.1, 1.1, 1.2]})
+ >>> da
+
+ array([[ 0, 1, 2],
+ [ 3, 4, 5],
+ [ 6, 7, 8],
+ [ 9, 10, 11]])
+ Coordinates:
+ * x (x) float64 0.0 0.1 1.1 1.2
+ Dimensions without coordinates: y
+ >>>
+ >>> da.differentiate('x')
+
+ array([[30. , 30. , 30. ],
+ [27.545455, 27.545455, 27.545455],
+ [27.545455, 27.545455, 27.545455],
+ [30. , 30. , 30. ]])
+ Coordinates:
+ * x (x) float64 0.0 0.1 1.1 1.2
+ Dimensions without coordinates: y
+ """
+ ds = self._to_temp_dataset().differentiate(
+ coord, edge_order, datetime_unit)
+ return self._from_temp_dataset(ds)
+
# priority most be higher than Variable to properly work with binary ufuncs
ops.inject_all_ops_and_reduce_methods(DataArray, priority=60)
diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index 4b52178ad0e..c8586d1d408 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -13,12 +13,14 @@
import xarray as xr
from . import (
- alignment, duck_array_ops, formatting, groupby, indexing, ops, resample,
- rolling, utils)
+ alignment, computation, duck_array_ops, formatting, groupby, indexing, ops,
+ resample, rolling, utils)
from .. import conventions
+from ..coding.cftimeindex import _parse_array_of_cftime_strings
from .alignment import align
from .common import (
- DataWithCoords, ImplementsDatasetReduce, _contains_datetime_like_objects)
+ ALL_DIMS, DataWithCoords, ImplementsDatasetReduce,
+ _contains_datetime_like_objects)
from .coordinates import (
DatasetCoordinates, Indexes, LevelCoordinatesSource,
assert_coordinate_consistent, remap_label_indexers)
@@ -30,8 +32,9 @@
from .pycompat import (
OrderedDict, basestring, dask_array_type, integer_types, iteritems, range)
from .utils import (
- Frozen, SortedKeysDict, either_dict_or_kwargs, decode_numpy_dict_values,
- ensure_us_time_resolution, hashable, maybe_wrap_array)
+ Frozen, SortedKeysDict, datetime_to_numeric, decode_numpy_dict_values,
+ either_dict_or_kwargs, ensure_us_time_resolution, hashable,
+ maybe_wrap_array)
from .variable import IndexVariable, Variable, as_variable, broadcast_variables
# list of attributes of pd.DatetimeIndex that are ndarrays of time info
@@ -709,16 +712,120 @@ def _replace_indexes(self, indexes):
obj = obj.rename(dim_names)
return obj
- def copy(self, deep=False):
+ def copy(self, deep=False, data=None):
"""Returns a copy of this dataset.
If `deep=True`, a deep copy is made of each of the component variables.
Otherwise, a shallow copy of each of the component variable is made, so
that the underlying memory region of the new dataset is the same as in
the original dataset.
+
+ Use `data` to create a new object with the same structure as
+ original but entirely new data.
+
+ Parameters
+ ----------
+ deep : bool, optional
+ Whether each component variable is loaded into memory and copied onto
+ the new object. Default is True.
+ data : dict-like, optional
+ Data to use in the new object. Each item in `data` must have same
+ shape as corresponding data variable in original. When `data` is
+ used, `deep` is ignored for the data variables and only used for
+ coords.
+
+ Returns
+ -------
+ object : Dataset
+ New object with dimensions, attributes, coordinates, name, encoding,
+ and optionally data copied from original.
+
+ Examples
+ --------
+
+ Shallow copy versus deep copy
+
+ >>> da = xr.DataArray(np.random.randn(2, 3))
+ >>> ds = xr.Dataset({'foo': da, 'bar': ('x', [-1, 2])},
+ coords={'x': ['one', 'two']})
+ >>> ds.copy()
+
+ Dimensions: (dim_0: 2, dim_1: 3, x: 2)
+ Coordinates:
+ * x (x) >> ds_0 = ds.copy(deep=False)
+ >>> ds_0['foo'][0, 0] = 7
+ >>> ds_0
+
+ Dimensions: (dim_0: 2, dim_1: 3, x: 2)
+ Coordinates:
+ * x (x) >> ds
+
+ Dimensions: (dim_0: 2, dim_1: 3, x: 2)
+ Coordinates:
+ * x (x) >> ds.copy(data={'foo': np.arange(6).reshape(2, 3), 'bar': ['a', 'b']})
+
+ Dimensions: (dim_0: 2, dim_1: 3, x: 2)
+ Coordinates:
+ * x (x) >> ds
+
+ Dimensions: (dim_0: 2, dim_1: 3, x: 2)
+ Coordinates:
+ * x (x) 1:
+ warnings.warn(
+ "Default reduction dimension will be changed to the "
+ "grouped dimension after xarray 0.12. To silence this "
+ "warning, pass dim=xarray.ALL_DIMS explicitly.",
+ FutureWarning, stacklevel=2)
+
def reduce_array(ar):
return ar.reduce(func, dim, axis, keep_attrs=keep_attrs, **kwargs)
return self.apply(reduce_array, shortcut=shortcut)
+ # TODO remove the following class method and DEFAULT_DIMS after the
+ # deprecation cycle
+ @classmethod
+ def _reduce_method(cls, func, include_skipna, numeric_only):
+ if include_skipna:
+ def wrapped_func(self, dim=DEFAULT_DIMS, axis=None, skipna=None,
+ keep_attrs=False, **kwargs):
+ return self.reduce(func, dim, axis, keep_attrs=keep_attrs,
+ skipna=skipna, allow_lazy=True, **kwargs)
+ else:
+ def wrapped_func(self, dim=DEFAULT_DIMS, axis=None,
+ keep_attrs=False, **kwargs):
+ return self.reduce(func, dim, axis, keep_attrs=keep_attrs,
+ allow_lazy=True, **kwargs)
+ return wrapped_func
+
+
+DEFAULT_DIMS = utils.ReprObject('')
ops.inject_reduce_methods(DataArrayGroupBy)
ops.inject_binary_ops(DataArrayGroupBy)
@@ -649,10 +679,40 @@ def reduce(self, func, dim=None, keep_attrs=False, **kwargs):
Array with summarized data and the indicated dimension(s)
removed.
"""
+ if dim == DEFAULT_DIMS:
+ dim = ALL_DIMS
+ # TODO change this to dim = self._group_dim after
+ # the deprecation process. Do not forget to remove _reduce_method
+ warnings.warn(
+ "Default reduction dimension will be changed to the "
+ "grouped dimension after xarray 0.12. To silence this "
+ "warning, pass dim=xarray.ALL_DIMS explicitly.",
+ FutureWarning, stacklevel=2)
+ elif dim is None:
+ dim = self._group_dim
+
def reduce_dataset(ds):
return ds.reduce(func, dim, keep_attrs, **kwargs)
return self.apply(reduce_dataset)
+ # TODO remove the following class method and DEFAULT_DIMS after the
+ # deprecation cycle
+ @classmethod
+ def _reduce_method(cls, func, include_skipna, numeric_only):
+ if include_skipna:
+ def wrapped_func(self, dim=DEFAULT_DIMS, keep_attrs=False,
+ skipna=None, **kwargs):
+ return self.reduce(func, dim, keep_attrs, skipna=skipna,
+ numeric_only=numeric_only, allow_lazy=True,
+ **kwargs)
+ else:
+ def wrapped_func(self, dim=DEFAULT_DIMS, keep_attrs=False,
+ **kwargs):
+ return self.reduce(func, dim, keep_attrs,
+ numeric_only=numeric_only, allow_lazy=True,
+ **kwargs)
+ return wrapped_func
+
def assign(self, **kwargs):
"""Assign data variables by group.
diff --git a/xarray/core/merge.py b/xarray/core/merge.py
index f823717a8af..984dd2fa204 100644
--- a/xarray/core/merge.py
+++ b/xarray/core/merge.py
@@ -190,10 +190,13 @@ def expand_variable_dicts(list_of_variable_dicts):
an input's values. The values of each ordered dictionary are all
xarray.Variable objects.
"""
+ from .dataarray import DataArray
+ from .dataset import Dataset
+
var_dicts = []
for variables in list_of_variable_dicts:
- if hasattr(variables, 'variables'): # duck-type Dataset
+ if isinstance(variables, Dataset):
sanitized_vars = variables.variables
else:
# append coords to var_dicts before appending sanitized_vars,
@@ -201,7 +204,7 @@ def expand_variable_dicts(list_of_variable_dicts):
sanitized_vars = OrderedDict()
for name, var in variables.items():
- if hasattr(var, '_coords'): # duck-type DataArray
+ if isinstance(var, DataArray):
# use private API for speed
coords = var._coords.copy()
# explicitly overwritten variables should take precedence
@@ -232,17 +235,19 @@ def determine_coords(list_of_variable_dicts):
All variable found in the input should appear in either the set of
coordinate or non-coordinate names.
"""
+ from .dataarray import DataArray
+ from .dataset import Dataset
+
coord_names = set()
noncoord_names = set()
for variables in list_of_variable_dicts:
- if hasattr(variables, 'coords') and hasattr(variables, 'data_vars'):
- # duck-type Dataset
+ if isinstance(variables, Dataset):
coord_names.update(variables.coords)
noncoord_names.update(variables.data_vars)
else:
for name, var in variables.items():
- if hasattr(var, '_coords'): # duck-type DataArray
+ if isinstance(var, DataArray):
coords = set(var._coords) # use private API for speed
# explicitly overwritten variables should take precedence
coords.discard(name)
diff --git a/xarray/core/missing.py b/xarray/core/missing.py
index bec9e2e1931..3f4e0fc3ac9 100644
--- a/xarray/core/missing.py
+++ b/xarray/core/missing.py
@@ -1,5 +1,6 @@
from __future__ import absolute_import, division, print_function
+import warnings
from collections import Iterable
from functools import partial
@@ -7,11 +8,12 @@
import pandas as pd
from . import rolling
+from .common import _contains_datetime_like_objects
from .computation import apply_ufunc
+from .duck_array_ops import dask_array_type
from .pycompat import iteritems
-from .utils import is_scalar, OrderedSet
+from .utils import OrderedSet, datetime_to_numeric, is_scalar
from .variable import Variable, broadcast_variables
-from .duck_array_ops import dask_array_type
class BaseInterpolator(object):
@@ -57,7 +59,7 @@ def __init__(self, xi, yi, method='linear', fill_value=None, **kwargs):
if self.cons_kwargs:
raise ValueError(
- 'recieved invalid kwargs: %r' % self.cons_kwargs.keys())
+ 'received invalid kwargs: %r' % self.cons_kwargs.keys())
if fill_value is None:
self._left = np.nan
@@ -207,13 +209,16 @@ def interp_na(self, dim=None, use_coordinate=True, method='linear', limit=None,
interp_class, kwargs = _get_interpolator(method, **kwargs)
interpolator = partial(func_interpolate_na, interp_class, **kwargs)
- arr = apply_ufunc(interpolator, index, self,
- input_core_dims=[[dim], [dim]],
- output_core_dims=[[dim]],
- output_dtypes=[self.dtype],
- dask='parallelized',
- vectorize=True,
- keep_attrs=True).transpose(*self.dims)
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', 'overflow', RuntimeWarning)
+ warnings.filterwarnings('ignore', 'invalid value', RuntimeWarning)
+ arr = apply_ufunc(interpolator, index, self,
+ input_core_dims=[[dim], [dim]],
+ output_core_dims=[[dim]],
+ output_dtypes=[self.dtype],
+ dask='parallelized',
+ vectorize=True,
+ keep_attrs=True).transpose(*self.dims)
if limit is not None:
arr = arr.where(valids)
@@ -402,15 +407,16 @@ def _floatize_x(x, new_x):
x = list(x)
new_x = list(new_x)
for i in range(len(x)):
- if x[i].dtype.kind in 'Mm':
+ if _contains_datetime_like_objects(x[i]):
# Scipy casts coordinates to np.float64, which is not accurate
# enough for datetime64 (uses 64bit integer).
# We assume that the most of the bits are used to represent the
# offset (min(x)) and the variation (x - min(x)) can be
# represented by float.
- xmin = np.min(x[i])
- x[i] = (x[i] - xmin).astype(np.float64)
- new_x[i] = (new_x[i] - xmin).astype(np.float64)
+ xmin = x[i].min()
+ x[i] = datetime_to_numeric(x[i], offset=xmin, dtype=np.float64)
+ new_x[i] = datetime_to_numeric(
+ new_x[i], offset=xmin, dtype=np.float64)
return x, new_x
diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py
new file mode 100644
index 00000000000..4d3f03c899e
--- /dev/null
+++ b/xarray/core/nanops.py
@@ -0,0 +1,207 @@
+from __future__ import absolute_import, division, print_function
+
+import numpy as np
+
+from . import dtypes, nputils
+from .duck_array_ops import (
+ _dask_or_eager_func, count, fillna, isnull, where_method)
+from .pycompat import dask_array_type
+
+try:
+ import dask.array as dask_array
+except ImportError:
+ dask_array = None
+
+
+def _replace_nan(a, val):
+ """
+ replace nan in a by val, and returns the replaced array and the nan
+ position
+ """
+ mask = isnull(a)
+ return where_method(val, mask, a), mask
+
+
+def _maybe_null_out(result, axis, mask, min_count=1):
+ """
+ xarray version of pandas.core.nanops._maybe_null_out
+ """
+ if hasattr(axis, '__len__'): # if tuple or list
+ raise ValueError('min_count is not available for reduction '
+ 'with more than one dimensions.')
+
+ if axis is not None and getattr(result, 'ndim', False):
+ null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0
+ if null_mask.any():
+ dtype, fill_value = dtypes.maybe_promote(result.dtype)
+ result = result.astype(dtype)
+ result[null_mask] = fill_value
+
+ elif getattr(result, 'dtype', None) not in dtypes.NAT_TYPES:
+ null_mask = mask.size - mask.sum()
+ if null_mask < min_count:
+ result = np.nan
+
+ return result
+
+
+def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs):
+ """ In house nanargmin, nanargmax for object arrays. Always return integer
+ type
+ """
+ valid_count = count(value, axis=axis)
+ value = fillna(value, fill_value)
+ data = _dask_or_eager_func(func)(value, axis=axis, **kwargs)
+
+ # TODO This will evaluate dask arrays and might be costly.
+ if (valid_count == 0).any():
+ raise ValueError('All-NaN slice encountered')
+
+ return data
+
+
+def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs):
+ """ In house nanmin and nanmax for object array """
+ valid_count = count(value, axis=axis)
+ filled_value = fillna(value, fill_value)
+ data = getattr(np, func)(filled_value, axis=axis, **kwargs)
+ if not hasattr(data, 'dtype'): # scalar case
+ data = dtypes.fill_value(value.dtype) if valid_count == 0 else data
+ return np.array(data, dtype=value.dtype)
+ return where_method(data, valid_count != 0)
+
+
+def nanmin(a, axis=None, out=None):
+ if a.dtype.kind == 'O':
+ return _nan_minmax_object(
+ 'min', dtypes.get_pos_infinity(a.dtype), a, axis)
+
+ module = dask_array if isinstance(a, dask_array_type) else nputils
+ return module.nanmin(a, axis=axis)
+
+
+def nanmax(a, axis=None, out=None):
+ if a.dtype.kind == 'O':
+ return _nan_minmax_object(
+ 'max', dtypes.get_neg_infinity(a.dtype), a, axis)
+
+ module = dask_array if isinstance(a, dask_array_type) else nputils
+ return module.nanmax(a, axis=axis)
+
+
+def nanargmin(a, axis=None):
+ fill_value = dtypes.get_pos_infinity(a.dtype)
+ if a.dtype.kind == 'O':
+ return _nan_argminmax_object('argmin', fill_value, a, axis=axis)
+ a, mask = _replace_nan(a, fill_value)
+ if isinstance(a, dask_array_type):
+ res = dask_array.argmin(a, axis=axis)
+ else:
+ res = np.argmin(a, axis=axis)
+
+ if mask is not None:
+ mask = mask.all(axis=axis)
+ if mask.any():
+ raise ValueError("All-NaN slice encountered")
+ return res
+
+
+def nanargmax(a, axis=None):
+ fill_value = dtypes.get_neg_infinity(a.dtype)
+ if a.dtype.kind == 'O':
+ return _nan_argminmax_object('argmax', fill_value, a, axis=axis)
+
+ a, mask = _replace_nan(a, fill_value)
+ if isinstance(a, dask_array_type):
+ res = dask_array.argmax(a, axis=axis)
+ else:
+ res = np.argmax(a, axis=axis)
+
+ if mask is not None:
+ mask = mask.all(axis=axis)
+ if mask.any():
+ raise ValueError("All-NaN slice encountered")
+ return res
+
+
+def nansum(a, axis=None, dtype=None, out=None, min_count=None):
+ a, mask = _replace_nan(a, 0)
+ result = _dask_or_eager_func('sum')(a, axis=axis, dtype=dtype)
+ if min_count is not None:
+ return _maybe_null_out(result, axis, mask, min_count)
+ else:
+ return result
+
+
+def _nanmean_ddof_object(ddof, value, axis=None, **kwargs):
+ """ In house nanmean. ddof argument will be used in _nanvar method """
+ from .duck_array_ops import (count, fillna, _dask_or_eager_func,
+ where_method)
+
+ valid_count = count(value, axis=axis)
+ value = fillna(value, 0)
+ # As dtype inference is impossible for object dtype, we assume float
+ # https://github.com/dask/dask/issues/3162
+ dtype = kwargs.pop('dtype', None)
+ if dtype is None and value.dtype.kind == 'O':
+ dtype = value.dtype if value.dtype.kind in ['cf'] else float
+
+ data = _dask_or_eager_func('sum')(value, axis=axis, dtype=dtype, **kwargs)
+ data = data / (valid_count - ddof)
+ return where_method(data, valid_count != 0)
+
+
+def nanmean(a, axis=None, dtype=None, out=None):
+ if a.dtype.kind == 'O':
+ return _nanmean_ddof_object(0, a, axis=axis, dtype=dtype)
+
+ if isinstance(a, dask_array_type):
+ return dask_array.nanmean(a, axis=axis, dtype=dtype)
+
+ return np.nanmean(a, axis=axis, dtype=dtype)
+
+
+def nanmedian(a, axis=None, out=None):
+ return _dask_or_eager_func('nanmedian', eager_module=nputils)(a, axis=axis)
+
+
+def _nanvar_object(value, axis=None, **kwargs):
+ ddof = kwargs.pop('ddof', 0)
+ kwargs_mean = kwargs.copy()
+ kwargs_mean.pop('keepdims', None)
+ value_mean = _nanmean_ddof_object(ddof=0, value=value, axis=axis,
+ keepdims=True, **kwargs_mean)
+ squared = (value.astype(value_mean.dtype) - value_mean)**2
+ return _nanmean_ddof_object(ddof, squared, axis=axis, **kwargs)
+
+
+def nanvar(a, axis=None, dtype=None, out=None, ddof=0):
+ if a.dtype.kind == 'O':
+ return _nanvar_object(a, axis=axis, dtype=dtype, ddof=ddof)
+
+ return _dask_or_eager_func('nanvar', eager_module=nputils)(
+ a, axis=axis, dtype=dtype, ddof=ddof)
+
+
+def nanstd(a, axis=None, dtype=None, out=None, ddof=0):
+ return _dask_or_eager_func('nanstd', eager_module=nputils)(
+ a, axis=axis, dtype=dtype, ddof=ddof)
+
+
+def nanprod(a, axis=None, dtype=None, out=None, min_count=None):
+ a, mask = _replace_nan(a, 1)
+ result = _dask_or_eager_func('nanprod')(a, axis=axis, dtype=dtype, out=out)
+ if min_count is not None:
+ return _maybe_null_out(result, axis, mask, min_count)
+ else:
+ return result
+
+
+def nancumsum(a, axis=None, dtype=None, out=None):
+ return _dask_or_eager_func('nancumsum', eager_module=nputils)(
+ a, axis=axis, dtype=dtype)
+
+
+def nancumprod(a, axis=None, dtype=None, out=None):
+ return _dask_or_eager_func('nancumprod', eager_module=nputils)(
+ a, axis=axis, dtype=dtype)
diff --git a/xarray/core/npcompat.py b/xarray/core/npcompat.py
index 6d4db063b98..efa68c8bad5 100644
--- a/xarray/core/npcompat.py
+++ b/xarray/core/npcompat.py
@@ -1,5 +1,7 @@
from __future__ import absolute_import, division, print_function
+from distutils.version import LooseVersion
+
import numpy as np
try:
@@ -97,3 +99,187 @@ def isin(element, test_elements, assume_unique=False, invert=False):
element = np.asarray(element)
return np.in1d(element, test_elements, assume_unique=assume_unique,
invert=invert).reshape(element.shape)
+
+
+if LooseVersion(np.__version__) >= LooseVersion('1.13'):
+ gradient = np.gradient
+else:
+ def normalize_axis_tuple(axes, N):
+ if isinstance(axes, int):
+ axes = (axes, )
+ return tuple([N + a if a < 0 else a for a in axes])
+
+ def gradient(f, *varargs, **kwargs):
+ f = np.asanyarray(f)
+ N = f.ndim # number of dimensions
+
+ axes = kwargs.pop('axis', None)
+ if axes is None:
+ axes = tuple(range(N))
+ else:
+ axes = normalize_axis_tuple(axes, N)
+
+ len_axes = len(axes)
+ n = len(varargs)
+ if n == 0:
+ # no spacing argument - use 1 in all axes
+ dx = [1.0] * len_axes
+ elif n == 1 and np.ndim(varargs[0]) == 0:
+ # single scalar for all axes
+ dx = varargs * len_axes
+ elif n == len_axes:
+ # scalar or 1d array for each axis
+ dx = list(varargs)
+ for i, distances in enumerate(dx):
+ if np.ndim(distances) == 0:
+ continue
+ elif np.ndim(distances) != 1:
+ raise ValueError("distances must be either scalars or 1d")
+ if len(distances) != f.shape[axes[i]]:
+ raise ValueError("when 1d, distances must match the "
+ "length of the corresponding dimension")
+ diffx = np.diff(distances)
+ # if distances are constant reduce to the scalar case
+ # since it brings a consistent speedup
+ if (diffx == diffx[0]).all():
+ diffx = diffx[0]
+ dx[i] = diffx
+ else:
+ raise TypeError("invalid number of arguments")
+
+ edge_order = kwargs.pop('edge_order', 1)
+ if kwargs:
+ raise TypeError('"{}" are not valid keyword arguments.'.format(
+ '", "'.join(kwargs.keys())))
+ if edge_order > 2:
+ raise ValueError("'edge_order' greater than 2 not supported")
+
+ # use central differences on interior and one-sided differences on the
+ # endpoints. This preserves second order-accuracy over the full domain.
+
+ outvals = []
+
+ # create slice objects --- initially all are [:, :, ..., :]
+ slice1 = [slice(None)] * N
+ slice2 = [slice(None)] * N
+ slice3 = [slice(None)] * N
+ slice4 = [slice(None)] * N
+
+ otype = f.dtype.char
+ if otype not in ['f', 'd', 'F', 'D', 'm', 'M']:
+ otype = 'd'
+
+ # Difference of datetime64 elements results in timedelta64
+ if otype == 'M':
+ # Need to use the full dtype name because it contains unit
+ # information
+ otype = f.dtype.name.replace('datetime', 'timedelta')
+ elif otype == 'm':
+ # Needs to keep the specific units, can't be a general unit
+ otype = f.dtype
+
+ # Convert datetime64 data into ints. Make dummy variable `y`
+ # that is a view of ints if the data is datetime64, otherwise
+ # just set y equal to the array `f`.
+ if f.dtype.char in ["M", "m"]:
+ y = f.view('int64')
+ else:
+ y = f
+
+ for i, axis in enumerate(axes):
+ if y.shape[axis] < edge_order + 1:
+ raise ValueError(
+ "Shape of array too small to calculate a numerical "
+ "gradient, at least (edge_order + 1) elements are "
+ "required.")
+ # result allocation
+ out = np.empty_like(y, dtype=otype)
+
+ uniform_spacing = np.ndim(dx[i]) == 0
+
+ # Numerical differentiation: 2nd order interior
+ slice1[axis] = slice(1, -1)
+ slice2[axis] = slice(None, -2)
+ slice3[axis] = slice(1, -1)
+ slice4[axis] = slice(2, None)
+
+ if uniform_spacing:
+ out[slice1] = (f[slice4] - f[slice2]) / (2. * dx[i])
+ else:
+ dx1 = dx[i][0:-1]
+ dx2 = dx[i][1:]
+ a = -(dx2) / (dx1 * (dx1 + dx2))
+ b = (dx2 - dx1) / (dx1 * dx2)
+ c = dx1 / (dx2 * (dx1 + dx2))
+ # fix the shape for broadcasting
+ shape = np.ones(N, dtype=int)
+ shape[axis] = -1
+ a.shape = b.shape = c.shape = shape
+ # 1D equivalent --
+ # out[1:-1] = a * f[:-2] + b * f[1:-1] + c * f[2:]
+ out[slice1] = a * f[slice2] + b * f[slice3] + c * f[slice4]
+
+ # Numerical differentiation: 1st order edges
+ if edge_order == 1:
+ slice1[axis] = 0
+ slice2[axis] = 1
+ slice3[axis] = 0
+ dx_0 = dx[i] if uniform_spacing else dx[i][0]
+ # 1D equivalent -- out[0] = (y[1] - y[0]) / (x[1] - x[0])
+ out[slice1] = (y[slice2] - y[slice3]) / dx_0
+
+ slice1[axis] = -1
+ slice2[axis] = -1
+ slice3[axis] = -2
+ dx_n = dx[i] if uniform_spacing else dx[i][-1]
+ # 1D equivalent -- out[-1] = (y[-1] - y[-2]) / (x[-1] - x[-2])
+ out[slice1] = (y[slice2] - y[slice3]) / dx_n
+
+ # Numerical differentiation: 2nd order edges
+ else:
+ slice1[axis] = 0
+ slice2[axis] = 0
+ slice3[axis] = 1
+ slice4[axis] = 2
+ if uniform_spacing:
+ a = -1.5 / dx[i]
+ b = 2. / dx[i]
+ c = -0.5 / dx[i]
+ else:
+ dx1 = dx[i][0]
+ dx2 = dx[i][1]
+ a = -(2. * dx1 + dx2) / (dx1 * (dx1 + dx2))
+ b = (dx1 + dx2) / (dx1 * dx2)
+ c = - dx1 / (dx2 * (dx1 + dx2))
+ # 1D equivalent -- out[0] = a * y[0] + b * y[1] + c * y[2]
+ out[slice1] = a * y[slice2] + b * y[slice3] + c * y[slice4]
+
+ slice1[axis] = -1
+ slice2[axis] = -3
+ slice3[axis] = -2
+ slice4[axis] = -1
+ if uniform_spacing:
+ a = 0.5 / dx[i]
+ b = -2. / dx[i]
+ c = 1.5 / dx[i]
+ else:
+ dx1 = dx[i][-2]
+ dx2 = dx[i][-1]
+ a = (dx2) / (dx1 * (dx1 + dx2))
+ b = - (dx2 + dx1) / (dx1 * dx2)
+ c = (2. * dx2 + dx1) / (dx2 * (dx1 + dx2))
+ # 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1]
+ out[slice1] = a * y[slice2] + b * y[slice3] + c * y[slice4]
+
+ outvals.append(out)
+
+ # reset the slice object in this dimension to ":"
+ slice1[axis] = slice(None)
+ slice2[axis] = slice(None)
+ slice3[axis] = slice(None)
+ slice4[axis] = slice(None)
+
+ if len_axes == 1:
+ return outvals[0]
+ else:
+ return outvals
diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py
index 6df2d34bfe3..a8d596abd86 100644
--- a/xarray/core/nputils.py
+++ b/xarray/core/nputils.py
@@ -5,6 +5,14 @@
import numpy as np
import pandas as pd
+try:
+ import bottleneck as bn
+ _USE_BOTTLENECK = True
+except ImportError:
+ # use numpy methods instead
+ bn = np
+ _USE_BOTTLENECK = False
+
def _validate_axis(data, axis):
ndim = data.ndim
@@ -195,3 +203,36 @@ def _rolling_window(a, window, axis=-1):
rolling = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides,
writeable=False)
return np.swapaxes(rolling, -2, axis)
+
+
+def _create_bottleneck_method(name, npmodule=np):
+ def f(values, axis=None, **kwds):
+ dtype = kwds.get('dtype', None)
+ bn_func = getattr(bn, name, None)
+
+ if (_USE_BOTTLENECK and bn_func is not None and
+ not isinstance(axis, tuple) and
+ values.dtype.kind in 'uifc' and
+ values.dtype.isnative and
+ (dtype is None or np.dtype(dtype) == values.dtype)):
+ # bottleneck does not take care dtype, min_count
+ kwds.pop('dtype', None)
+ result = bn_func(values, axis=axis, **kwds)
+ else:
+ result = getattr(npmodule, name)(values, axis=axis, **kwds)
+
+ return result
+
+ f.__name__ = name
+ return f
+
+
+nanmin = _create_bottleneck_method('nanmin')
+nanmax = _create_bottleneck_method('nanmax')
+nanmean = _create_bottleneck_method('nanmean')
+nanmedian = _create_bottleneck_method('nanmedian')
+nanvar = _create_bottleneck_method('nanvar')
+nanstd = _create_bottleneck_method('nanstd')
+nanprod = _create_bottleneck_method('nanprod')
+nancumsum = _create_bottleneck_method('nancumsum')
+nancumprod = _create_bottleneck_method('nancumprod')
diff --git a/xarray/core/ops.py b/xarray/core/ops.py
index d9e8ceb65d5..a0dd2212a8f 100644
--- a/xarray/core/ops.py
+++ b/xarray/core/ops.py
@@ -86,7 +86,7 @@
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
- implemented (object, datetime64 or timedelta64).
+ implemented (object, datetime64 or timedelta64).{min_count_docs}
keep_attrs : bool, optional
If True, the attributes (`attrs`) will be copied from the original
object to the new one. If False (default), the new object will be
@@ -102,6 +102,12 @@
indicated dimension(s) removed.
"""
+_MINCOUNT_DOCSTRING = """
+min_count : int, default None
+ The required number of valid values to perform the operation.
+ If fewer than min_count non-NA values are present the result will
+ be NA. New in version 0.10.8: Added with the default being None."""
+
_ROLLING_REDUCE_DOCSTRING_TEMPLATE = """\
Reduce this {da_or_ds}'s data windows by applying `{name}` along its dimension.
@@ -236,11 +242,15 @@ def inject_reduce_methods(cls):
[('count', duck_array_ops.count, False)])
for name, f, include_skipna in methods:
numeric_only = getattr(f, 'numeric_only', False)
+ available_min_count = getattr(f, 'available_min_count', False)
+ min_count_docs = _MINCOUNT_DOCSTRING if available_min_count else ''
+
func = cls._reduce_method(f, include_skipna, numeric_only)
func.__name__ = name
func.__doc__ = _REDUCE_DOCSTRING_TEMPLATE.format(
name=name, cls=cls.__name__,
- extra_args=cls._reduce_extra_args_docstring.format(name=name))
+ extra_args=cls._reduce_extra_args_docstring.format(name=name),
+ min_count_docs=min_count_docs)
setattr(cls, name, func)
diff --git a/xarray/core/options.py b/xarray/core/options.py
index 48d4567fc99..04ea0be7172 100644
--- a/xarray/core/options.py
+++ b/xarray/core/options.py
@@ -1,9 +1,43 @@
from __future__ import absolute_import, division, print_function
+DISPLAY_WIDTH = 'display_width'
+ARITHMETIC_JOIN = 'arithmetic_join'
+ENABLE_CFTIMEINDEX = 'enable_cftimeindex'
+FILE_CACHE_MAXSIZE = 'file_cache_maxsize'
+CMAP_SEQUENTIAL = 'cmap_sequential'
+CMAP_DIVERGENT = 'cmap_divergent'
+
OPTIONS = {
- 'display_width': 80,
- 'arithmetic_join': 'inner',
- 'enable_cftimeindex': False
+ DISPLAY_WIDTH: 80,
+ ARITHMETIC_JOIN: 'inner',
+ ENABLE_CFTIMEINDEX: False,
+ FILE_CACHE_MAXSIZE: 128,
+ CMAP_SEQUENTIAL: 'viridis',
+ CMAP_DIVERGENT: 'RdBu_r',
+}
+
+_JOIN_OPTIONS = frozenset(['inner', 'outer', 'left', 'right', 'exact'])
+
+
+def _positive_integer(value):
+ return isinstance(value, int) and value > 0
+
+
+_VALIDATORS = {
+ DISPLAY_WIDTH: _positive_integer,
+ ARITHMETIC_JOIN: _JOIN_OPTIONS.__contains__,
+ ENABLE_CFTIMEINDEX: lambda value: isinstance(value, bool),
+ FILE_CACHE_MAXSIZE: _positive_integer,
+}
+
+
+def _set_file_cache_maxsize(value):
+ from ..backends.file_manager import FILE_CACHE
+ FILE_CACHE.maxsize = value
+
+
+_SETTERS = {
+ FILE_CACHE_MAXSIZE: _set_file_cache_maxsize,
}
@@ -19,8 +53,18 @@ class set_options(object):
- ``enable_cftimeindex``: flag to enable using a ``CFTimeIndex``
for time indexes with non-standard calendars or dates outside the
Timestamp-valid range. Default: ``False``.
+ - ``file_cache_maxsize``: maximum number of open files to hold in xarray's
+ global least-recently-usage cached. This should be smaller than your
+ system's per-process file descriptor limit, e.g., ``ulimit -n`` on Linux.
+ Default: 128.
+ - ``cmap_sequential``: colormap to use for nondivergent data plots.
+ Default: ``viridis``. If string, must be matplotlib built-in colormap.
+ Can also be a Colormap object (e.g. mpl.cm.magma)
+ - ``cmap_divergent``: colormap to use for divergent data plots.
+ Default: ``RdBu_r``. If string, must be matplotlib built-in colormap.
+ Can also be a Colormap object (e.g. mpl.cm.magma)
- You can use ``set_options`` either as a context manager:
+f You can use ``set_options`` either as a context manager:
>>> ds = xr.Dataset({'x': np.arange(1000)})
>>> with xr.set_options(display_width=40):
@@ -38,16 +82,26 @@ class set_options(object):
"""
def __init__(self, **kwargs):
- invalid_options = {k for k in kwargs if k not in OPTIONS}
- if invalid_options:
- raise ValueError('argument names %r are not in the set of valid '
- 'options %r' % (invalid_options, set(OPTIONS)))
self.old = OPTIONS.copy()
- OPTIONS.update(kwargs)
+ for k, v in kwargs.items():
+ if k not in OPTIONS:
+ raise ValueError(
+ 'argument name %r is not in the set of valid options %r'
+ % (k, set(OPTIONS)))
+ if k in _VALIDATORS and not _VALIDATORS[k](v):
+ raise ValueError(
+ 'option %r given an invalid value: %r' % (k, v))
+ self._apply_update(kwargs)
+
+ def _apply_update(self, options_dict):
+ for k, v in options_dict.items():
+ if k in _SETTERS:
+ _SETTERS[k](v)
+ OPTIONS.update(options_dict)
def __enter__(self):
return
def __exit__(self, type, value, traceback):
OPTIONS.clear()
- OPTIONS.update(self.old)
+ self._apply_update(self.old)
diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py
index 78c26f1e92f..b980bc279b0 100644
--- a/xarray/core/pycompat.py
+++ b/xarray/core/pycompat.py
@@ -28,6 +28,9 @@ def itervalues(d):
import builtins
from urllib.request import urlretrieve
from inspect import getfullargspec as getargspec
+
+ def move_to_end(ordered_dict, key):
+ ordered_dict.move_to_end(key)
else: # pragma: no cover
# Python 2
basestring = basestring # noqa
@@ -50,6 +53,11 @@ def itervalues(d):
from urllib import urlretrieve
from inspect import getargspec
+ def move_to_end(ordered_dict, key):
+ value = ordered_dict[key]
+ del ordered_dict[key]
+ ordered_dict[key] = value
+
integer_types = native_int_types + (np.integer,)
try:
@@ -76,7 +84,6 @@ def itervalues(d):
except ImportError as e:
path_type = ()
-
try:
from contextlib import suppress
except ImportError:
diff --git a/xarray/core/resample.py b/xarray/core/resample.py
index 4933a09b257..bd84e04487e 100644
--- a/xarray/core/resample.py
+++ b/xarray/core/resample.py
@@ -1,7 +1,7 @@
from __future__ import absolute_import, division, print_function
from . import ops
-from .groupby import DataArrayGroupBy, DatasetGroupBy
+from .groupby import DEFAULT_DIMS, DataArrayGroupBy, DatasetGroupBy
from .pycompat import OrderedDict, dask_array_type
RESAMPLE_DIM = '__resample_dim__'
@@ -277,15 +277,14 @@ def reduce(self, func, dim=None, keep_attrs=False, **kwargs):
"""Reduce the items in this group by applying `func` along the
pre-defined resampling dimension.
- Note that `dim` is by default here and ignored if passed by the user;
- this ensures compatibility with the existing reduce interface.
-
Parameters
----------
func : function
Function which can be called in the form
`func(x, axis=axis, **kwargs)` to return the result of collapsing
an np.ndarray over an integer valued axis.
+ dim : str or sequence of str, optional
+ Dimension(s) over which to apply `func`.
keep_attrs : bool, optional
If True, the datasets's attributes (`attrs`) will be copied from
the original object to the new one. If False (default), the new
@@ -299,8 +298,11 @@ def reduce(self, func, dim=None, keep_attrs=False, **kwargs):
Array with summarized data and the indicated dimension(s)
removed.
"""
+ if dim == DEFAULT_DIMS:
+ dim = None
+
return super(DatasetResample, self).reduce(
- func, self._dim, keep_attrs, **kwargs)
+ func, dim, keep_attrs, **kwargs)
def _interpolate(self, kind='linear'):
"""Apply scipy.interpolate.interp1d along resampling dimension."""
diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py
index 24ed280b19e..883dbb34dff 100644
--- a/xarray/core/rolling.py
+++ b/xarray/core/rolling.py
@@ -44,7 +44,7 @@ class Rolling(object):
_attributes = ['window', 'min_periods', 'center', 'dim']
- def __init__(self, obj, min_periods=None, center=False, **windows):
+ def __init__(self, obj, windows, min_periods=None, center=False):
"""
Moving window object.
@@ -52,18 +52,18 @@ def __init__(self, obj, min_periods=None, center=False, **windows):
----------
obj : Dataset or DataArray
Object to window.
+ windows : A mapping from a dimension name to window size
+ dim : str
+ Name of the dimension to create the rolling iterator
+ along (e.g., `time`).
+ window : int
+ Size of the moving window.
min_periods : int, default None
Minimum number of observations in window required to have a value
(otherwise result is NA). The default, None, is equivalent to
setting min_periods equal to the size of the window.
center : boolean, default False
Set the labels at the center of the window.
- **windows : dim=window
- dim : str
- Name of the dimension to create the rolling iterator
- along (e.g., `time`).
- window : int
- Size of the moving window.
Returns
-------
@@ -115,7 +115,7 @@ def __len__(self):
class DataArrayRolling(Rolling):
- def __init__(self, obj, min_periods=None, center=False, **windows):
+ def __init__(self, obj, windows, min_periods=None, center=False):
"""
Moving window object for DataArray.
You should use DataArray.rolling() method to construct this object
@@ -125,18 +125,18 @@ def __init__(self, obj, min_periods=None, center=False, **windows):
----------
obj : DataArray
Object to window.
+ windows : A mapping from a dimension name to window size
+ dim : str
+ Name of the dimension to create the rolling iterator
+ along (e.g., `time`).
+ window : int
+ Size of the moving window.
min_periods : int, default None
Minimum number of observations in window required to have a value
(otherwise result is NA). The default, None, is equivalent to
setting min_periods equal to the size of the window.
center : boolean, default False
Set the labels at the center of the window.
- **windows : dim=window
- dim : str
- Name of the dimension to create the rolling iterator
- along (e.g., `time`).
- window : int
- Size of the moving window.
Returns
-------
@@ -149,8 +149,8 @@ def __init__(self, obj, min_periods=None, center=False, **windows):
Dataset.rolling
Dataset.groupby
"""
- super(DataArrayRolling, self).__init__(obj, min_periods=min_periods,
- center=center, **windows)
+ super(DataArrayRolling, self).__init__(
+ obj, windows, min_periods=min_periods, center=center)
self.window_labels = self.obj[self.dim]
@@ -321,7 +321,7 @@ def wrapped_func(self, **kwargs):
class DatasetRolling(Rolling):
- def __init__(self, obj, min_periods=None, center=False, **windows):
+ def __init__(self, obj, windows, min_periods=None, center=False):
"""
Moving window object for Dataset.
You should use Dataset.rolling() method to construct this object
@@ -331,18 +331,18 @@ def __init__(self, obj, min_periods=None, center=False, **windows):
----------
obj : Dataset
Object to window.
+ windows : A mapping from a dimension name to window size
+ dim : str
+ Name of the dimension to create the rolling iterator
+ along (e.g., `time`).
+ window : int
+ Size of the moving window.
min_periods : int, default None
Minimum number of observations in window required to have a value
(otherwise result is NA). The default, None, is equivalent to
setting min_periods equal to the size of the window.
center : boolean, default False
Set the labels at the center of the window.
- **windows : dim=window
- dim : str
- Name of the dimension to create the rolling iterator
- along (e.g., `time`).
- window : int
- Size of the moving window.
Returns
-------
@@ -355,8 +355,7 @@ def __init__(self, obj, min_periods=None, center=False, **windows):
Dataset.groupby
DataArray.groupby
"""
- super(DatasetRolling, self).__init__(obj,
- min_periods, center, **windows)
+ super(DatasetRolling, self).__init__(obj, windows, min_periods, center)
if self.dim not in self.obj.dims:
raise KeyError(self.dim)
# Keep each Rolling object as an OrderedDict
@@ -364,8 +363,8 @@ def __init__(self, obj, min_periods=None, center=False, **windows):
for key, da in self.obj.data_vars.items():
# keeps rollings only for the dataset depending on slf.dim
if self.dim in da.dims:
- self.rollings[key] = DataArrayRolling(da, min_periods,
- center, **windows)
+ self.rollings[key] = DataArrayRolling(
+ da, windows, min_periods, center)
def reduce(self, func, **kwargs):
"""Reduce the items in this group by applying `func` along some
diff --git a/xarray/core/utils.py b/xarray/core/utils.py
index c3bb747fac5..c39a07e1b5a 100644
--- a/xarray/core/utils.py
+++ b/xarray/core/utils.py
@@ -591,3 +591,29 @@ def __iter__(self):
def __len__(self):
num_hidden = sum([k in self._hidden_keys for k in self._data])
return len(self._data) - num_hidden
+
+
+def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
+ """Convert an array containing datetime-like data to an array of floats.
+
+ Parameters
+ ----------
+ da : array
+ Input data
+ offset: Scalar with the same type of array or None
+ If None, subtract minimum values to reduce round off error
+ datetime_unit: None or any of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms',
+ 'us', 'ns', 'ps', 'fs', 'as'}
+ dtype: target dtype
+
+ Returns
+ -------
+ array
+ """
+ if offset is None:
+ offset = array.min()
+ array = array - offset
+
+ if datetime_unit:
+ return (array / np.timedelta64(1, datetime_unit)).astype(dtype)
+ return array.astype(dtype)
diff --git a/xarray/core/variable.py b/xarray/core/variable.py
index d9772407b82..c003d52aab2 100644
--- a/xarray/core/variable.py
+++ b/xarray/core/variable.py
@@ -64,22 +64,15 @@ def as_variable(obj, name=None):
The newly created variable.
"""
+ from .dataarray import DataArray
+
# TODO: consider extending this method to automatically handle Iris and
- # pandas objects.
- if hasattr(obj, 'variable'):
+ if isinstance(obj, DataArray):
# extract the primary Variable from DataArrays
obj = obj.variable
if isinstance(obj, Variable):
obj = obj.copy(deep=False)
- elif hasattr(obj, 'dims') and (hasattr(obj, 'data') or
- hasattr(obj, 'values')):
- obj_data = getattr(obj, 'data', None)
- if obj_data is None:
- obj_data = getattr(obj, 'values')
- obj = Variable(obj.dims, obj_data,
- getattr(obj, 'attrs', None),
- getattr(obj, 'encoding', None))
elif isinstance(obj, tuple):
try:
obj = Variable(*obj)
@@ -728,24 +721,81 @@ def encoding(self, value):
except ValueError:
raise ValueError('encoding must be castable to a dictionary')
- def copy(self, deep=True):
+ def copy(self, deep=True, data=None):
"""Returns a copy of this object.
If `deep=True`, the data array is loaded into memory and copied onto
the new object. Dimensions, attributes and encodings are always copied.
- """
- data = self._data
- if isinstance(data, indexing.MemoryCachedArray):
- # don't share caching between copies
- data = indexing.MemoryCachedArray(data.array)
+ Use `data` to create a new object with the same structure as
+ original but entirely new data.
+
+ Parameters
+ ----------
+ deep : bool, optional
+ Whether the data array is loaded into memory and copied onto
+ the new object. Default is True.
+ data : array_like, optional
+ Data to use in the new object. Must have same shape as original.
+ When `data` is used, `deep` is ignored.
+
+ Returns
+ -------
+ object : Variable
+ New object with dimensions, attributes, encodings, and optionally
+ data copied from original.
- if deep:
- if isinstance(data, dask_array_type):
- data = data.copy()
- elif not isinstance(data, PandasIndexAdapter):
- # pandas.Index is immutable
- data = np.array(data)
+ Examples
+ --------
+
+ Shallow copy versus deep copy
+
+ >>> var = xr.Variable(data=[1, 2, 3], dims='x')
+ >>> var.copy()
+
+ array([1, 2, 3])
+ >>> var_0 = var.copy(deep=False)
+ >>> var_0[0] = 7
+ >>> var_0
+
+ array([7, 2, 3])
+ >>> var
+
+ array([7, 2, 3])
+
+ Changing the data using the ``data`` argument maintains the
+ structure of the original object, but with the new data. Original
+ object is unaffected.
+
+ >>> var.copy(data=[0.1, 0.2, 0.3])
+
+ array([ 0.1, 0.2, 0.3])
+ >>> var
+
+ array([7, 2, 3])
+
+ See Also
+ --------
+ pandas.DataFrame.copy
+ """
+ if data is None:
+ data = self._data
+
+ if isinstance(data, indexing.MemoryCachedArray):
+ # don't share caching between copies
+ data = indexing.MemoryCachedArray(data.array)
+
+ if deep:
+ if isinstance(data, dask_array_type):
+ data = data.copy()
+ elif not isinstance(data, PandasIndexAdapter):
+ # pandas.Index is immutable
+ data = np.array(data)
+ else:
+ data = as_compatible_data(data)
+ if self.shape != data.shape:
+ raise ValueError("Data shape {} must match shape of object {}"
+ .format(data.shape, self.shape))
# note:
# dims is already an immutable tuple
@@ -877,7 +927,7 @@ def squeeze(self, dim=None):
numpy.squeeze
"""
dims = common.get_squeeze_dims(self, dim)
- return self.isel(**{d: 0 for d in dims})
+ return self.isel({d: 0 for d in dims})
def _shift_one_dim(self, dim, count):
axis = self.get_axis_num(dim)
@@ -919,36 +969,46 @@ def _shift_one_dim(self, dim, count):
return type(self)(self.dims, data, self._attrs, fastpath=True)
- def shift(self, **shifts):
+ def shift(self, shifts=None, **shifts_kwargs):
"""
Return a new Variable with shifted data.
Parameters
----------
- **shifts : keyword arguments of the form {dim: offset}
+ shifts : mapping of the form {dim: offset}
Integer offset to shift along each of the given dimensions.
Positive offsets shift to the right; negative offsets shift to the
left.
+ **shifts_kwargs:
+ The keyword arguments form of ``shifts``.
+ One of shifts or shifts_kwarg must be provided.
Returns
-------
shifted : Variable
Variable with the same dimensions and attributes but shifted data.
"""
+ shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'shift')
result = self
for dim, count in shifts.items():
result = result._shift_one_dim(dim, count)
return result
- def pad_with_fill_value(self, fill_value=dtypes.NA, **pad_widths):
+ def pad_with_fill_value(self, pad_widths=None, fill_value=dtypes.NA,
+ **pad_widths_kwargs):
"""
Return a new Variable with paddings.
Parameters
----------
- **pad_width: keyword arguments of the form {dim: (before, after)}
+ pad_width: Mapping of the form {dim: (before, after)}
Number of values padded to the edges of each dimension.
+ **pad_widths_kwargs:
+ Keyword argument for pad_widths
"""
+ pad_widths = either_dict_or_kwargs(pad_widths, pad_widths_kwargs,
+ 'pad')
+
if fill_value is dtypes.NA: # np.nan is passed
dtype, fill_value = dtypes.maybe_promote(self.dtype)
else:
@@ -1009,22 +1069,27 @@ def _roll_one_dim(self, dim, count):
return type(self)(self.dims, data, self._attrs, fastpath=True)
- def roll(self, **shifts):
+ def roll(self, shifts=None, **shifts_kwargs):
"""
Return a new Variable with rolld data.
Parameters
----------
- **shifts : keyword arguments of the form {dim: offset}
+ shifts : mapping of the form {dim: offset}
Integer offset to roll along each of the given dimensions.
Positive offsets roll to the right; negative offsets roll to the
left.
+ **shifts_kwargs:
+ The keyword arguments form of ``shifts``.
+ One of shifts or shifts_kwarg must be provided.
Returns
-------
shifted : Variable
Variable with the same dimensions and attributes but rolled data.
"""
+ shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'roll')
+
result = self
for dim, count in shifts.items():
result = result._roll_one_dim(dim, count)
@@ -1142,7 +1207,7 @@ def _stack_once(self, dims, new_dim):
return Variable(new_dims, new_data, self._attrs, self._encoding,
fastpath=True)
- def stack(self, **dimensions):
+ def stack(self, dimensions=None, **dimensions_kwargs):
"""
Stack any number of existing dimensions into a single new dimension.
@@ -1151,9 +1216,12 @@ def stack(self, **dimensions):
Parameters
----------
- **dimensions : keyword arguments of the form new_name=(dim1, dim2, ...)
+ dimensions : Mapping of form new_name=(dim1, dim2, ...)
Names of new dimensions, and the existing dimensions that they
replace.
+ **dimensions_kwargs:
+ The keyword arguments form of ``dimensions``.
+ One of dimensions or dimensions_kwargs must be provided.
Returns
-------
@@ -1164,6 +1232,8 @@ def stack(self, **dimensions):
--------
Variable.unstack
"""
+ dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs,
+ 'stack')
result = self
for new_dim, dims in dimensions.items():
result = result._stack_once(dims, new_dim)
@@ -1195,7 +1265,7 @@ def _unstack_once(self, dims, old_dim):
return Variable(new_dims, new_data, self._attrs, self._encoding,
fastpath=True)
- def unstack(self, **dimensions):
+ def unstack(self, dimensions=None, **dimensions_kwargs):
"""
Unstack an existing dimension into multiple new dimensions.
@@ -1204,9 +1274,12 @@ def unstack(self, **dimensions):
Parameters
----------
- **dimensions : keyword arguments of the form old_dim={dim1: size1, ...}
+ dimensions : mapping of the form old_dim={dim1: size1, ...}
Names of existing dimensions, and the new dimensions and sizes
that they map to.
+ **dimensions_kwargs:
+ The keyword arguments form of ``dimensions``.
+ One of dimensions or dimensions_kwargs must be provided.
Returns
-------
@@ -1217,6 +1290,8 @@ def unstack(self, **dimensions):
--------
Variable.stack
"""
+ dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs,
+ 'unstack')
result = self
for old_dim, dims in dimensions.items():
result = result._unstack_once(dims, old_dim)
@@ -1258,6 +1333,8 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False,
Array with summarized data and the indicated dimension(s)
removed.
"""
+ if dim is common.ALL_DIMS:
+ dim = None
if dim is not None and axis is not None:
raise ValueError("cannot supply both 'axis' and 'dim' arguments")
@@ -1691,14 +1768,37 @@ def concat(cls, variables, dim='concat_dim', positions=None,
return cls(first_var.dims, data, attrs)
- def copy(self, deep=True):
+ def copy(self, deep=True, data=None):
"""Returns a copy of this object.
- `deep` is ignored since data is stored in the form of pandas.Index,
- which is already immutable. Dimensions, attributes and encodings are
- always copied.
+ `deep` is ignored since data is stored in the form of
+ pandas.Index, which is already immutable. Dimensions, attributes
+ and encodings are always copied.
+
+ Use `data` to create a new object with the same structure as
+ original but entirely new data.
+
+ Parameters
+ ----------
+ deep : bool, optional
+ Deep is always ignored.
+ data : array_like, optional
+ Data to use in the new object. Must have same shape as original.
+
+ Returns
+ -------
+ object : Variable
+ New object with dimensions, attributes, encodings, and optionally
+ data copied from original.
"""
- return type(self)(self.dims, self._data, self._attrs,
+ if data is None:
+ data = self._data
+ else:
+ data = as_compatible_data(data)
+ if self.shape != data.shape:
+ raise ValueError("Data shape {} must match shape of object {}"
+ .format(data.shape, self.shape))
+ return type(self)(self.dims, data, self._attrs,
self._encoding, fastpath=True)
def equals(self, other, equiv=None):
diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py
index a0d7c4dd5e2..32a954a3fcd 100644
--- a/xarray/plot/facetgrid.py
+++ b/xarray/plot/facetgrid.py
@@ -5,6 +5,7 @@
import warnings
import numpy as np
+
from ..core.formatting import format_item
from ..core.pycompat import getargspec
from .utils import (
@@ -188,6 +189,7 @@ def __init__(self, data, col=None, row=None, col_wrap=None,
self._y_var = None
self._cmap_extend = None
self._mappables = []
+ self._finalized = False
@property
def _left_axes(self):
@@ -308,13 +310,16 @@ def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs):
def _finalize_grid(self, *axlabels):
"""Finalize the annotations and layout."""
- self.set_axis_labels(*axlabels)
- self.set_titles()
- self.fig.tight_layout()
+ if not self._finalized:
+ self.set_axis_labels(*axlabels)
+ self.set_titles()
+ self.fig.tight_layout()
- for ax, namedict in zip(self.axes.flat, self.name_dicts.flat):
- if namedict is None:
- ax.set_visible(False)
+ for ax, namedict in zip(self.axes.flat, self.name_dicts.flat):
+ if namedict is None:
+ ax.set_visible(False)
+
+ self._finalized = True
def add_legend(self, **kwargs):
figlegend = self.fig.legend(
@@ -502,9 +507,12 @@ def map(self, func, *args, **kwargs):
data = self.data.loc[namedict]
plt.sca(ax)
innerargs = [data[a].values for a in args]
- # TODO: is it possible to verify that an artist is mappable?
- mappable = func(*innerargs, **kwargs)
- self._mappables.append(mappable)
+ maybe_mappable = func(*innerargs, **kwargs)
+ # TODO: better way to verify that an artist is mappable?
+ # https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522
+ if (maybe_mappable and
+ hasattr(maybe_mappable, 'autoscale_None')):
+ self._mappables.append(maybe_mappable)
self._finalize_grid(*args[:2])
diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py
index 3e7e5909a70..52be985153a 100644
--- a/xarray/plot/plot.py
+++ b/xarray/plot/plot.py
@@ -600,9 +600,11 @@ def step(self, *args, **kwargs):
def _rescale_imshow_rgb(darray, vmin, vmax, robust):
assert robust or vmin is not None or vmax is not None
+ # TODO: remove when min numpy version is bumped to 1.13
# There's a cyclic dependency via DataArray, so we can't import from
# xarray.ufuncs in global scope.
from xarray.ufuncs import maximum, minimum
+
# Calculate vmin and vmax automatically for `robust=True`
if robust:
if vmax is None:
@@ -628,7 +630,10 @@ def _rescale_imshow_rgb(darray, vmin, vmax, robust):
# After scaling, downcast to 32-bit float. This substantially reduces
# memory usage after we hand `darray` off to matplotlib.
darray = ((darray.astype('f8') - vmin) / (vmax - vmin)).astype('f4')
- return minimum(maximum(darray, 0), 1)
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', 'xarray.ufuncs',
+ PendingDeprecationWarning)
+ return minimum(maximum(darray, 0), 1)
def _plot2d(plotfunc):
@@ -678,6 +683,9 @@ def _plot2d(plotfunc):
Adds colorbar to axis
add_labels : Boolean, optional
Use xarray metadata to label axes
+ norm : ``matplotlib.colors.Normalize`` instance, optional
+ If the ``norm`` has vmin or vmax specified, the corresponding kwarg
+ must be None.
vmin, vmax : floats, optional
Values to anchor the colormap, otherwise they are inferred from the
data and other keyword arguments. When a diverging dataset is inferred,
@@ -746,7 +754,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
levels=None, infer_intervals=None, colors=None,
subplot_kws=None, cbar_ax=None, cbar_kwargs=None,
xscale=None, yscale=None, xticks=None, yticks=None,
- xlim=None, ylim=None, **kwargs):
+ xlim=None, ylim=None, norm=None, **kwargs):
# All 2d plots in xarray share this function signature.
# Method signature below should be consistent.
@@ -847,6 +855,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
'extend': extend,
'levels': levels,
'filled': plotfunc.__name__ != 'contour',
+ 'norm': norm,
}
cmap_params = _determine_cmap_params(**cmap_kwargs)
@@ -857,13 +866,15 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
# pcolormesh
kwargs['extend'] = cmap_params['extend']
kwargs['levels'] = cmap_params['levels']
+ # if colors == a single color, matplotlib draws dashed negative
+ # contours. we lose this feature if we pass cmap and not colors
+ if isinstance(colors, basestring):
+ cmap_params['cmap'] = None
+ kwargs['colors'] = colors
if 'pcolormesh' == plotfunc.__name__:
kwargs['infer_intervals'] = infer_intervals
- # This allows the user to pass in a custom norm coming via kwargs
- kwargs.setdefault('norm', cmap_params['norm'])
-
if 'imshow' == plotfunc.__name__ and isinstance(aspect, basestring):
# forbid usage of mpl strings
raise ValueError("plt.imshow's `aspect` kwarg is not available "
@@ -873,6 +884,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
primitive = plotfunc(xplt, yplt, zval, ax=ax, cmap=cmap_params['cmap'],
vmin=cmap_params['vmin'],
vmax=cmap_params['vmax'],
+ norm=cmap_params['norm'],
**kwargs)
# Label the plot with metadata
@@ -890,12 +902,16 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
cbar_kwargs.setdefault('cax', cbar_ax)
cbar = plt.colorbar(primitive, **cbar_kwargs)
if add_labels and 'label' not in cbar_kwargs:
- cbar.set_label(label_from_attrs(darray), rotation=90)
+ cbar.set_label(label_from_attrs(darray))
elif cbar_ax is not None or cbar_kwargs is not None:
# inform the user about keywords which aren't used
raise ValueError("cbar_ax and cbar_kwargs can't be used with "
"add_colorbar=False.")
+ # origin kwarg overrides yincrease
+ if 'origin' in kwargs:
+ yincrease = None
+
_update_axes(ax, xincrease, yincrease, xscale, yscale,
xticks, yticks, xlim, ylim)
@@ -920,7 +936,7 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, figsize=None, size=None,
levels=None, infer_intervals=None, subplot_kws=None,
cbar_ax=None, cbar_kwargs=None,
xscale=None, yscale=None, xticks=None, yticks=None,
- xlim=None, ylim=None, **kwargs):
+ xlim=None, ylim=None, norm=None, **kwargs):
"""
The method should have the same signature as the function.
@@ -982,10 +998,8 @@ def imshow(x, y, z, ax, **kwargs):
left, right = x[0] - xstep, x[-1] + xstep
bottom, top = y[-1] + ystep, y[0] - ystep
- defaults = {'extent': [left, right, bottom, top],
- 'origin': 'upper',
- 'interpolation': 'nearest',
- }
+ defaults = {'origin': 'upper',
+ 'interpolation': 'nearest'}
if not hasattr(ax, 'projection'):
# not for cartopy geoaxes
@@ -994,6 +1008,11 @@ def imshow(x, y, z, ax, **kwargs):
# Allow user to override these defaults
defaults.update(kwargs)
+ if defaults['origin'] == 'upper':
+ defaults['extent'] = [left, right, bottom, top]
+ else:
+ defaults['extent'] = [left, right, top, bottom]
+
if z.ndim == 3:
# matplotlib imshow uses black for missing data, but Xarray makes
# missing data transparent. We therefore add an alpha channel if
diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py
index 4a09b66ca33..f39c989a514 100644
--- a/xarray/plot/utils.py
+++ b/xarray/plot/utils.py
@@ -1,10 +1,11 @@
from __future__ import absolute_import, division, print_function
+import textwrap
import warnings
import numpy as np
-import textwrap
+from ..core.options import OPTIONS
from ..core.pycompat import basestring
from ..core.utils import is_scalar
@@ -171,6 +172,10 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
# vlim might be computed below
vlim = None
+ # save state; needed later
+ vmin_was_none = vmin is None
+ vmax_was_none = vmax is None
+
if vmin is None:
if robust:
vmin = np.percentile(calc_data, ROBUST_PERCENTILE)
@@ -203,18 +208,42 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
vmin += center
vmax += center
+ # now check norm and harmonize with vmin, vmax
+ if norm is not None:
+ if norm.vmin is None:
+ norm.vmin = vmin
+ else:
+ if not vmin_was_none and vmin != norm.vmin:
+ raise ValueError('Cannot supply vmin and a norm'
+ + ' with a different vmin.')
+ vmin = norm.vmin
+
+ if norm.vmax is None:
+ norm.vmax = vmax
+ else:
+ if not vmax_was_none and vmax != norm.vmax:
+ raise ValueError('Cannot supply vmax and a norm'
+ + ' with a different vmax.')
+ vmax = norm.vmax
+
+ # if BoundaryNorm, then set levels
+ if isinstance(norm, mpl.colors.BoundaryNorm):
+ levels = norm.boundaries
+
# Choose default colormaps if not provided
if cmap is None:
if divergent:
- cmap = "RdBu_r"
+ cmap = OPTIONS['cmap_divergent']
else:
- cmap = "viridis"
+ cmap = OPTIONS['cmap_sequential']
# Handle discrete levels
- if levels is not None:
+ if levels is not None and norm is None:
if is_scalar(levels):
- if user_minmax or levels == 1:
+ if user_minmax:
levels = np.linspace(vmin, vmax, levels)
+ elif levels == 1:
+ levels = np.asarray([(vmin + vmax) / 2])
else:
# N in MaxNLocator refers to bins, not ticks
ticker = mpl.ticker.MaxNLocator(levels - 1)
@@ -224,8 +253,9 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
if extend is None:
extend = _determine_extend(calc_data, vmin, vmax)
- if levels is not None:
- cmap, norm = _build_discrete_cmap(cmap, levels, extend, filled)
+ if levels is not None or isinstance(norm, mpl.colors.BoundaryNorm):
+ cmap, newnorm = _build_discrete_cmap(cmap, levels, extend, filled)
+ norm = newnorm if norm is None else norm
return dict(vmin=vmin, vmax=vmax, cmap=cmap, extend=extend,
levels=levels, norm=norm)
diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py
index 33a8da6bbfb..285c1f03a26 100644
--- a/xarray/tests/__init__.py
+++ b/xarray/tests/__init__.py
@@ -9,11 +9,10 @@
import numpy as np
from numpy.testing import assert_array_equal # noqa: F401
-from xarray.core.duck_array_ops import allclose_or_equiv
+from xarray.core.duck_array_ops import allclose_or_equiv # noqa
import pytest
from xarray.core import utils
-from xarray.core.pycompat import PY3
from xarray.core.indexing import ExplicitlyIndexed
from xarray.testing import (assert_equal, assert_identical, # noqa: F401
assert_allclose)
@@ -25,10 +24,6 @@
# old location, for pandas < 0.20
from pandas.util.testing import assert_frame_equal # noqa: F401
-try:
- import unittest2 as unittest
-except ImportError:
- import unittest
try:
from unittest import mock
@@ -117,39 +112,6 @@ def _importorskip(modname, minversion=None):
"internet connection")
-class TestCase(unittest.TestCase):
- """
- These functions are all deprecated. Instead, use functions in xr.testing
- """
- if PY3:
- # Python 3 assertCountEqual is roughly equivalent to Python 2
- # assertItemsEqual
- def assertItemsEqual(self, first, second, msg=None):
- __tracebackhide__ = True # noqa: F841
- return self.assertCountEqual(first, second, msg)
-
- @contextmanager
- def assertWarns(self, message):
- __tracebackhide__ = True # noqa: F841
- with warnings.catch_warnings(record=True) as w:
- warnings.filterwarnings('always', message)
- yield
- assert len(w) > 0
- assert any(message in str(wi.message) for wi in w)
-
- def assertVariableNotEqual(self, v1, v2):
- __tracebackhide__ = True # noqa: F841
- assert not v1.equals(v2)
-
- def assertEqual(self, a1, a2):
- __tracebackhide__ = True # noqa: F841
- assert a1 == a2 or (a1 != a1 and a2 != a2)
-
- def assertAllClose(self, a1, a2, rtol=1e-05, atol=1e-8):
- __tracebackhide__ = True # noqa: F841
- assert allclose_or_equiv(a1, a2, rtol=rtol, atol=atol)
-
-
@contextmanager
def raises_regex(error, pattern):
__tracebackhide__ = True # noqa: F841
diff --git a/xarray/tests/test_accessors.py b/xarray/tests/test_accessors.py
index e1b3a95b942..38038fc8f65 100644
--- a/xarray/tests/test_accessors.py
+++ b/xarray/tests/test_accessors.py
@@ -7,12 +7,13 @@
import xarray as xr
from . import (
- TestCase, assert_array_equal, assert_equal, raises_regex, requires_dask,
- has_cftime, has_dask, has_cftime_or_netCDF4)
+ assert_array_equal, assert_equal, has_cftime, has_cftime_or_netCDF4,
+ has_dask, raises_regex, requires_dask)
-class TestDatetimeAccessor(TestCase):
- def setUp(self):
+class TestDatetimeAccessor(object):
+ @pytest.fixture(autouse=True)
+ def setup(self):
nt = 100
data = np.random.rand(10, 10, nt)
lons = np.linspace(0, 11, 10)
diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py
index e6de50b9dd2..43811942d5f 100644
--- a/xarray/tests/test_backends.py
+++ b/xarray/tests/test_backends.py
@@ -2,12 +2,12 @@
import contextlib
import itertools
+import math
import os.path
import pickle
import shutil
import sys
import tempfile
-import unittest
import warnings
from io import BytesIO
@@ -19,22 +19,21 @@
from xarray import (
DataArray, Dataset, backends, open_dataarray, open_dataset, open_mfdataset,
save_mfdataset)
-from xarray.backends.common import (robust_getitem,
- PickleByReconstructionWrapper)
+from xarray.backends.common import robust_getitem
from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding
from xarray.backends.pydap_ import PydapDataStore
from xarray.core import indexing
from xarray.core.pycompat import (
- PY2, ExitStack, basestring, dask_array_type, iteritems)
+ ExitStack, basestring, dask_array_type, iteritems)
+from xarray.core.options import set_options
from xarray.tests import mock
from . import (
- TestCase, assert_allclose, assert_array_equal, assert_equal,
- assert_identical, has_dask, has_netCDF4, has_scipy, network, raises_regex,
+ assert_allclose, assert_array_equal, assert_equal, assert_identical,
+ has_dask, has_netCDF4, has_scipy, network, raises_regex, requires_cftime,
requires_dask, requires_h5netcdf, requires_netCDF4, requires_pathlib,
- requires_pydap, requires_pynio, requires_rasterio, requires_scipy,
- requires_scipy_or_netCDF4, requires_zarr, requires_pseudonetcdf,
- requires_cftime)
+ requires_pseudonetcdf, requires_pydap, requires_pynio, requires_rasterio,
+ requires_scipy, requires_scipy_or_netCDF4, requires_zarr)
from .test_dataset import create_test_data
try:
@@ -106,7 +105,7 @@ def create_boolean_data():
return Dataset({'x': ('t', [True, False, False, True], attributes)})
-class TestCommon(TestCase):
+class TestCommon(object):
def test_robust_getitem(self):
class UnreliableArrayFailure(Exception):
@@ -126,11 +125,11 @@ def __getitem__(self, key):
array = UnreliableArray([0])
with pytest.raises(UnreliableArrayFailure):
array[0]
- self.assertEqual(array[0], 0)
+ assert array[0] == 0
actual = robust_getitem(array, 0, catch=UnreliableArrayFailure,
initial_delay=0)
- self.assertEqual(actual, 0)
+ assert actual == 0
class NetCDF3Only(object):
@@ -138,7 +137,6 @@ class NetCDF3Only(object):
class DatasetIOTestCases(object):
- autoclose = False
engine = None
file_format = None
@@ -172,8 +170,7 @@ def save(self, dataset, path, **kwargs):
@contextlib.contextmanager
def open(self, path, **kwargs):
- with open_dataset(path, engine=self.engine, autoclose=self.autoclose,
- **kwargs) as ds:
+ with open_dataset(path, engine=self.engine, **kwargs) as ds:
yield ds
def test_zero_dimensional_variable(self):
@@ -222,11 +219,11 @@ def assert_loads(vars=None):
with self.roundtrip(expected) as actual:
for k, v in actual.variables.items():
# IndexVariables are eagerly loaded into memory
- self.assertEqual(v._in_memory, k in actual.dims)
+ assert v._in_memory == (k in actual.dims)
yield actual
for k, v in actual.variables.items():
if k in vars:
- self.assertTrue(v._in_memory)
+ assert v._in_memory
assert_identical(expected, actual)
with pytest.raises(AssertionError):
@@ -252,14 +249,14 @@ def test_dataset_compute(self):
# Test Dataset.compute()
for k, v in actual.variables.items():
# IndexVariables are eagerly cached
- self.assertEqual(v._in_memory, k in actual.dims)
+ assert v._in_memory == (k in actual.dims)
computed = actual.compute()
for k, v in actual.variables.items():
- self.assertEqual(v._in_memory, k in actual.dims)
+ assert v._in_memory == (k in actual.dims)
for v in computed.variables.values():
- self.assertTrue(v._in_memory)
+ assert v._in_memory
assert_identical(expected, actual)
assert_identical(expected, computed)
@@ -343,12 +340,12 @@ def test_roundtrip_string_encoded_characters(self):
expected['x'].encoding['dtype'] = 'S1'
with self.roundtrip(expected) as actual:
assert_identical(expected, actual)
- self.assertEqual(actual['x'].encoding['_Encoding'], 'utf-8')
+ assert actual['x'].encoding['_Encoding'] == 'utf-8'
expected['x'].encoding['_Encoding'] = 'ascii'
with self.roundtrip(expected) as actual:
assert_identical(expected, actual)
- self.assertEqual(actual['x'].encoding['_Encoding'], 'ascii')
+ assert actual['x'].encoding['_Encoding'] == 'ascii'
def test_roundtrip_numpy_datetime_data(self):
times = pd.to_datetime(['2000-01-01', '2000-01-02', 'NaT'])
@@ -434,10 +431,10 @@ def test_roundtrip_coordinates_with_space(self):
def test_roundtrip_boolean_dtype(self):
original = create_boolean_data()
- self.assertEqual(original['x'].dtype, 'bool')
+ assert original['x'].dtype == 'bool'
with self.roundtrip(original) as actual:
assert_identical(original, actual)
- self.assertEqual(actual['x'].dtype, 'bool')
+ assert actual['x'].dtype == 'bool'
def test_orthogonal_indexing(self):
in_memory = create_test_data()
@@ -626,20 +623,20 @@ def test_unsigned_roundtrip_mask_and_scale(self):
encoded = create_encoded_unsigned_masked_scaled_data()
with self.roundtrip(decoded) as actual:
for k in decoded.variables:
- self.assertEqual(decoded.variables[k].dtype,
- actual.variables[k].dtype)
+ assert (decoded.variables[k].dtype ==
+ actual.variables[k].dtype)
assert_allclose(decoded, actual, decode_bytes=False)
with self.roundtrip(decoded,
open_kwargs=dict(decode_cf=False)) as actual:
for k in encoded.variables:
- self.assertEqual(encoded.variables[k].dtype,
- actual.variables[k].dtype)
+ assert (encoded.variables[k].dtype ==
+ actual.variables[k].dtype)
assert_allclose(encoded, actual, decode_bytes=False)
with self.roundtrip(encoded,
open_kwargs=dict(decode_cf=False)) as actual:
for k in encoded.variables:
- self.assertEqual(encoded.variables[k].dtype,
- actual.variables[k].dtype)
+ assert (encoded.variables[k].dtype ==
+ actual.variables[k].dtype)
assert_allclose(encoded, actual, decode_bytes=False)
# make sure roundtrip encoding didn't change the
# original dataset.
@@ -647,14 +644,14 @@ def test_unsigned_roundtrip_mask_and_scale(self):
encoded, create_encoded_unsigned_masked_scaled_data())
with self.roundtrip(encoded) as actual:
for k in decoded.variables:
- self.assertEqual(decoded.variables[k].dtype,
- actual.variables[k].dtype)
+ assert decoded.variables[k].dtype == \
+ actual.variables[k].dtype
assert_allclose(decoded, actual, decode_bytes=False)
with self.roundtrip(encoded,
open_kwargs=dict(decode_cf=False)) as actual:
for k in encoded.variables:
- self.assertEqual(encoded.variables[k].dtype,
- actual.variables[k].dtype)
+ assert encoded.variables[k].dtype == \
+ actual.variables[k].dtype
assert_allclose(encoded, actual, decode_bytes=False)
def test_roundtrip_mask_and_scale(self):
@@ -692,12 +689,11 @@ def equals_latlon(obj):
with create_tmp_file() as tmp_file:
original.to_netcdf(tmp_file)
with open_dataset(tmp_file, decode_coords=False) as ds:
- self.assertTrue(equals_latlon(ds['temp'].attrs['coordinates']))
- self.assertTrue(
- equals_latlon(ds['precip'].attrs['coordinates']))
- self.assertNotIn('coordinates', ds.attrs)
- self.assertNotIn('coordinates', ds['lat'].attrs)
- self.assertNotIn('coordinates', ds['lon'].attrs)
+ assert equals_latlon(ds['temp'].attrs['coordinates'])
+ assert equals_latlon(ds['precip'].attrs['coordinates'])
+ assert 'coordinates' not in ds.attrs
+ assert 'coordinates' not in ds['lat'].attrs
+ assert 'coordinates' not in ds['lon'].attrs
modified = original.drop(['temp', 'precip'])
with self.roundtrip(modified) as actual:
@@ -705,9 +701,9 @@ def equals_latlon(obj):
with create_tmp_file() as tmp_file:
modified.to_netcdf(tmp_file)
with open_dataset(tmp_file, decode_coords=False) as ds:
- self.assertTrue(equals_latlon(ds.attrs['coordinates']))
- self.assertNotIn('coordinates', ds['lat'].attrs)
- self.assertNotIn('coordinates', ds['lon'].attrs)
+ assert equals_latlon(ds.attrs['coordinates'])
+ assert 'coordinates' not in ds['lat'].attrs
+ assert 'coordinates' not in ds['lon'].attrs
def test_roundtrip_endian(self):
ds = Dataset({'x': np.arange(3, 10, dtype='>i2'),
@@ -743,8 +739,8 @@ def test_encoding_kwarg(self):
ds = Dataset({'x': ('y', np.arange(10.0))})
kwargs = dict(encoding={'x': {'dtype': 'f4'}})
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
- self.assertEqual(actual.x.encoding['dtype'], 'f4')
- self.assertEqual(ds.x.encoding, {})
+ assert actual.x.encoding['dtype'] == 'f4'
+ assert ds.x.encoding == {}
kwargs = dict(encoding={'x': {'foo': 'bar'}})
with raises_regex(ValueError, 'unexpected encoding'):
@@ -766,7 +762,7 @@ def test_encoding_kwarg_dates(self):
units = 'days since 1900-01-01'
kwargs = dict(encoding={'t': {'units': units}})
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
- self.assertEqual(actual.t.encoding['units'], units)
+ assert actual.t.encoding['units'] == units
assert_identical(actual, ds)
def test_encoding_kwarg_fixed_width_string(self):
@@ -778,7 +774,7 @@ def test_encoding_kwarg_fixed_width_string(self):
ds = Dataset({'x': strings})
kwargs = dict(encoding={'x': {'dtype': 'S1'}})
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
- self.assertEqual(actual['x'].encoding['dtype'], 'S1')
+ assert actual['x'].encoding['dtype'] == 'S1'
assert_identical(actual, ds)
def test_default_fill_value(self):
@@ -786,9 +782,8 @@ def test_default_fill_value(self):
ds = Dataset({'x': ('y', np.arange(10.0))})
kwargs = dict(encoding={'x': {'dtype': 'f4'}})
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
- self.assertEqual(actual.x.encoding['_FillValue'],
- np.nan)
- self.assertEqual(ds.x.encoding, {})
+ assert math.isnan(actual.x.encoding['_FillValue'])
+ assert ds.x.encoding == {}
# Test default encoding for int:
ds = Dataset({'x': ('y', np.arange(10.0))})
@@ -797,14 +792,14 @@ def test_default_fill_value(self):
warnings.filterwarnings(
'ignore', '.*floating point data as an integer')
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
- self.assertTrue('_FillValue' not in actual.x.encoding)
- self.assertEqual(ds.x.encoding, {})
+ assert '_FillValue' not in actual.x.encoding
+ assert ds.x.encoding == {}
# Test default encoding for implicit int:
ds = Dataset({'x': ('y', np.arange(10, dtype='int16'))})
with self.roundtrip(ds) as actual:
- self.assertTrue('_FillValue' not in actual.x.encoding)
- self.assertEqual(ds.x.encoding, {})
+ assert '_FillValue' not in actual.x.encoding
+ assert ds.x.encoding == {}
def test_explicitly_omit_fill_value(self):
ds = Dataset({'x': ('y', [np.pi, -np.pi])})
@@ -817,7 +812,7 @@ def test_explicitly_omit_fill_value_via_encoding_kwarg(self):
kwargs = dict(encoding={'x': {'_FillValue': None}})
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
assert '_FillValue' not in actual.x.encoding
- self.assertEqual(ds.y.encoding, {})
+ assert ds.y.encoding == {}
def test_explicitly_omit_fill_value_in_coord(self):
ds = Dataset({'x': ('y', [np.pi, -np.pi])}, coords={'y': [0.0, 1.0]})
@@ -830,14 +825,14 @@ def test_explicitly_omit_fill_value_in_coord_via_encoding_kwarg(self):
kwargs = dict(encoding={'y': {'_FillValue': None}})
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
assert '_FillValue' not in actual.y.encoding
- self.assertEqual(ds.y.encoding, {})
+ assert ds.y.encoding == {}
def test_encoding_same_dtype(self):
ds = Dataset({'x': ('y', np.arange(10.0, dtype='f4'))})
kwargs = dict(encoding={'x': {'dtype': 'f4'}})
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
- self.assertEqual(actual.x.encoding['dtype'], 'f4')
- self.assertEqual(ds.x.encoding, {})
+ assert actual.x.encoding['dtype'] == 'f4'
+ assert ds.x.encoding == {}
def test_append_write(self):
# regression for GH1215
@@ -1015,7 +1010,7 @@ def test_default_to_char_arrays(self):
data = Dataset({'x': np.array(['foo', 'zzzz'], dtype='S')})
with self.roundtrip(data) as actual:
assert_identical(data, actual)
- self.assertEqual(actual['x'].dtype, np.dtype('S4'))
+ assert actual['x'].dtype == np.dtype('S4')
def test_open_encodings(self):
# Create a netCDF file with explicit time units
@@ -1040,15 +1035,15 @@ def test_open_encodings(self):
actual_encoding = dict((k, v) for k, v in
iteritems(actual['time'].encoding)
if k in expected['time'].encoding)
- self.assertDictEqual(actual_encoding,
- expected['time'].encoding)
+ assert actual_encoding == \
+ expected['time'].encoding
def test_dump_encodings(self):
# regression test for #709
ds = Dataset({'x': ('y', np.arange(10.0))})
kwargs = dict(encoding={'x': {'zlib': True}})
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
- self.assertTrue(actual.x.encoding['zlib'])
+ assert actual.x.encoding['zlib']
def test_dump_and_open_encodings(self):
# Create a netCDF file with explicit time units
@@ -1066,8 +1061,7 @@ def test_dump_and_open_encodings(self):
with create_tmp_file() as tmp_file2:
xarray_dataset.to_netcdf(tmp_file2)
with nc4.Dataset(tmp_file2, 'r') as ds:
- self.assertEqual(
- ds.variables['time'].getncattr('units'), units)
+ assert ds.variables['time'].getncattr('units') == units
assert_array_equal(
ds.variables['time'], np.arange(10) + 4)
@@ -1080,7 +1074,7 @@ def test_compression_encoding(self):
'original_shape': data.var2.shape})
with self.roundtrip(data) as actual:
for k, v in iteritems(data['var2'].encoding):
- self.assertEqual(v, actual['var2'].encoding[k])
+ assert v == actual['var2'].encoding[k]
# regression test for #156
expected = data.isel(dim1=0)
@@ -1095,14 +1089,14 @@ def test_encoding_kwarg_compression(self):
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
assert_equal(actual, ds)
- self.assertEqual(actual.x.encoding['dtype'], 'f4')
- self.assertEqual(actual.x.encoding['zlib'], True)
- self.assertEqual(actual.x.encoding['complevel'], 9)
- self.assertEqual(actual.x.encoding['fletcher32'], True)
- self.assertEqual(actual.x.encoding['chunksizes'], (5,))
- self.assertEqual(actual.x.encoding['shuffle'], True)
+ assert actual.x.encoding['dtype'] == 'f4'
+ assert actual.x.encoding['zlib']
+ assert actual.x.encoding['complevel'] == 9
+ assert actual.x.encoding['fletcher32']
+ assert actual.x.encoding['chunksizes'] == (5,)
+ assert actual.x.encoding['shuffle']
- self.assertEqual(ds.x.encoding, {})
+ assert ds.x.encoding == {}
def test_encoding_chunksizes_unlimited(self):
# regression test for GH1225
@@ -1162,10 +1156,10 @@ def test_already_open_dataset(self):
v[...] = 42
nc = nc4.Dataset(tmp_file, mode='r')
- with backends.NetCDF4DataStore(nc, autoclose=False) as store:
- with open_dataset(store) as ds:
- expected = Dataset({'x': ((), 42)})
- assert_identical(expected, ds)
+ store = backends.NetCDF4DataStore(nc)
+ with open_dataset(store) as ds:
+ expected = Dataset({'x': ((), 42)})
+ assert_identical(expected, ds)
def test_read_variable_len_strings(self):
with create_tmp_file() as tmp_file:
@@ -1183,8 +1177,7 @@ def test_read_variable_len_strings(self):
@requires_netCDF4
-class NetCDF4DataTest(BaseNetCDF4Test, TestCase):
- autoclose = False
+class NetCDF4DataTest(BaseNetCDF4Test):
@contextlib.contextmanager
def create_store(self):
@@ -1201,7 +1194,7 @@ def test_variable_order(self):
ds.coords['c'] = 4
with self.roundtrip(ds) as actual:
- self.assertEqual(list(ds.variables), list(actual.variables))
+ assert list(ds.variables) == list(actual.variables)
def test_unsorted_index_raises(self):
# should be fixed in netcdf4 v1.2.1
@@ -1220,7 +1213,7 @@ def test_unsorted_index_raises(self):
try:
ds2.randovar.values
except IndexError as err:
- self.assertIn('first by calling .load', str(err))
+ assert 'first by calling .load' in str(err)
def test_88_character_filename_segmentation_fault(self):
# should be fixed in netcdf4 v1.3.1
@@ -1250,9 +1243,13 @@ def test_setncattr_string(self):
totest.attrs['bar'])
assert one_string == totest.attrs['baz']
-
-class NetCDF4DataStoreAutocloseTrue(NetCDF4DataTest):
- autoclose = True
+ def test_autoclose_future_warning(self):
+ data = create_test_data()
+ with create_tmp_file() as tmp_file:
+ self.save(data, tmp_file)
+ with pytest.warns(FutureWarning):
+ with self.open(tmp_file, autoclose=True) as actual:
+ assert_identical(data, actual)
@requires_netCDF4
@@ -1293,10 +1290,6 @@ def test_write_inconsistent_chunks(self):
assert actual['y'].encoding['chunksizes'] == (100, 50)
-class NetCDF4ViaDaskDataTestAutocloseTrue(NetCDF4ViaDaskDataTest):
- autoclose = True
-
-
@requires_zarr
class BaseZarrTest(CFEncodedDataTest):
@@ -1335,17 +1328,17 @@ def test_auto_chunk(self):
original, open_kwargs={'auto_chunk': False}) as actual:
for k, v in actual.variables.items():
# only index variables should be in memory
- self.assertEqual(v._in_memory, k in actual.dims)
+ assert v._in_memory == (k in actual.dims)
# there should be no chunks
- self.assertEqual(v.chunks, None)
+ assert v.chunks is None
with self.roundtrip(
original, open_kwargs={'auto_chunk': True}) as actual:
for k, v in actual.variables.items():
# only index variables should be in memory
- self.assertEqual(v._in_memory, k in actual.dims)
+ assert v._in_memory == (k in actual.dims)
# chunk size should be the same as original
- self.assertEqual(v.chunks, original[k].chunks)
+ assert v.chunks == original[k].chunks
def test_write_uneven_dask_chunks(self):
# regression for GH#2225
@@ -1365,7 +1358,7 @@ def test_chunk_encoding(self):
data['var2'].encoding.update({'chunks': chunks})
with self.roundtrip(data) as actual:
- self.assertEqual(chunks, actual['var2'].encoding['chunks'])
+ assert chunks == actual['var2'].encoding['chunks']
# expect an error with non-integer chunks
data['var2'].encoding.update({'chunks': (5, 4.5)})
@@ -1382,7 +1375,7 @@ def test_chunk_encoding_with_dask(self):
# zarr automatically gets chunk information from dask chunks
ds_chunk4 = ds.chunk({'x': 4})
with self.roundtrip(ds_chunk4) as actual:
- self.assertEqual((4,), actual['var1'].encoding['chunks'])
+ assert (4,) == actual['var1'].encoding['chunks']
# should fail if dask_chunks are irregular...
ds_chunk_irreg = ds.chunk({'x': (5, 4, 3)})
@@ -1395,15 +1388,14 @@ def test_chunk_encoding_with_dask(self):
# ... except if the last chunk is smaller than the first
ds_chunk_irreg = ds.chunk({'x': (5, 5, 2)})
with self.roundtrip(ds_chunk_irreg) as actual:
- self.assertEqual((5,), actual['var1'].encoding['chunks'])
+ assert (5,) == actual['var1'].encoding['chunks']
# - encoding specified -
# specify compatible encodings
for chunk_enc in 4, (4, ):
ds_chunk4['var1'].encoding.update({'chunks': chunk_enc})
with self.roundtrip(ds_chunk4) as actual:
- self.assertEqual((4,), actual['var1'].encoding['chunks'])
-
+ assert (4,) == actual['var1'].encoding['chunks']
# TODO: remove this failure once syncronized overlapping writes are
# supported by xarray
@@ -1533,14 +1525,14 @@ def test_encoding_chunksizes(self):
@requires_zarr
-class ZarrDictStoreTest(BaseZarrTest, TestCase):
+class ZarrDictStoreTest(BaseZarrTest):
@contextlib.contextmanager
def create_zarr_target(self):
yield {}
@requires_zarr
-class ZarrDirectoryStoreTest(BaseZarrTest, TestCase):
+class ZarrDirectoryStoreTest(BaseZarrTest):
@contextlib.contextmanager
def create_zarr_target(self):
with create_tmp_file(suffix='.zarr') as tmp:
@@ -1563,7 +1555,7 @@ def test_append_overwrite_values(self):
@requires_scipy
-class ScipyInMemoryDataTest(ScipyWriteTest, TestCase):
+class ScipyInMemoryDataTest(ScipyWriteTest):
engine = 'scipy'
@contextlib.contextmanager
@@ -1575,21 +1567,16 @@ def test_to_netcdf_explicit_engine(self):
# regression test for GH1321
Dataset({'foo': 42}).to_netcdf(engine='scipy')
- @pytest.mark.skipif(PY2, reason='cannot pickle BytesIO on Python 2')
- def test_bytesio_pickle(self):
+ def test_bytes_pickle(self):
data = Dataset({'foo': ('x', [1, 2, 3])})
- fobj = BytesIO(data.to_netcdf())
- with open_dataset(fobj, autoclose=self.autoclose) as ds:
+ fobj = data.to_netcdf()
+ with self.open(fobj) as ds:
unpickled = pickle.loads(pickle.dumps(ds))
assert_identical(unpickled, data)
-class ScipyInMemoryDataTestAutocloseTrue(ScipyInMemoryDataTest):
- autoclose = True
-
-
@requires_scipy
-class ScipyFileObjectTest(ScipyWriteTest, TestCase):
+class ScipyFileObjectTest(ScipyWriteTest):
engine = 'scipy'
@contextlib.contextmanager
@@ -1617,7 +1604,7 @@ def test_pickle_dataarray(self):
@requires_scipy
-class ScipyFilePathTest(ScipyWriteTest, TestCase):
+class ScipyFilePathTest(ScipyWriteTest):
engine = 'scipy'
@contextlib.contextmanager
@@ -1641,7 +1628,7 @@ def test_netcdf3_endianness(self):
# regression test for GH416
expected = open_example_dataset('bears.nc', engine='scipy')
for var in expected.variables.values():
- self.assertTrue(var.dtype.isnative)
+ assert var.dtype.isnative
@requires_netCDF4
def test_nc4_scipy(self):
@@ -1653,12 +1640,8 @@ def test_nc4_scipy(self):
open_dataset(tmp_file, engine='scipy')
-class ScipyFilePathTestAutocloseTrue(ScipyFilePathTest):
- autoclose = True
-
-
@requires_netCDF4
-class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only, TestCase):
+class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only):
engine = 'netcdf4'
file_format = 'NETCDF3_CLASSIC'
@@ -1677,13 +1660,9 @@ def test_encoding_kwarg_vlen_string(self):
pass
-class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest):
- autoclose = True
-
-
@requires_netCDF4
class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only,
- TestCase):
+ object):
engine = 'netcdf4'
file_format = 'NETCDF4_CLASSIC'
@@ -1695,13 +1674,8 @@ def create_store(self):
yield store
-class NetCDF4ClassicViaNetCDF4DataTestAutocloseTrue(
- NetCDF4ClassicViaNetCDF4DataTest):
- autoclose = True
-
-
@requires_scipy_or_netCDF4
-class GenericNetCDFDataTest(CFEncodedDataTest, NetCDF3Only, TestCase):
+class GenericNetCDFDataTest(CFEncodedDataTest, NetCDF3Only):
# verify that we can read and write netCDF3 files as long as we have scipy
# or netCDF4-python installed
file_format = 'netcdf3_64bit'
@@ -1755,34 +1729,30 @@ def test_encoding_unlimited_dims(self):
ds = Dataset({'x': ('y', np.arange(10.0))})
with self.roundtrip(ds,
save_kwargs=dict(unlimited_dims=['y'])) as actual:
- self.assertEqual(actual.encoding['unlimited_dims'], set('y'))
+ assert actual.encoding['unlimited_dims'] == set('y')
assert_equal(ds, actual)
# Regression test for https://github.com/pydata/xarray/issues/2134
with self.roundtrip(ds,
save_kwargs=dict(unlimited_dims='y')) as actual:
- self.assertEqual(actual.encoding['unlimited_dims'], set('y'))
+ assert actual.encoding['unlimited_dims'] == set('y')
assert_equal(ds, actual)
ds.encoding = {'unlimited_dims': ['y']}
with self.roundtrip(ds) as actual:
- self.assertEqual(actual.encoding['unlimited_dims'], set('y'))
+ assert actual.encoding['unlimited_dims'] == set('y')
assert_equal(ds, actual)
# Regression test for https://github.com/pydata/xarray/issues/2134
ds.encoding = {'unlimited_dims': 'y'}
with self.roundtrip(ds) as actual:
- self.assertEqual(actual.encoding['unlimited_dims'], set('y'))
+ assert actual.encoding['unlimited_dims'] == set('y')
assert_equal(ds, actual)
-class GenericNetCDFDataTestAutocloseTrue(GenericNetCDFDataTest):
- autoclose = True
-
-
@requires_h5netcdf
@requires_netCDF4
-class H5NetCDFDataTest(BaseNetCDF4Test, TestCase):
+class H5NetCDFDataTest(BaseNetCDF4Test):
engine = 'h5netcdf'
@contextlib.contextmanager
@@ -1790,10 +1760,14 @@ def create_store(self):
with create_tmp_file() as tmp_file:
yield backends.H5NetCDFStore(tmp_file, 'w')
+ @pytest.mark.filterwarnings('ignore:complex dtypes are supported by h5py')
def test_complex(self):
expected = Dataset({'x': ('y', np.ones(5) + 1j * np.ones(5))})
- with self.roundtrip(expected) as actual:
- assert_equal(expected, actual)
+ with pytest.warns(FutureWarning):
+ # TODO: make it possible to write invalid netCDF files from xarray
+ # without a warning
+ with self.roundtrip(expected) as actual:
+ assert_equal(expected, actual)
@pytest.mark.xfail(reason='https://github.com/pydata/xarray/issues/535')
def test_cross_engine_read_write_netcdf4(self):
@@ -1822,11 +1796,11 @@ def test_encoding_unlimited_dims(self):
ds = Dataset({'x': ('y', np.arange(10.0))})
with self.roundtrip(ds,
save_kwargs=dict(unlimited_dims=['y'])) as actual:
- self.assertEqual(actual.encoding['unlimited_dims'], set('y'))
+ assert actual.encoding['unlimited_dims'] == set('y')
assert_equal(ds, actual)
ds.encoding = {'unlimited_dims': ['y']}
with self.roundtrip(ds) as actual:
- self.assertEqual(actual.encoding['unlimited_dims'], set('y'))
+ assert actual.encoding['unlimited_dims'] == set('y')
assert_equal(ds, actual)
def test_compression_encoding_h5py(self):
@@ -1857,7 +1831,7 @@ def test_compression_encoding_h5py(self):
compr_out.update(compr_common)
with self.roundtrip(data) as actual:
for k, v in compr_out.items():
- self.assertEqual(v, actual['var2'].encoding[k])
+ assert v == actual['var2'].encoding[k]
def test_compression_check_encoding_h5py(self):
"""When mismatched h5py and NetCDF4-Python encodings are expressed
@@ -1898,20 +1872,14 @@ def test_dump_encodings_h5py(self):
kwargs = {'encoding': {'x': {
'compression': 'gzip', 'compression_opts': 9}}}
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
- self.assertEqual(actual.x.encoding['zlib'], True)
- self.assertEqual(actual.x.encoding['complevel'], 9)
+ assert actual.x.encoding['zlib']
+ assert actual.x.encoding['complevel'] == 9
kwargs = {'encoding': {'x': {
'compression': 'lzf', 'compression_opts': None}}}
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
- self.assertEqual(actual.x.encoding['compression'], 'lzf')
- self.assertEqual(actual.x.encoding['compression_opts'], None)
-
-
-# tests pending h5netcdf fix
-@unittest.skip
-class H5NetCDFDataTestAutocloseTrue(H5NetCDFDataTest):
- autoclose = True
+ assert actual.x.encoding['compression'] == 'lzf'
+ assert actual.x.encoding['compression_opts'] is None
@pytest.fixture(params=['scipy', 'netcdf4', 'h5netcdf', 'pynio'])
@@ -1919,14 +1887,19 @@ def readengine(request):
return request.param
-@pytest.fixture(params=[1, 100])
+@pytest.fixture(params=[1, 20])
def nfiles(request):
return request.param
-@pytest.fixture(params=[True, False])
-def autoclose(request):
- return request.param
+@pytest.fixture(params=[5, None])
+def file_cache_maxsize(request):
+ maxsize = request.param
+ if maxsize is not None:
+ with set_options(file_cache_maxsize=maxsize):
+ yield maxsize
+ else:
+ yield maxsize
@pytest.fixture(params=[True, False])
@@ -1949,8 +1922,8 @@ def skip_if_not_engine(engine):
pytest.importorskip(engine)
-def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel,
- chunks):
+def test_open_mfdataset_manyfiles(readengine, nfiles, parallel, chunks,
+ file_cache_maxsize):
# skip certain combinations
skip_if_not_engine(readengine)
@@ -1958,9 +1931,6 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel,
if not has_dask and parallel:
pytest.skip('parallel requires dask')
- if readengine == 'h5netcdf' and autoclose:
- pytest.skip('h5netcdf does not support autoclose yet')
-
if ON_WINDOWS:
pytest.skip('Skipping on Windows')
@@ -1976,7 +1946,7 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel,
# check that calculation on opened datasets works properly
actual = open_mfdataset(tmpfiles, engine=readengine, parallel=parallel,
- autoclose=autoclose, chunks=chunks)
+ chunks=chunks)
# check that using open_mfdataset returns dask arrays for variables
assert isinstance(actual['foo'].data, dask_array_type)
@@ -1985,7 +1955,7 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel,
@requires_scipy_or_netCDF4
-class OpenMFDatasetWithDataVarsAndCoordsKwTest(TestCase):
+class OpenMFDatasetWithDataVarsAndCoordsKwTest(object):
coord_name = 'lon'
var_name = 'v1'
@@ -2056,9 +2026,9 @@ def test_common_coord_when_datavars_all(self):
var_shape = ds[self.var_name].shape
- self.assertEqual(var_shape, coord_shape)
- self.assertNotEqual(coord_shape1, coord_shape)
- self.assertNotEqual(coord_shape2, coord_shape)
+ assert var_shape == coord_shape
+ assert coord_shape1 != coord_shape
+ assert coord_shape2 != coord_shape
def test_common_coord_when_datavars_minimal(self):
opt = 'minimal'
@@ -2073,9 +2043,9 @@ def test_common_coord_when_datavars_minimal(self):
var_shape = ds[self.var_name].shape
- self.assertNotEqual(var_shape, coord_shape)
- self.assertEqual(coord_shape1, coord_shape)
- self.assertEqual(coord_shape2, coord_shape)
+ assert var_shape != coord_shape
+ assert coord_shape1 == coord_shape
+ assert coord_shape2 == coord_shape
def test_invalid_data_vars_value_should_fail(self):
@@ -2093,7 +2063,7 @@ def test_invalid_data_vars_value_should_fail(self):
@requires_dask
@requires_scipy
@requires_netCDF4
-class DaskTest(TestCase, DatasetIOTestCases):
+class DaskTest(DatasetIOTestCases):
@contextlib.contextmanager
def create_store(self):
yield Dataset()
@@ -2133,10 +2103,10 @@ def test_roundtrip_cftime_datetime_data_enable_cftimeindex(self):
with xr.set_options(enable_cftimeindex=True):
with self.roundtrip(expected) as actual:
abs_diff = abs(actual.t.values - expected_decoded_t)
- self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all())
+ assert (abs_diff <= np.timedelta64(1, 's')).all()
abs_diff = abs(actual.t0.values - expected_decoded_t0)
- self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all())
+ assert (abs_diff <= np.timedelta64(1, 's')).all()
def test_roundtrip_cftime_datetime_data_disable_cftimeindex(self):
# Override method in DatasetIOTestCases - remove not applicable
@@ -2153,10 +2123,10 @@ def test_roundtrip_cftime_datetime_data_disable_cftimeindex(self):
with xr.set_options(enable_cftimeindex=False):
with self.roundtrip(expected) as actual:
abs_diff = abs(actual.t.values - expected_decoded_t)
- self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all())
+ assert (abs_diff <= np.timedelta64(1, 's')).all()
abs_diff = abs(actual.t0.values - expected_decoded_t0)
- self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all())
+ assert (abs_diff <= np.timedelta64(1, 's')).all()
def test_write_store(self):
# Override method in DatasetIOTestCases - not applicable to dask
@@ -2175,22 +2145,20 @@ def test_open_mfdataset(self):
with create_tmp_file() as tmp2:
original.isel(x=slice(5)).to_netcdf(tmp1)
original.isel(x=slice(5, 10)).to_netcdf(tmp2)
- with open_mfdataset([tmp1, tmp2],
- autoclose=self.autoclose) as actual:
- self.assertIsInstance(actual.foo.variable.data, da.Array)
- self.assertEqual(actual.foo.variable.data.chunks,
- ((5, 5),))
+ with open_mfdataset([tmp1, tmp2]) as actual:
+ assert isinstance(actual.foo.variable.data, da.Array)
+ assert actual.foo.variable.data.chunks == \
+ ((5, 5),)
assert_identical(original, actual)
- with open_mfdataset([tmp1, tmp2], chunks={'x': 3},
- autoclose=self.autoclose) as actual:
- self.assertEqual(actual.foo.variable.data.chunks,
- ((3, 2, 3, 2),))
+ with open_mfdataset([tmp1, tmp2], chunks={'x': 3}) as actual:
+ assert actual.foo.variable.data.chunks == \
+ ((3, 2, 3, 2),)
with raises_regex(IOError, 'no files to open'):
- open_mfdataset('foo-bar-baz-*.nc', autoclose=self.autoclose)
+ open_mfdataset('foo-bar-baz-*.nc')
with raises_regex(ValueError, 'wild-card'):
- open_mfdataset('http://some/remote/uri', autoclose=self.autoclose)
+ open_mfdataset('http://some/remote/uri')
@requires_pathlib
def test_open_mfdataset_pathlib(self):
@@ -2201,8 +2169,7 @@ def test_open_mfdataset_pathlib(self):
tmp2 = Path(tmp2)
original.isel(x=slice(5)).to_netcdf(tmp1)
original.isel(x=slice(5, 10)).to_netcdf(tmp2)
- with open_mfdataset([tmp1, tmp2],
- autoclose=self.autoclose) as actual:
+ with open_mfdataset([tmp1, tmp2]) as actual:
assert_identical(original, actual)
def test_attrs_mfdataset(self):
@@ -2218,7 +2185,7 @@ def test_attrs_mfdataset(self):
with open_mfdataset([tmp1, tmp2]) as actual:
# presumes that attributes inherited from
# first dataset loaded
- self.assertEqual(actual.test1, ds1.test1)
+ assert actual.test1 == ds1.test1
# attributes from ds2 are not retained, e.g.,
with raises_regex(AttributeError,
'no attribute'):
@@ -2233,8 +2200,7 @@ def preprocess(ds):
return ds.assign_coords(z=0)
expected = preprocess(original)
- with open_mfdataset(tmp, preprocess=preprocess,
- autoclose=self.autoclose) as actual:
+ with open_mfdataset(tmp, preprocess=preprocess) as actual:
assert_identical(expected, actual)
def test_save_mfdataset_roundtrip(self):
@@ -2244,8 +2210,7 @@ def test_save_mfdataset_roundtrip(self):
with create_tmp_file() as tmp1:
with create_tmp_file() as tmp2:
save_mfdataset(datasets, [tmp1, tmp2])
- with open_mfdataset([tmp1, tmp2],
- autoclose=self.autoclose) as actual:
+ with open_mfdataset([tmp1, tmp2]) as actual:
assert_identical(actual, original)
def test_save_mfdataset_invalid(self):
@@ -2271,15 +2236,14 @@ def test_save_mfdataset_pathlib_roundtrip(self):
tmp1 = Path(tmp1)
tmp2 = Path(tmp2)
save_mfdataset(datasets, [tmp1, tmp2])
- with open_mfdataset([tmp1, tmp2],
- autoclose=self.autoclose) as actual:
+ with open_mfdataset([tmp1, tmp2]) as actual:
assert_identical(actual, original)
def test_open_and_do_math(self):
original = Dataset({'foo': ('x', np.random.randn(10))})
with create_tmp_file() as tmp:
original.to_netcdf(tmp)
- with open_mfdataset(tmp, autoclose=self.autoclose) as ds:
+ with open_mfdataset(tmp) as ds:
actual = 1.0 * ds
assert_allclose(original, actual, decode_bytes=False)
@@ -2289,8 +2253,7 @@ def test_open_mfdataset_concat_dim_none(self):
data = Dataset({'x': 0})
data.to_netcdf(tmp1)
Dataset({'x': np.nan}).to_netcdf(tmp2)
- with open_mfdataset([tmp1, tmp2], concat_dim=None,
- autoclose=self.autoclose) as actual:
+ with open_mfdataset([tmp1, tmp2], concat_dim=None) as actual:
assert_identical(data, actual)
def test_open_dataset(self):
@@ -2298,13 +2261,13 @@ def test_open_dataset(self):
with create_tmp_file() as tmp:
original.to_netcdf(tmp)
with open_dataset(tmp, chunks={'x': 5}) as actual:
- self.assertIsInstance(actual.foo.variable.data, da.Array)
- self.assertEqual(actual.foo.variable.data.chunks, ((5, 5),))
+ assert isinstance(actual.foo.variable.data, da.Array)
+ assert actual.foo.variable.data.chunks == ((5, 5),)
assert_identical(original, actual)
with open_dataset(tmp, chunks=5) as actual:
assert_identical(original, actual)
with open_dataset(tmp) as actual:
- self.assertIsInstance(actual.foo.variable.data, np.ndarray)
+ assert isinstance(actual.foo.variable.data, np.ndarray)
assert_identical(original, actual)
def test_open_single_dataset(self):
@@ -2317,8 +2280,7 @@ def test_open_single_dataset(self):
{'baz': [100]})
with create_tmp_file() as tmp:
original.to_netcdf(tmp)
- with open_mfdataset([tmp], concat_dim=dim,
- autoclose=self.autoclose) as actual:
+ with open_mfdataset([tmp], concat_dim=dim) as actual:
assert_identical(expected, actual)
def test_dask_roundtrip(self):
@@ -2337,16 +2299,16 @@ def test_deterministic_names(self):
with create_tmp_file() as tmp:
data = create_test_data()
data.to_netcdf(tmp)
- with open_mfdataset(tmp, autoclose=self.autoclose) as ds:
+ with open_mfdataset(tmp) as ds:
original_names = dict((k, v.data.name)
for k, v in ds.data_vars.items())
- with open_mfdataset(tmp, autoclose=self.autoclose) as ds:
+ with open_mfdataset(tmp) as ds:
repeat_names = dict((k, v.data.name)
for k, v in ds.data_vars.items())
for var_name, dask_name in original_names.items():
- self.assertIn(var_name, dask_name)
- self.assertEqual(dask_name[:13], 'open_dataset-')
- self.assertEqual(original_names, repeat_names)
+ assert var_name in dask_name
+ assert dask_name[:13] == 'open_dataset-'
+ assert original_names == repeat_names
def test_dataarray_compute(self):
# Test DataArray.compute() on dask backend.
@@ -2354,48 +2316,29 @@ def test_dataarray_compute(self):
# however dask is the only tested backend which supports DataArrays
actual = DataArray([1, 2]).chunk()
computed = actual.compute()
- self.assertFalse(actual._in_memory)
- self.assertTrue(computed._in_memory)
+ assert not actual._in_memory
+ assert computed._in_memory
assert_allclose(actual, computed, decode_bytes=False)
- def test_to_netcdf_compute_false_roundtrip(self):
- from dask.delayed import Delayed
-
- original = create_test_data().chunk()
-
- with create_tmp_file() as tmp_file:
- # dataset, path, **kwargs):
- delayed_obj = self.save(original, tmp_file, compute=False)
- assert isinstance(delayed_obj, Delayed)
- delayed_obj.compute()
-
- with self.open(tmp_file) as actual:
- assert_identical(original, actual)
-
def test_save_mfdataset_compute_false_roundtrip(self):
from dask.delayed import Delayed
original = Dataset({'foo': ('x', np.random.randn(10))}).chunk()
datasets = [original.isel(x=slice(5)),
original.isel(x=slice(5, 10))]
- with create_tmp_file() as tmp1:
- with create_tmp_file() as tmp2:
+ with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp1:
+ with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp2:
delayed_obj = save_mfdataset(datasets, [tmp1, tmp2],
engine=self.engine, compute=False)
assert isinstance(delayed_obj, Delayed)
delayed_obj.compute()
- with open_mfdataset([tmp1, tmp2],
- autoclose=self.autoclose) as actual:
+ with open_mfdataset([tmp1, tmp2]) as actual:
assert_identical(actual, original)
-class DaskTestAutocloseTrue(DaskTest):
- autoclose = True
-
-
@requires_scipy_or_netCDF4
@requires_pydap
-class PydapTest(TestCase):
+class PydapTest(object):
def convert_to_pydap_dataset(self, original):
from pydap.model import GridType, BaseType, DatasetType
ds = DatasetType('bears', **original.attrs)
@@ -2427,8 +2370,8 @@ def test_cmp_local_file(self):
assert_equal(actual, expected)
# global attributes should be global attributes on the dataset
- self.assertNotIn('NC_GLOBAL', actual.attrs)
- self.assertIn('history', actual.attrs)
+ assert 'NC_GLOBAL' not in actual.attrs
+ assert 'history' in actual.attrs
# we don't check attributes exactly with assertDatasetIdentical()
# because the test DAP server seems to insert some extra
@@ -2436,8 +2379,7 @@ def test_cmp_local_file(self):
assert actual.attrs.keys() == expected.attrs.keys()
with self.create_datasets() as (actual, expected):
- assert_equal(
- actual.isel(l=2), expected.isel(l=2)) # noqa: E741
+ assert_equal(actual.isel(l=2), expected.isel(l=2)) # noqa
with self.create_datasets() as (actual, expected):
assert_equal(actual.isel(i=0, j=-1),
@@ -2497,15 +2439,14 @@ def test_session(self):
@requires_scipy
@requires_pynio
-class PyNioTest(ScipyWriteTest, TestCase):
+class PyNioTest(ScipyWriteTest):
def test_write_store(self):
# pynio is read-only for now
pass
@contextlib.contextmanager
def open(self, path, **kwargs):
- with open_dataset(path, engine='pynio', autoclose=self.autoclose,
- **kwargs) as ds:
+ with open_dataset(path, engine='pynio', **kwargs) as ds:
yield ds
def save(self, dataset, path, **kwargs):
@@ -2523,18 +2464,12 @@ def test_weakrefs(self):
assert_identical(actual, expected)
-class PyNioTestAutocloseTrue(PyNioTest):
- autoclose = True
-
-
@requires_pseudonetcdf
-class PseudoNetCDFFormatTest(TestCase):
- autoclose = True
+@pytest.mark.filterwarnings('ignore:IOAPI_ISPH is assumed to be 6370000')
+class PseudoNetCDFFormatTest(object):
def open(self, path, **kwargs):
- return open_dataset(path, engine='pseudonetcdf',
- autoclose=self.autoclose,
- **kwargs)
+ return open_dataset(path, engine='pseudonetcdf', **kwargs)
@contextlib.contextmanager
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
@@ -2551,7 +2486,6 @@ def test_ict_format(self):
"""
ictfile = open_example_dataset('example.ict',
engine='pseudonetcdf',
- autoclose=False,
backend_kwargs={'format': 'ffi1001'})
stdattr = {
'fill_value': -9999.0,
@@ -2649,7 +2583,6 @@ def test_ict_format_write(self):
fmtkw = {'format': 'ffi1001'}
expected = open_example_dataset('example.ict',
engine='pseudonetcdf',
- autoclose=False,
backend_kwargs=fmtkw)
with self.roundtrip(expected, save_kwargs=fmtkw,
open_kwargs={'backend_kwargs': fmtkw}) as actual:
@@ -2659,14 +2592,10 @@ def test_uamiv_format_read(self):
"""
Open a CAMx file and test data variables
"""
- with warnings.catch_warnings():
- warnings.filterwarnings('ignore', category=UserWarning,
- message=('IOAPI_ISPH is assumed to be ' +
- '6370000.; consistent with WRF'))
- camxfile = open_example_dataset('example.uamiv',
- engine='pseudonetcdf',
- autoclose=True,
- backend_kwargs={'format': 'uamiv'})
+
+ camxfile = open_example_dataset('example.uamiv',
+ engine='pseudonetcdf',
+ backend_kwargs={'format': 'uamiv'})
data = np.arange(20, dtype='f').reshape(1, 1, 4, 5)
expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data,
dict(units='ppm', long_name='O3'.ljust(16),
@@ -2688,17 +2617,13 @@ def test_uamiv_format_mfread(self):
"""
Open a CAMx file and test data variables
"""
- with warnings.catch_warnings():
- warnings.filterwarnings('ignore', category=UserWarning,
- message=('IOAPI_ISPH is assumed to be ' +
- '6370000.; consistent with WRF'))
- camxfile = open_example_mfdataset(
- ['example.uamiv',
- 'example.uamiv'],
- engine='pseudonetcdf',
- autoclose=True,
- concat_dim='TSTEP',
- backend_kwargs={'format': 'uamiv'})
+
+ camxfile = open_example_mfdataset(
+ ['example.uamiv',
+ 'example.uamiv'],
+ engine='pseudonetcdf',
+ concat_dim='TSTEP',
+ backend_kwargs={'format': 'uamiv'})
data1 = np.arange(20, dtype='f').reshape(1, 1, 4, 5)
data = np.concatenate([data1] * 2, axis=0)
@@ -2710,30 +2635,28 @@ def test_uamiv_format_mfread(self):
data1 = np.array(['2002-06-03'], 'datetime64[ns]')
data = np.concatenate([data1] * 2, axis=0)
- expected = xr.Variable(('TSTEP',), data,
- dict(bounds='time_bounds',
- long_name=('synthesized time coordinate ' +
- 'from SDATE, STIME, STEP ' +
- 'global attributes')))
+ attrs = dict(bounds='time_bounds',
+ long_name=('synthesized time coordinate ' +
+ 'from SDATE, STIME, STEP ' +
+ 'global attributes'))
+ expected = xr.Variable(('TSTEP',), data, attrs)
actual = camxfile.variables['time']
assert_allclose(expected, actual)
camxfile.close()
def test_uamiv_format_write(self):
fmtkw = {'format': 'uamiv'}
- with warnings.catch_warnings():
- warnings.filterwarnings('ignore', category=UserWarning,
- message=('IOAPI_ISPH is assumed to be ' +
- '6370000.; consistent with WRF'))
- expected = open_example_dataset('example.uamiv',
- engine='pseudonetcdf',
- autoclose=False,
- backend_kwargs=fmtkw)
+
+ expected = open_example_dataset('example.uamiv',
+ engine='pseudonetcdf',
+ backend_kwargs=fmtkw)
with self.roundtrip(expected,
save_kwargs=fmtkw,
open_kwargs={'backend_kwargs': fmtkw}) as actual:
assert_identical(expected, actual)
+ expected.close()
+
def save(self, dataset, path, **save_kwargs):
import PseudoNetCDF as pnc
pncf = pnc.PseudoNetCDFFile()
@@ -2798,7 +2721,7 @@ def create_tmp_geotiff(nx=4, ny=3, nz=3,
@requires_rasterio
-class TestRasterio(TestCase):
+class TestRasterio(object):
@requires_scipy_or_netCDF4
def test_serialization(self):
@@ -2843,7 +2766,8 @@ def test_non_rectilinear(self):
assert len(rioda.attrs['transform']) == 6
# See if a warning is raised if we force it
- with self.assertWarns("transformation isn't rectilinear"):
+ with pytest.warns(Warning,
+ match="transformation isn't rectilinear"):
with xr.open_rasterio(tmp_file,
parse_coordinates=True) as rioda:
assert 'x' not in rioda.coords
@@ -2934,6 +2858,10 @@ def test_indexing(self):
assert_allclose(expected.isel(**ind), actual.isel(**ind))
assert not actual.variable._in_memory
+ ind = {'band': 0, 'x': np.array([0, 0]), 'y': np.array([1, 1, 1])}
+ assert_allclose(expected.isel(**ind), actual.isel(**ind))
+ assert not actual.variable._in_memory
+
# minus-stepped slice
ind = {'band': np.array([2, 1, 0]),
'x': slice(-1, None, -1), 'y': 0}
@@ -3030,7 +2958,7 @@ def test_chunks(self):
with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual:
import dask.array as da
- self.assertIsInstance(actual.data, da.Array)
+ assert isinstance(actual.data, da.Array)
assert 'open_rasterio' in actual.data.name
# do some arithmetic
@@ -3111,7 +3039,7 @@ def test_no_mftime(self):
with mock.patch('os.path.getmtime', side_effect=OSError):
with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual:
import dask.array as da
- self.assertIsInstance(actual.data, da.Array)
+ assert isinstance(actual.data, da.Array)
assert_allclose(actual, expected)
@network
@@ -3124,10 +3052,10 @@ def test_http_url(self):
# make sure chunking works
with xr.open_rasterio(url, chunks=(1, 256, 256)) as actual:
import dask.array as da
- self.assertIsInstance(actual.data, da.Array)
+ assert isinstance(actual.data, da.Array)
-class TestEncodingInvalid(TestCase):
+class TestEncodingInvalid(object):
def test_extract_nc4_variable_encoding(self):
var = xr.Variable(('x',), [1, 2, 3], {}, {'foo': 'bar'})
@@ -3136,12 +3064,12 @@ def test_extract_nc4_variable_encoding(self):
var = xr.Variable(('x',), [1, 2, 3], {}, {'chunking': (2, 1)})
encoding = _extract_nc4_variable_encoding(var)
- self.assertEqual({}, encoding)
+ assert {} == encoding
# regression test
var = xr.Variable(('x',), [1, 2, 3], {}, {'shuffle': True})
encoding = _extract_nc4_variable_encoding(var, raise_on_invalid=True)
- self.assertEqual({'shuffle': True}, encoding)
+ assert {'shuffle': True} == encoding
def test_extract_h5nc_encoding(self):
# not supported with h5netcdf (yet)
@@ -3156,7 +3084,7 @@ class MiscObject:
@requires_netCDF4
-class TestValidateAttrs(TestCase):
+class TestValidateAttrs(object):
def test_validating_attrs(self):
def new_dataset():
return Dataset({'data': ('y', np.arange(10.0))},
@@ -3256,7 +3184,7 @@ def new_dataset_and_coord_attrs():
@requires_scipy_or_netCDF4
-class TestDataArrayToNetCDF(TestCase):
+class TestDataArrayToNetCDF(object):
def test_dataarray_to_netcdf_no_name(self):
original_da = DataArray(np.arange(12).reshape((3, 4)))
@@ -3317,32 +3245,6 @@ def test_dataarray_to_netcdf_no_name_pathlib(self):
assert_identical(original_da, loaded_da)
-def test_pickle_reconstructor():
-
- lines = ['foo bar spam eggs']
-
- with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp:
- with open(tmp, 'w') as f:
- f.writelines(lines)
-
- obj = PickleByReconstructionWrapper(open, tmp)
-
- assert obj.value.readlines() == lines
-
- p_obj = pickle.dumps(obj)
- obj.value.close() # for windows
- obj2 = pickle.loads(p_obj)
-
- assert obj2.value.readlines() == lines
-
- # roundtrip again to make sure we can fully restore the state
- p_obj2 = pickle.dumps(obj2)
- obj2.value.close() # for windows
- obj3 = pickle.loads(p_obj2)
-
- assert obj3.value.readlines() == lines
-
-
@requires_scipy_or_netCDF4
def test_no_warning_from_dask_effective_get():
with create_tmp_file() as tmpfile:
diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py
new file mode 100644
index 00000000000..591c981cd45
--- /dev/null
+++ b/xarray/tests/test_backends_file_manager.py
@@ -0,0 +1,114 @@
+import pickle
+import threading
+try:
+ from unittest import mock
+except ImportError:
+ import mock # noqa: F401
+
+import pytest
+
+from xarray.backends.file_manager import CachingFileManager
+from xarray.backends.lru_cache import LRUCache
+
+
+@pytest.fixture(params=[1, 2, 3, None])
+def file_cache(request):
+ maxsize = request.param
+ if maxsize is None:
+ yield {}
+ else:
+ yield LRUCache(maxsize)
+
+
+def test_file_manager_mock_write(file_cache):
+ mock_file = mock.Mock()
+ opener = mock.Mock(spec=open, return_value=mock_file)
+ lock = mock.MagicMock(spec=threading.Lock())
+
+ manager = CachingFileManager(
+ opener, 'filename', lock=lock, cache=file_cache)
+ f = manager.acquire()
+ f.write('contents')
+ manager.close()
+
+ assert not file_cache
+ opener.assert_called_once_with('filename')
+ mock_file.write.assert_called_once_with('contents')
+ mock_file.close.assert_called_once_with()
+ lock.__enter__.assert_has_calls([mock.call(), mock.call()])
+
+
+def test_file_manager_write_consecutive(tmpdir, file_cache):
+ path1 = str(tmpdir.join('testing1.txt'))
+ path2 = str(tmpdir.join('testing2.txt'))
+ manager1 = CachingFileManager(open, path1, mode='w', cache=file_cache)
+ manager2 = CachingFileManager(open, path2, mode='w', cache=file_cache)
+ f1a = manager1.acquire()
+ f1a.write('foo')
+ f1a.flush()
+ f2 = manager2.acquire()
+ f2.write('bar')
+ f2.flush()
+ f1b = manager1.acquire()
+ f1b.write('baz')
+ assert (getattr(file_cache, 'maxsize', float('inf')) > 1) == (f1a is f1b)
+ manager1.close()
+ manager2.close()
+
+ with open(path1, 'r') as f:
+ assert f.read() == 'foobaz'
+ with open(path2, 'r') as f:
+ assert f.read() == 'bar'
+
+
+def test_file_manager_write_concurrent(tmpdir, file_cache):
+ path = str(tmpdir.join('testing.txt'))
+ manager = CachingFileManager(open, path, mode='w', cache=file_cache)
+ f1 = manager.acquire()
+ f2 = manager.acquire()
+ f3 = manager.acquire()
+ assert f1 is f2
+ assert f2 is f3
+ f1.write('foo')
+ f1.flush()
+ f2.write('bar')
+ f2.flush()
+ f3.write('baz')
+ f3.flush()
+ manager.close()
+
+ with open(path, 'r') as f:
+ assert f.read() == 'foobarbaz'
+
+
+def test_file_manager_write_pickle(tmpdir, file_cache):
+ path = str(tmpdir.join('testing.txt'))
+ manager = CachingFileManager(open, path, mode='w', cache=file_cache)
+ f = manager.acquire()
+ f.write('foo')
+ f.flush()
+ manager2 = pickle.loads(pickle.dumps(manager))
+ f2 = manager2.acquire()
+ f2.write('bar')
+ manager2.close()
+ manager.close()
+
+ with open(path, 'r') as f:
+ assert f.read() == 'foobar'
+
+
+def test_file_manager_read(tmpdir, file_cache):
+ path = str(tmpdir.join('testing.txt'))
+
+ with open(path, 'w') as f:
+ f.write('foobar')
+
+ manager = CachingFileManager(open, path, cache=file_cache)
+ f = manager.acquire()
+ assert f.read() == 'foobar'
+ manager.close()
+
+
+def test_file_manager_invalid_kwargs():
+ with pytest.raises(TypeError):
+ CachingFileManager(open, 'dummy', mode='w', invalid=True)
diff --git a/xarray/tests/test_backends_locks.py b/xarray/tests/test_backends_locks.py
new file mode 100644
index 00000000000..5f83321802e
--- /dev/null
+++ b/xarray/tests/test_backends_locks.py
@@ -0,0 +1,13 @@
+import threading
+
+from xarray.backends import locks
+
+
+def test_threaded_lock():
+ lock1 = locks._get_threaded_lock('foo')
+ assert isinstance(lock1, type(threading.Lock()))
+ lock2 = locks._get_threaded_lock('foo')
+ assert lock1 is lock2
+
+ lock3 = locks._get_threaded_lock('bar')
+ assert lock1 is not lock3
diff --git a/xarray/tests/test_backends_lru_cache.py b/xarray/tests/test_backends_lru_cache.py
new file mode 100644
index 00000000000..03eb6dcf208
--- /dev/null
+++ b/xarray/tests/test_backends_lru_cache.py
@@ -0,0 +1,91 @@
+try:
+ from unittest import mock
+except ImportError:
+ import mock # noqa: F401
+
+import pytest
+
+from xarray.backends.lru_cache import LRUCache
+
+
+def test_simple():
+ cache = LRUCache(maxsize=2)
+ cache['x'] = 1
+ cache['y'] = 2
+
+ assert cache['x'] == 1
+ assert cache['y'] == 2
+ assert len(cache) == 2
+ assert dict(cache) == {'x': 1, 'y': 2}
+ assert list(cache.keys()) == ['x', 'y']
+ assert list(cache.items()) == [('x', 1), ('y', 2)]
+
+ cache['z'] = 3
+ assert len(cache) == 2
+ assert list(cache.items()) == [('y', 2), ('z', 3)]
+
+
+def test_trivial():
+ cache = LRUCache(maxsize=0)
+ cache['x'] = 1
+ assert len(cache) == 0
+
+
+def test_invalid():
+ with pytest.raises(TypeError):
+ LRUCache(maxsize=None)
+ with pytest.raises(ValueError):
+ LRUCache(maxsize=-1)
+
+
+def test_update_priority():
+ cache = LRUCache(maxsize=2)
+ cache['x'] = 1
+ cache['y'] = 2
+ assert list(cache) == ['x', 'y']
+ assert 'x' in cache # contains
+ assert list(cache) == ['y', 'x']
+ assert cache['y'] == 2 # getitem
+ assert list(cache) == ['x', 'y']
+ cache['x'] = 3 # setitem
+ assert list(cache.items()) == [('y', 2), ('x', 3)]
+
+
+def test_del():
+ cache = LRUCache(maxsize=2)
+ cache['x'] = 1
+ cache['y'] = 2
+ del cache['x']
+ assert dict(cache) == {'y': 2}
+
+
+def test_on_evict():
+ on_evict = mock.Mock()
+ cache = LRUCache(maxsize=1, on_evict=on_evict)
+ cache['x'] = 1
+ cache['y'] = 2
+ on_evict.assert_called_once_with('x', 1)
+
+
+def test_on_evict_trivial():
+ on_evict = mock.Mock()
+ cache = LRUCache(maxsize=0, on_evict=on_evict)
+ cache['x'] = 1
+ on_evict.assert_called_once_with('x', 1)
+
+
+def test_resize():
+ cache = LRUCache(maxsize=2)
+ assert cache.maxsize == 2
+ cache['w'] = 0
+ cache['x'] = 1
+ cache['y'] = 2
+ assert list(cache.items()) == [('x', 1), ('y', 2)]
+ cache.maxsize = 10
+ cache['z'] = 3
+ assert list(cache.items()) == [('x', 1), ('y', 2), ('z', 3)]
+ cache.maxsize = 1
+ assert list(cache.items()) == [('z', 3)]
+
+ with pytest.raises(ValueError):
+ cache.maxsize = -1
diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py
new file mode 100644
index 00000000000..7acd764cab3
--- /dev/null
+++ b/xarray/tests/test_cftime_offsets.py
@@ -0,0 +1,799 @@
+from itertools import product
+
+import numpy as np
+import pytest
+
+from xarray import CFTimeIndex
+from xarray.coding.cftime_offsets import (
+ _MONTH_ABBREVIATIONS, BaseCFTimeOffset, Day, Hour, Minute, MonthBegin,
+ MonthEnd, Second, YearBegin, YearEnd, _days_in_month, cftime_range,
+ get_date_type, to_cftime_datetime, to_offset)
+
+cftime = pytest.importorskip('cftime')
+
+
+_CFTIME_CALENDARS = ['365_day', '360_day', 'julian', 'all_leap',
+ '366_day', 'gregorian', 'proleptic_gregorian', 'standard']
+
+
+def _id_func(param):
+ """Called on each parameter passed to pytest.mark.parametrize"""
+ return str(param)
+
+
+@pytest.fixture(params=_CFTIME_CALENDARS)
+def calendar(request):
+ return request.param
+
+
+@pytest.mark.parametrize(
+ ('offset', 'expected_n'),
+ [(BaseCFTimeOffset(), 1),
+ (YearBegin(), 1),
+ (YearEnd(), 1),
+ (BaseCFTimeOffset(n=2), 2),
+ (YearBegin(n=2), 2),
+ (YearEnd(n=2), 2)],
+ ids=_id_func
+)
+def test_cftime_offset_constructor_valid_n(offset, expected_n):
+ assert offset.n == expected_n
+
+
+@pytest.mark.parametrize(
+ ('offset', 'invalid_n'),
+ [(BaseCFTimeOffset, 1.5),
+ (YearBegin, 1.5),
+ (YearEnd, 1.5)],
+ ids=_id_func
+)
+def test_cftime_offset_constructor_invalid_n(offset, invalid_n):
+ with pytest.raises(TypeError):
+ offset(n=invalid_n)
+
+
+@pytest.mark.parametrize(
+ ('offset', 'expected_month'),
+ [(YearBegin(), 1),
+ (YearEnd(), 12),
+ (YearBegin(month=5), 5),
+ (YearEnd(month=5), 5)],
+ ids=_id_func
+)
+def test_year_offset_constructor_valid_month(offset, expected_month):
+ assert offset.month == expected_month
+
+
+@pytest.mark.parametrize(
+ ('offset', 'invalid_month', 'exception'),
+ [(YearBegin, 0, ValueError),
+ (YearEnd, 0, ValueError),
+ (YearBegin, 13, ValueError,),
+ (YearEnd, 13, ValueError),
+ (YearBegin, 1.5, TypeError),
+ (YearEnd, 1.5, TypeError)],
+ ids=_id_func
+)
+def test_year_offset_constructor_invalid_month(
+ offset, invalid_month, exception):
+ with pytest.raises(exception):
+ offset(month=invalid_month)
+
+
+@pytest.mark.parametrize(
+ ('offset', 'expected'),
+ [(BaseCFTimeOffset(), None),
+ (MonthBegin(), 'MS'),
+ (YearBegin(), 'AS-JAN')],
+ ids=_id_func
+)
+def test_rule_code(offset, expected):
+ assert offset.rule_code() == expected
+
+
+@pytest.mark.parametrize(
+ ('offset', 'expected'),
+ [(BaseCFTimeOffset(), ''),
+ (YearBegin(), '')],
+ ids=_id_func
+)
+def test_str_and_repr(offset, expected):
+ assert str(offset) == expected
+ assert repr(offset) == expected
+
+
+@pytest.mark.parametrize(
+ 'offset',
+ [BaseCFTimeOffset(), MonthBegin(), YearBegin()],
+ ids=_id_func
+)
+def test_to_offset_offset_input(offset):
+ assert to_offset(offset) == offset
+
+
+@pytest.mark.parametrize(
+ ('freq', 'expected'),
+ [('M', MonthEnd()),
+ ('2M', MonthEnd(n=2)),
+ ('MS', MonthBegin()),
+ ('2MS', MonthBegin(n=2)),
+ ('D', Day()),
+ ('2D', Day(n=2)),
+ ('H', Hour()),
+ ('2H', Hour(n=2)),
+ ('T', Minute()),
+ ('2T', Minute(n=2)),
+ ('min', Minute()),
+ ('2min', Minute(n=2)),
+ ('S', Second()),
+ ('2S', Second(n=2))],
+ ids=_id_func
+)
+def test_to_offset_sub_annual(freq, expected):
+ assert to_offset(freq) == expected
+
+
+_ANNUAL_OFFSET_TYPES = {
+ 'A': YearEnd,
+ 'AS': YearBegin
+}
+
+
+@pytest.mark.parametrize(('month_int', 'month_label'),
+ list(_MONTH_ABBREVIATIONS.items()) + [('', '')])
+@pytest.mark.parametrize('multiple', [None, 2])
+@pytest.mark.parametrize('offset_str', ['AS', 'A'])
+def test_to_offset_annual(month_label, month_int, multiple, offset_str):
+ freq = offset_str
+ offset_type = _ANNUAL_OFFSET_TYPES[offset_str]
+ if month_label:
+ freq = '-'.join([freq, month_label])
+ if multiple:
+ freq = '{}'.format(multiple) + freq
+ result = to_offset(freq)
+
+ if multiple and month_int:
+ expected = offset_type(n=multiple, month=month_int)
+ elif multiple:
+ expected = offset_type(n=multiple)
+ elif month_int:
+ expected = offset_type(month=month_int)
+ else:
+ expected = offset_type()
+ assert result == expected
+
+
+@pytest.mark.parametrize('freq', ['Z', '7min2', 'AM', 'M-', 'AS-', '1H1min'])
+def test_invalid_to_offset_str(freq):
+ with pytest.raises(ValueError):
+ to_offset(freq)
+
+
+@pytest.mark.parametrize(
+ ('argument', 'expected_date_args'),
+ [('2000-01-01', (2000, 1, 1)),
+ ((2000, 1, 1), (2000, 1, 1))],
+ ids=_id_func
+)
+def test_to_cftime_datetime(calendar, argument, expected_date_args):
+ date_type = get_date_type(calendar)
+ expected = date_type(*expected_date_args)
+ if isinstance(argument, tuple):
+ argument = date_type(*argument)
+ result = to_cftime_datetime(argument, calendar=calendar)
+ assert result == expected
+
+
+def test_to_cftime_datetime_error_no_calendar():
+ with pytest.raises(ValueError):
+ to_cftime_datetime('2000')
+
+
+def test_to_cftime_datetime_error_type_error():
+ with pytest.raises(TypeError):
+ to_cftime_datetime(1)
+
+
+_EQ_TESTS_A = [
+ BaseCFTimeOffset(), YearBegin(), YearEnd(), YearBegin(month=2),
+ YearEnd(month=2), MonthBegin(), MonthEnd(), Day(), Hour(), Minute(),
+ Second()
+]
+_EQ_TESTS_B = [
+ BaseCFTimeOffset(n=2), YearBegin(n=2), YearEnd(n=2),
+ YearBegin(n=2, month=2), YearEnd(n=2, month=2), MonthBegin(n=2),
+ MonthEnd(n=2), Day(n=2), Hour(n=2), Minute(n=2), Second(n=2)
+]
+
+
+@pytest.mark.parametrize(
+ ('a', 'b'), product(_EQ_TESTS_A, _EQ_TESTS_B), ids=_id_func
+)
+def test_neq(a, b):
+ assert a != b
+
+
+_EQ_TESTS_B_COPY = [
+ BaseCFTimeOffset(n=2), YearBegin(n=2), YearEnd(n=2),
+ YearBegin(n=2, month=2), YearEnd(n=2, month=2), MonthBegin(n=2),
+ MonthEnd(n=2), Day(n=2), Hour(n=2), Minute(n=2), Second(n=2)
+]
+
+
+@pytest.mark.parametrize(
+ ('a', 'b'), zip(_EQ_TESTS_B, _EQ_TESTS_B_COPY), ids=_id_func
+)
+def test_eq(a, b):
+ assert a == b
+
+
+_MUL_TESTS = [
+ (BaseCFTimeOffset(), BaseCFTimeOffset(n=3)),
+ (YearEnd(), YearEnd(n=3)),
+ (YearBegin(), YearBegin(n=3)),
+ (MonthEnd(), MonthEnd(n=3)),
+ (MonthBegin(), MonthBegin(n=3)),
+ (Day(), Day(n=3)),
+ (Hour(), Hour(n=3)),
+ (Minute(), Minute(n=3)),
+ (Second(), Second(n=3))
+]
+
+
+@pytest.mark.parametrize(('offset', 'expected'), _MUL_TESTS, ids=_id_func)
+def test_mul(offset, expected):
+ assert offset * 3 == expected
+
+
+@pytest.mark.parametrize(('offset', 'expected'), _MUL_TESTS, ids=_id_func)
+def test_rmul(offset, expected):
+ assert 3 * offset == expected
+
+
+@pytest.mark.parametrize(
+ ('offset', 'expected'),
+ [(BaseCFTimeOffset(), BaseCFTimeOffset(n=-1)),
+ (YearEnd(), YearEnd(n=-1)),
+ (YearBegin(), YearBegin(n=-1)),
+ (MonthEnd(), MonthEnd(n=-1)),
+ (MonthBegin(), MonthBegin(n=-1)),
+ (Day(), Day(n=-1)),
+ (Hour(), Hour(n=-1)),
+ (Minute(), Minute(n=-1)),
+ (Second(), Second(n=-1))],
+ ids=_id_func)
+def test_neg(offset, expected):
+ assert -offset == expected
+
+
+_ADD_TESTS = [
+ (Day(n=2), (1, 1, 3)),
+ (Hour(n=2), (1, 1, 1, 2)),
+ (Minute(n=2), (1, 1, 1, 0, 2)),
+ (Second(n=2), (1, 1, 1, 0, 0, 2))
+]
+
+
+@pytest.mark.parametrize(
+ ('offset', 'expected_date_args'),
+ _ADD_TESTS,
+ ids=_id_func
+)
+def test_add_sub_monthly(offset, expected_date_args, calendar):
+ date_type = get_date_type(calendar)
+ initial = date_type(1, 1, 1)
+ expected = date_type(*expected_date_args)
+ result = offset + initial
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ ('offset', 'expected_date_args'),
+ _ADD_TESTS,
+ ids=_id_func
+)
+def test_radd_sub_monthly(offset, expected_date_args, calendar):
+ date_type = get_date_type(calendar)
+ initial = date_type(1, 1, 1)
+ expected = date_type(*expected_date_args)
+ result = initial + offset
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ ('offset', 'expected_date_args'),
+ [(Day(n=2), (1, 1, 1)),
+ (Hour(n=2), (1, 1, 2, 22)),
+ (Minute(n=2), (1, 1, 2, 23, 58)),
+ (Second(n=2), (1, 1, 2, 23, 59, 58))],
+ ids=_id_func
+)
+def test_rsub_sub_monthly(offset, expected_date_args, calendar):
+ date_type = get_date_type(calendar)
+ initial = date_type(1, 1, 3)
+ expected = date_type(*expected_date_args)
+ result = initial - offset
+ assert result == expected
+
+
+@pytest.mark.parametrize('offset', _EQ_TESTS_A, ids=_id_func)
+def test_sub_error(offset, calendar):
+ date_type = get_date_type(calendar)
+ initial = date_type(1, 1, 1)
+ with pytest.raises(TypeError):
+ offset - initial
+
+
+@pytest.mark.parametrize(
+ ('a', 'b'),
+ zip(_EQ_TESTS_A, _EQ_TESTS_B),
+ ids=_id_func
+)
+def test_minus_offset(a, b):
+ result = b - a
+ expected = a
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ ('a', 'b'),
+ list(zip(np.roll(_EQ_TESTS_A, 1), _EQ_TESTS_B)) +
+ [(YearEnd(month=1), YearEnd(month=2))],
+ ids=_id_func
+)
+def test_minus_offset_error(a, b):
+ with pytest.raises(TypeError):
+ b - a
+
+
+def test_days_in_month_non_december(calendar):
+ date_type = get_date_type(calendar)
+ reference = date_type(1, 4, 1)
+ assert _days_in_month(reference) == 30
+
+
+def test_days_in_month_december(calendar):
+ if calendar == '360_day':
+ expected = 30
+ else:
+ expected = 31
+ date_type = get_date_type(calendar)
+ reference = date_type(1, 12, 5)
+ assert _days_in_month(reference) == expected
+
+
+@pytest.mark.parametrize(
+ ('initial_date_args', 'offset', 'expected_date_args'),
+ [((1, 1, 1), MonthBegin(), (1, 2, 1)),
+ ((1, 1, 1), MonthBegin(n=2), (1, 3, 1)),
+ ((1, 1, 7), MonthBegin(), (1, 2, 1)),
+ ((1, 1, 7), MonthBegin(n=2), (1, 3, 1)),
+ ((1, 3, 1), MonthBegin(n=-1), (1, 2, 1)),
+ ((1, 3, 1), MonthBegin(n=-2), (1, 1, 1)),
+ ((1, 3, 3), MonthBegin(n=-1), (1, 3, 1)),
+ ((1, 3, 3), MonthBegin(n=-2), (1, 2, 1)),
+ ((1, 2, 1), MonthBegin(n=14), (2, 4, 1)),
+ ((2, 4, 1), MonthBegin(n=-14), (1, 2, 1)),
+ ((1, 1, 1, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)),
+ ((1, 1, 3, 5, 5, 5, 5), MonthBegin(), (1, 2, 1, 5, 5, 5, 5)),
+ ((1, 1, 3, 5, 5, 5, 5), MonthBegin(n=-1), (1, 1, 1, 5, 5, 5, 5))],
+ ids=_id_func
+)
+def test_add_month_begin(
+ calendar, initial_date_args, offset, expected_date_args):
+ date_type = get_date_type(calendar)
+ initial = date_type(*initial_date_args)
+ result = initial + offset
+ expected = date_type(*expected_date_args)
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ ('initial_date_args', 'offset', 'expected_year_month',
+ 'expected_sub_day'),
+ [((1, 1, 1), MonthEnd(), (1, 1), ()),
+ ((1, 1, 1), MonthEnd(n=2), (1, 2), ()),
+ ((1, 3, 1), MonthEnd(n=-1), (1, 2), ()),
+ ((1, 3, 1), MonthEnd(n=-2), (1, 1), ()),
+ ((1, 2, 1), MonthEnd(n=14), (2, 3), ()),
+ ((2, 4, 1), MonthEnd(n=-14), (1, 2), ()),
+ ((1, 1, 1, 5, 5, 5, 5), MonthEnd(), (1, 1), (5, 5, 5, 5)),
+ ((1, 2, 1, 5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5))],
+ ids=_id_func
+)
+def test_add_month_end(
+ calendar, initial_date_args, offset, expected_year_month,
+ expected_sub_day
+):
+ date_type = get_date_type(calendar)
+ initial = date_type(*initial_date_args)
+ result = initial + offset
+ reference_args = expected_year_month + (1,)
+ reference = date_type(*reference_args)
+
+ # Here the days at the end of each month varies based on the calendar used
+ expected_date_args = (expected_year_month +
+ (_days_in_month(reference),) + expected_sub_day)
+ expected = date_type(*expected_date_args)
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ ('initial_year_month', 'initial_sub_day', 'offset', 'expected_year_month',
+ 'expected_sub_day'),
+ [((1, 1), (), MonthEnd(), (1, 2), ()),
+ ((1, 1), (), MonthEnd(n=2), (1, 3), ()),
+ ((1, 3), (), MonthEnd(n=-1), (1, 2), ()),
+ ((1, 3), (), MonthEnd(n=-2), (1, 1), ()),
+ ((1, 2), (), MonthEnd(n=14), (2, 4), ()),
+ ((2, 4), (), MonthEnd(n=-14), (1, 2), ()),
+ ((1, 1), (5, 5, 5, 5), MonthEnd(), (1, 2), (5, 5, 5, 5)),
+ ((1, 2), (5, 5, 5, 5), MonthEnd(n=-1), (1, 1), (5, 5, 5, 5))],
+ ids=_id_func
+)
+def test_add_month_end_onOffset(
+ calendar, initial_year_month, initial_sub_day, offset, expected_year_month,
+ expected_sub_day
+):
+ date_type = get_date_type(calendar)
+ reference_args = initial_year_month + (1,)
+ reference = date_type(*reference_args)
+ initial_date_args = (initial_year_month + (_days_in_month(reference),) +
+ initial_sub_day)
+ initial = date_type(*initial_date_args)
+ result = initial + offset
+ reference_args = expected_year_month + (1,)
+ reference = date_type(*reference_args)
+
+ # Here the days at the end of each month varies based on the calendar used
+ expected_date_args = (expected_year_month +
+ (_days_in_month(reference),) + expected_sub_day)
+ expected = date_type(*expected_date_args)
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ ('initial_date_args', 'offset', 'expected_date_args'),
+ [((1, 1, 1), YearBegin(), (2, 1, 1)),
+ ((1, 1, 1), YearBegin(n=2), (3, 1, 1)),
+ ((1, 1, 1), YearBegin(month=2), (1, 2, 1)),
+ ((1, 1, 7), YearBegin(n=2), (3, 1, 1)),
+ ((2, 2, 1), YearBegin(n=-1), (2, 1, 1)),
+ ((1, 1, 2), YearBegin(n=-1), (1, 1, 1)),
+ ((1, 1, 1, 5, 5, 5, 5), YearBegin(), (2, 1, 1, 5, 5, 5, 5)),
+ ((2, 1, 1, 5, 5, 5, 5), YearBegin(n=-1), (1, 1, 1, 5, 5, 5, 5))],
+ ids=_id_func
+)
+def test_add_year_begin(calendar, initial_date_args, offset,
+ expected_date_args):
+ date_type = get_date_type(calendar)
+ initial = date_type(*initial_date_args)
+ result = initial + offset
+ expected = date_type(*expected_date_args)
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ ('initial_date_args', 'offset', 'expected_year_month',
+ 'expected_sub_day'),
+ [((1, 1, 1), YearEnd(), (1, 12), ()),
+ ((1, 1, 1), YearEnd(n=2), (2, 12), ()),
+ ((1, 1, 1), YearEnd(month=1), (1, 1), ()),
+ ((2, 3, 1), YearEnd(n=-1), (1, 12), ()),
+ ((1, 3, 1), YearEnd(n=-1, month=2), (1, 2), ()),
+ ((1, 1, 1, 5, 5, 5, 5), YearEnd(), (1, 12), (5, 5, 5, 5)),
+ ((1, 1, 1, 5, 5, 5, 5), YearEnd(n=2), (2, 12), (5, 5, 5, 5))],
+ ids=_id_func
+)
+def test_add_year_end(
+ calendar, initial_date_args, offset, expected_year_month,
+ expected_sub_day
+):
+ date_type = get_date_type(calendar)
+ initial = date_type(*initial_date_args)
+ result = initial + offset
+ reference_args = expected_year_month + (1,)
+ reference = date_type(*reference_args)
+
+ # Here the days at the end of each month varies based on the calendar used
+ expected_date_args = (expected_year_month +
+ (_days_in_month(reference),) + expected_sub_day)
+ expected = date_type(*expected_date_args)
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ ('initial_year_month', 'initial_sub_day', 'offset', 'expected_year_month',
+ 'expected_sub_day'),
+ [((1, 12), (), YearEnd(), (2, 12), ()),
+ ((1, 12), (), YearEnd(n=2), (3, 12), ()),
+ ((2, 12), (), YearEnd(n=-1), (1, 12), ()),
+ ((3, 12), (), YearEnd(n=-2), (1, 12), ()),
+ ((1, 1), (), YearEnd(month=2), (1, 2), ()),
+ ((1, 12), (5, 5, 5, 5), YearEnd(), (2, 12), (5, 5, 5, 5)),
+ ((2, 12), (5, 5, 5, 5), YearEnd(n=-1), (1, 12), (5, 5, 5, 5))],
+ ids=_id_func
+)
+def test_add_year_end_onOffset(
+ calendar, initial_year_month, initial_sub_day, offset, expected_year_month,
+ expected_sub_day
+):
+ date_type = get_date_type(calendar)
+ reference_args = initial_year_month + (1,)
+ reference = date_type(*reference_args)
+ initial_date_args = (initial_year_month + (_days_in_month(reference),) +
+ initial_sub_day)
+ initial = date_type(*initial_date_args)
+ result = initial + offset
+ reference_args = expected_year_month + (1,)
+ reference = date_type(*reference_args)
+
+ # Here the days at the end of each month varies based on the calendar used
+ expected_date_args = (expected_year_month +
+ (_days_in_month(reference),) + expected_sub_day)
+ expected = date_type(*expected_date_args)
+ assert result == expected
+
+
+# Note for all sub-monthly offsets, pandas always returns True for onOffset
+@pytest.mark.parametrize(
+ ('date_args', 'offset', 'expected'),
+ [((1, 1, 1), MonthBegin(), True),
+ ((1, 1, 1, 1), MonthBegin(), True),
+ ((1, 1, 5), MonthBegin(), False),
+ ((1, 1, 5), MonthEnd(), False),
+ ((1, 1, 1), YearBegin(), True),
+ ((1, 1, 1, 1), YearBegin(), True),
+ ((1, 1, 5), YearBegin(), False),
+ ((1, 12, 1), YearEnd(), False),
+ ((1, 1, 1), Day(), True),
+ ((1, 1, 1, 1), Day(), True),
+ ((1, 1, 1), Hour(), True),
+ ((1, 1, 1), Minute(), True),
+ ((1, 1, 1), Second(), True)],
+ ids=_id_func
+)
+def test_onOffset(calendar, date_args, offset, expected):
+ date_type = get_date_type(calendar)
+ date = date_type(*date_args)
+ result = offset.onOffset(date)
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ ('year_month_args', 'sub_day_args', 'offset'),
+ [((1, 1), (), MonthEnd()),
+ ((1, 1), (1,), MonthEnd()),
+ ((1, 12), (), YearEnd()),
+ ((1, 1), (), YearEnd(month=1))],
+ ids=_id_func
+)
+def test_onOffset_month_or_year_end(
+ calendar, year_month_args, sub_day_args, offset):
+ date_type = get_date_type(calendar)
+ reference_args = year_month_args + (1,)
+ reference = date_type(*reference_args)
+ date_args = year_month_args + (_days_in_month(reference),) + sub_day_args
+ date = date_type(*date_args)
+ result = offset.onOffset(date)
+ assert result
+
+
+@pytest.mark.parametrize(
+ ('offset', 'initial_date_args', 'partial_expected_date_args'),
+ [(YearBegin(), (1, 3, 1), (2, 1)),
+ (YearBegin(), (1, 1, 1), (1, 1)),
+ (YearBegin(n=2), (1, 3, 1), (2, 1)),
+ (YearBegin(n=2, month=2), (1, 3, 1), (2, 2)),
+ (YearEnd(), (1, 3, 1), (1, 12)),
+ (YearEnd(n=2), (1, 3, 1), (1, 12)),
+ (YearEnd(n=2, month=2), (1, 3, 1), (2, 2)),
+ (YearEnd(n=2, month=4), (1, 4, 30), (1, 4)),
+ (MonthBegin(), (1, 3, 2), (1, 4)),
+ (MonthBegin(), (1, 3, 1), (1, 3)),
+ (MonthBegin(n=2), (1, 3, 2), (1, 4)),
+ (MonthEnd(), (1, 3, 2), (1, 3)),
+ (MonthEnd(), (1, 4, 30), (1, 4)),
+ (MonthEnd(n=2), (1, 3, 2), (1, 3)),
+ (Day(), (1, 3, 2, 1), (1, 3, 2, 1)),
+ (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)),
+ (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)),
+ (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1))],
+ ids=_id_func
+)
+def test_rollforward(calendar, offset, initial_date_args,
+ partial_expected_date_args):
+ date_type = get_date_type(calendar)
+ initial = date_type(*initial_date_args)
+ if isinstance(offset, (MonthBegin, YearBegin)):
+ expected_date_args = partial_expected_date_args + (1,)
+ elif isinstance(offset, (MonthEnd, YearEnd)):
+ reference_args = partial_expected_date_args + (1,)
+ reference = date_type(*reference_args)
+ expected_date_args = (partial_expected_date_args +
+ (_days_in_month(reference),))
+ else:
+ expected_date_args = partial_expected_date_args
+ expected = date_type(*expected_date_args)
+ result = offset.rollforward(initial)
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ ('offset', 'initial_date_args', 'partial_expected_date_args'),
+ [(YearBegin(), (1, 3, 1), (1, 1)),
+ (YearBegin(n=2), (1, 3, 1), (1, 1)),
+ (YearBegin(n=2, month=2), (1, 3, 1), (1, 2)),
+ (YearBegin(), (1, 1, 1), (1, 1)),
+ (YearBegin(n=2, month=2), (1, 2, 1), (1, 2)),
+ (YearEnd(), (2, 3, 1), (1, 12)),
+ (YearEnd(n=2), (2, 3, 1), (1, 12)),
+ (YearEnd(n=2, month=2), (2, 3, 1), (2, 2)),
+ (YearEnd(month=4), (1, 4, 30), (1, 4)),
+ (MonthBegin(), (1, 3, 2), (1, 3)),
+ (MonthBegin(n=2), (1, 3, 2), (1, 3)),
+ (MonthBegin(), (1, 3, 1), (1, 3)),
+ (MonthEnd(), (1, 3, 2), (1, 2)),
+ (MonthEnd(n=2), (1, 3, 2), (1, 2)),
+ (MonthEnd(), (1, 4, 30), (1, 4)),
+ (Day(), (1, 3, 2, 1), (1, 3, 2, 1)),
+ (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)),
+ (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)),
+ (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1))],
+ ids=_id_func
+)
+def test_rollback(calendar, offset, initial_date_args,
+ partial_expected_date_args):
+ date_type = get_date_type(calendar)
+ initial = date_type(*initial_date_args)
+ if isinstance(offset, (MonthBegin, YearBegin)):
+ expected_date_args = partial_expected_date_args + (1,)
+ elif isinstance(offset, (MonthEnd, YearEnd)):
+ reference_args = partial_expected_date_args + (1,)
+ reference = date_type(*reference_args)
+ expected_date_args = (partial_expected_date_args +
+ (_days_in_month(reference),))
+ else:
+ expected_date_args = partial_expected_date_args
+ expected = date_type(*expected_date_args)
+ result = offset.rollback(initial)
+ assert result == expected
+
+
+_CFTIME_RANGE_TESTS = [
+ ('0001-01-01', '0001-01-04', None, 'D', None, False,
+ [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]),
+ ('0001-01-01', '0001-01-04', None, 'D', 'left', False,
+ [(1, 1, 1), (1, 1, 2), (1, 1, 3)]),
+ ('0001-01-01', '0001-01-04', None, 'D', 'right', False,
+ [(1, 1, 2), (1, 1, 3), (1, 1, 4)]),
+ ('0001-01-01T01:00:00', '0001-01-04', None, 'D', None, False,
+ [(1, 1, 1, 1), (1, 1, 2, 1), (1, 1, 3, 1)]),
+ ('0001-01-01T01:00:00', '0001-01-04', None, 'D', None, True,
+ [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]),
+ ('0001-01-01', None, 4, 'D', None, False,
+ [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]),
+ (None, '0001-01-04', 4, 'D', None, False,
+ [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]),
+ ((1, 1, 1), '0001-01-04', None, 'D', None, False,
+ [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]),
+ ((1, 1, 1), (1, 1, 4), None, 'D', None, False,
+ [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)]),
+ ('0001-01-30', '0011-02-01', None, '3AS-JUN', None, False,
+ [(1, 6, 1), (4, 6, 1), (7, 6, 1), (10, 6, 1)]),
+ ('0001-01-04', '0001-01-01', None, 'D', None, False,
+ []),
+ ('0010', None, 4, YearBegin(n=-2), None, False,
+ [(10, 1, 1), (8, 1, 1), (6, 1, 1), (4, 1, 1)]),
+ ('0001-01-01', '0001-01-04', 4, None, None, False,
+ [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4)])
+]
+
+
+@pytest.mark.parametrize(
+ ('start', 'end', 'periods', 'freq', 'closed', 'normalize',
+ 'expected_date_args'),
+ _CFTIME_RANGE_TESTS, ids=_id_func
+)
+def test_cftime_range(
+ start, end, periods, freq, closed, normalize, calendar,
+ expected_date_args):
+ date_type = get_date_type(calendar)
+ expected_dates = [date_type(*args) for args in expected_date_args]
+
+ if isinstance(start, tuple):
+ start = date_type(*start)
+ if isinstance(end, tuple):
+ end = date_type(*end)
+
+ result = cftime_range(
+ start=start, end=end, periods=periods, freq=freq, closed=closed,
+ normalize=normalize, calendar=calendar)
+ resulting_dates = result.values
+
+ assert isinstance(result, CFTimeIndex)
+
+ if freq is not None:
+ np.testing.assert_equal(resulting_dates, expected_dates)
+ else:
+ # If we create a linear range of dates using cftime.num2date
+ # we will not get exact round number dates. This is because
+ # datetime arithmetic in cftime is accurate approximately to
+ # 1 millisecond (see https://unidata.github.io/cftime/api.html).
+ deltas = resulting_dates - expected_dates
+ deltas = np.array([delta.total_seconds() for delta in deltas])
+ assert np.max(np.abs(deltas)) < 0.001
+
+
+def test_cftime_range_name():
+ result = cftime_range(start='2000', periods=4, name='foo')
+ assert result.name == 'foo'
+
+ result = cftime_range(start='2000', periods=4)
+ assert result.name is None
+
+
+@pytest.mark.parametrize(
+ ('start', 'end', 'periods', 'freq', 'closed'),
+ [(None, None, 5, 'A', None),
+ ('2000', None, None, 'A', None),
+ (None, '2000', None, 'A', None),
+ ('2000', '2001', None, None, None),
+ (None, None, None, None, None),
+ ('2000', '2001', None, 'A', 'up'),
+ ('2000', '2001', 5, 'A', None)]
+)
+def test_invalid_cftime_range_inputs(start, end, periods, freq, closed):
+ with pytest.raises(ValueError):
+ cftime_range(start, end, periods, freq, closed=closed)
+
+
+_CALENDAR_SPECIFIC_MONTH_END_TESTS = [
+ ('2M', 'noleap',
+ [(2, 28), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]),
+ ('2M', 'all_leap',
+ [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]),
+ ('2M', '360_day',
+ [(2, 30), (4, 30), (6, 30), (8, 30), (10, 30), (12, 30)]),
+ ('2M', 'standard',
+ [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]),
+ ('2M', 'gregorian',
+ [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]),
+ ('2M', 'julian',
+ [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)])
+]
+
+
+@pytest.mark.parametrize(
+ ('freq', 'calendar', 'expected_month_day'),
+ _CALENDAR_SPECIFIC_MONTH_END_TESTS, ids=_id_func
+)
+def test_calendar_specific_month_end(freq, calendar, expected_month_day):
+ year = 2000 # Use a leap-year to highlight calendar differences
+ result = cftime_range(
+ start='2000-02', end='2001', freq=freq, calendar=calendar).values
+ date_type = get_date_type(calendar)
+ expected = [date_type(year, *args) for args in expected_month_day]
+ np.testing.assert_equal(result, expected)
+
+
+@pytest.mark.parametrize(
+ ('calendar', 'start', 'end', 'expected_number_of_days'),
+ [('noleap', '2000', '2001', 365),
+ ('all_leap', '2000', '2001', 366),
+ ('360_day', '2000', '2001', 360),
+ ('standard', '2000', '2001', 366),
+ ('gregorian', '2000', '2001', 366),
+ ('julian', '2000', '2001', 366),
+ ('noleap', '2001', '2002', 365),
+ ('all_leap', '2001', '2002', 366),
+ ('360_day', '2001', '2002', 360),
+ ('standard', '2001', '2002', 365),
+ ('gregorian', '2001', '2002', 365),
+ ('julian', '2001', '2002', 365)]
+)
+def test_calendar_year_length(
+ calendar, start, end, expected_number_of_days):
+ result = cftime_range(start, end, freq='D', closed='left',
+ calendar=calendar)
+ assert len(result) == expected_number_of_days
diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py
index 6f102b60b9d..d1726ab3313 100644
--- a/xarray/tests/test_cftimeindex.py
+++ b/xarray/tests/test_cftimeindex.py
@@ -1,17 +1,18 @@
from __future__ import absolute_import
-import pytest
+from datetime import timedelta
+import numpy as np
import pandas as pd
-import xarray as xr
+import pytest
-from datetime import timedelta
+import xarray as xr
from xarray.coding.cftimeindex import (
- parse_iso8601, CFTimeIndex, assert_all_valid_date_type,
- _parsed_string_to_bounds, _parse_iso8601_with_reso)
+ CFTimeIndex, _parse_array_of_cftime_strings, _parse_iso8601_with_reso,
+ _parsed_string_to_bounds, assert_all_valid_date_type, parse_iso8601)
from xarray.tests import assert_array_equal, assert_identical
-from . import has_cftime, has_cftime_or_netCDF4
+from . import has_cftime, has_cftime_or_netCDF4, requires_cftime
from .test_coding_times import _all_cftime_date_types
@@ -121,22 +122,42 @@ def dec_days(date_type):
return 31
+@pytest.fixture
+def index_with_name(date_type):
+ dates = [date_type(1, 1, 1), date_type(1, 2, 1),
+ date_type(2, 1, 1), date_type(2, 2, 1)]
+ return CFTimeIndex(dates, name='foo')
+
+
+@pytest.mark.skipif(not has_cftime, reason='cftime not installed')
+@pytest.mark.parametrize(
+ ('name', 'expected_name'),
+ [('bar', 'bar'),
+ (None, 'foo')])
+def test_constructor_with_name(index_with_name, name, expected_name):
+ result = CFTimeIndex(index_with_name, name=name).name
+ assert result == expected_name
+
+
@pytest.mark.skipif(not has_cftime, reason='cftime not installed')
def test_assert_all_valid_date_type(date_type, index):
import cftime
if date_type is cftime.DatetimeNoLeap:
- mixed_date_types = [date_type(1, 1, 1),
- cftime.DatetimeAllLeap(1, 2, 1)]
+ mixed_date_types = np.array(
+ [date_type(1, 1, 1),
+ cftime.DatetimeAllLeap(1, 2, 1)])
else:
- mixed_date_types = [date_type(1, 1, 1),
- cftime.DatetimeNoLeap(1, 2, 1)]
+ mixed_date_types = np.array(
+ [date_type(1, 1, 1),
+ cftime.DatetimeNoLeap(1, 2, 1)])
with pytest.raises(TypeError):
assert_all_valid_date_type(mixed_date_types)
with pytest.raises(TypeError):
- assert_all_valid_date_type([1, date_type(1, 1, 1)])
+ assert_all_valid_date_type(np.array([1, date_type(1, 1, 1)]))
- assert_all_valid_date_type([date_type(1, 1, 1), date_type(1, 2, 1)])
+ assert_all_valid_date_type(
+ np.array([date_type(1, 1, 1), date_type(1, 2, 1)]))
@pytest.mark.skipif(not has_cftime, reason='cftime not installed')
@@ -589,3 +610,95 @@ def test_concat_cftimeindex(date_type, enable_cftimeindex):
else:
assert isinstance(da.indexes['time'], pd.Index)
assert not isinstance(da.indexes['time'], CFTimeIndex)
+
+
+@pytest.mark.skipif(not has_cftime, reason='cftime not installed')
+def test_empty_cftimeindex():
+ index = CFTimeIndex([])
+ assert index.date_type is None
+
+
+@pytest.mark.skipif(not has_cftime, reason='cftime not installed')
+def test_cftimeindex_add(index):
+ date_type = index.date_type
+ expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2),
+ date_type(2, 1, 2), date_type(2, 2, 2)]
+ expected = CFTimeIndex(expected_dates)
+ result = index + timedelta(days=1)
+ assert result.equals(expected)
+ assert isinstance(result, CFTimeIndex)
+
+
+@pytest.mark.skipif(not has_cftime, reason='cftime not installed')
+def test_cftimeindex_radd(index):
+ date_type = index.date_type
+ expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2),
+ date_type(2, 1, 2), date_type(2, 2, 2)]
+ expected = CFTimeIndex(expected_dates)
+ result = timedelta(days=1) + index
+ assert result.equals(expected)
+ assert isinstance(result, CFTimeIndex)
+
+
+@pytest.mark.skipif(not has_cftime, reason='cftime not installed')
+def test_cftimeindex_sub(index):
+ date_type = index.date_type
+ expected_dates = [date_type(1, 1, 2), date_type(1, 2, 2),
+ date_type(2, 1, 2), date_type(2, 2, 2)]
+ expected = CFTimeIndex(expected_dates)
+ result = index + timedelta(days=2)
+ result = result - timedelta(days=1)
+ assert result.equals(expected)
+ assert isinstance(result, CFTimeIndex)
+
+
+@pytest.mark.skipif(not has_cftime, reason='cftime not installed')
+def test_cftimeindex_rsub(index):
+ with pytest.raises(TypeError):
+ timedelta(days=1) - index
+
+
+@pytest.mark.skipif(not has_cftime, reason='cftime not installed')
+@pytest.mark.parametrize('freq', ['D', timedelta(days=1)])
+def test_cftimeindex_shift(index, freq):
+ date_type = index.date_type
+ expected_dates = [date_type(1, 1, 3), date_type(1, 2, 3),
+ date_type(2, 1, 3), date_type(2, 2, 3)]
+ expected = CFTimeIndex(expected_dates)
+ result = index.shift(2, freq)
+ assert result.equals(expected)
+ assert isinstance(result, CFTimeIndex)
+
+
+@pytest.mark.skipif(not has_cftime, reason='cftime not installed')
+def test_cftimeindex_shift_invalid_n():
+ index = xr.cftime_range('2000', periods=3)
+ with pytest.raises(TypeError):
+ index.shift('a', 'D')
+
+
+@pytest.mark.skipif(not has_cftime, reason='cftime not installed')
+def test_cftimeindex_shift_invalid_freq():
+ index = xr.cftime_range('2000', periods=3)
+ with pytest.raises(TypeError):
+ index.shift(1, 1)
+
+
+@requires_cftime
+def test_parse_array_of_cftime_strings():
+ from cftime import DatetimeNoLeap
+
+ strings = np.array([['2000-01-01', '2000-01-02'],
+ ['2000-01-03', '2000-01-04']])
+ expected = np.array(
+ [[DatetimeNoLeap(2000, 1, 1), DatetimeNoLeap(2000, 1, 2)],
+ [DatetimeNoLeap(2000, 1, 3), DatetimeNoLeap(2000, 1, 4)]])
+
+ result = _parse_array_of_cftime_strings(strings, DatetimeNoLeap)
+ np.testing.assert_array_equal(result, expected)
+
+ # Test scalar array case
+ strings = np.array('2000-01-01')
+ expected = np.array(DatetimeNoLeap(2000, 1, 1))
+ result = _parse_array_of_cftime_strings(strings, DatetimeNoLeap)
+ np.testing.assert_array_equal(result, expected)
diff --git a/xarray/tests/test_coding_strings.py b/xarray/tests/test_coding_strings.py
index 53d028e164b..ca138ca8362 100644
--- a/xarray/tests/test_coding_strings.py
+++ b/xarray/tests/test_coding_strings.py
@@ -5,13 +5,13 @@
import pytest
from xarray import Variable
-from xarray.core.pycompat import bytes_type, unicode_type, suppress
from xarray.coding import strings
from xarray.core import indexing
+from xarray.core.pycompat import bytes_type, suppress, unicode_type
-from . import (IndexerMaker, assert_array_equal, assert_identical,
- raises_regex, requires_dask)
-
+from . import (
+ IndexerMaker, assert_array_equal, assert_identical, raises_regex,
+ requires_dask)
with suppress(ImportError):
import dask.array as da
diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py
index e763af4984c..10a1a956b27 100644
--- a/xarray/tests/test_coding_times.py
+++ b/xarray/tests/test_coding_times.py
@@ -1,20 +1,20 @@
from __future__ import absolute_import, division, print_function
-from itertools import product
import warnings
+from itertools import product
import numpy as np
import pandas as pd
import pytest
-from xarray import Variable, coding, set_options, DataArray, decode_cf
+from xarray import DataArray, Variable, coding, decode_cf, set_options
from xarray.coding.times import _import_cftime
from xarray.coding.variables import SerializationWarning
from xarray.core.common import contains_cftime_datetimes
-from . import (assert_array_equal, has_cftime_or_netCDF4,
- requires_cftime_or_netCDF4, has_cftime, has_dask)
-
+from . import (
+ assert_array_equal, has_cftime, has_cftime_or_netCDF4, has_dask,
+ requires_cftime_or_netCDF4)
_NON_STANDARD_CALENDARS_SET = {'noleap', '365_day', '360_day',
'julian', 'all_leap', '366_day'}
@@ -538,7 +538,8 @@ def test_cf_datetime_nan(num_dates, units, expected_list):
with warnings.catch_warnings():
warnings.filterwarnings('ignore', 'All-NaN')
actual = coding.times.decode_cf_datetime(num_dates, units)
- expected = np.array(expected_list, dtype='datetime64[ns]')
+ # use pandas because numpy will deprecate timezone-aware conversions
+ expected = pd.to_datetime(expected_list)
assert_array_equal(expected, actual)
diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py
index 482a280b355..2004b1e660f 100644
--- a/xarray/tests/test_combine.py
+++ b/xarray/tests/test_combine.py
@@ -10,12 +10,12 @@
from xarray.core.pycompat import OrderedDict, iteritems
from . import (
- InaccessibleArray, TestCase, assert_array_equal, assert_equal,
- assert_identical, raises_regex, requires_dask)
+ InaccessibleArray, assert_array_equal, assert_equal, assert_identical,
+ raises_regex, requires_dask)
from .test_dataset import create_test_data
-class TestConcatDataset(TestCase):
+class TestConcatDataset(object):
def test_concat(self):
# TODO: simplify and split this test case
@@ -235,7 +235,7 @@ def test_concat_multiindex(self):
assert isinstance(actual.x.to_index(), pd.MultiIndex)
-class TestConcatDataArray(TestCase):
+class TestConcatDataArray(object):
def test_concat(self):
ds = Dataset({'foo': (['x', 'y'], np.random.random((2, 3))),
'bar': (['x', 'y'], np.random.random((2, 3)))},
@@ -295,7 +295,7 @@ def test_concat_lazy(self):
assert combined.dims == ('z', 'x', 'y')
-class TestAutoCombine(TestCase):
+class TestAutoCombine(object):
@requires_dask # only for toolz
def test_auto_combine(self):
diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py
index ca8e4e59737..1003c531018 100644
--- a/xarray/tests/test_computation.py
+++ b/xarray/tests/test_computation.py
@@ -15,7 +15,7 @@
join_dict_keys, ordered_set_intersection, ordered_set_union, result_name,
unified_dim_sizes)
-from . import raises_regex, requires_dask, has_dask
+from . import has_dask, raises_regex, requires_dask
def assert_identical(a, b):
diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py
index 5ed482ed2bd..a067d01a308 100644
--- a/xarray/tests/test_conventions.py
+++ b/xarray/tests/test_conventions.py
@@ -8,20 +8,20 @@
import pandas as pd
import pytest
-from xarray import (Dataset, Variable, SerializationWarning, coding,
- conventions, open_dataset)
+from xarray import (
+ Dataset, SerializationWarning, Variable, coding, conventions, open_dataset)
from xarray.backends.common import WritableCFDataStore
from xarray.backends.memory import InMemoryDataStore
from xarray.conventions import decode_cf
from xarray.testing import assert_identical
from . import (
- TestCase, assert_array_equal, raises_regex, requires_netCDF4,
- requires_cftime_or_netCDF4, unittest, requires_dask)
+ assert_array_equal, raises_regex, requires_cftime_or_netCDF4,
+ requires_dask, requires_netCDF4)
from .test_backends import CFEncodedDataTest
-class TestBoolTypeArray(TestCase):
+class TestBoolTypeArray(object):
def test_booltype_array(self):
x = np.array([1, 0, 1, 1, 0], dtype='i1')
bx = conventions.BoolTypeArray(x)
@@ -30,7 +30,7 @@ def test_booltype_array(self):
dtype=np.bool))
-class TestNativeEndiannessArray(TestCase):
+class TestNativeEndiannessArray(object):
def test(self):
x = np.arange(5, dtype='>i8')
expected = np.arange(5, dtype='int64')
@@ -69,7 +69,7 @@ def test_decode_cf_with_conflicting_fill_missing_value():
@requires_cftime_or_netCDF4
-class TestEncodeCFVariable(TestCase):
+class TestEncodeCFVariable(object):
def test_incompatible_attributes(self):
invalid_vars = [
Variable(['t'], pd.date_range('2000-01-01', periods=3),
@@ -134,7 +134,7 @@ def test_string_object_warning(self):
@requires_cftime_or_netCDF4
-class TestDecodeCF(TestCase):
+class TestDecodeCF(object):
def test_dataset(self):
original = Dataset({
't': ('t', [0, 1, 2], {'units': 'days since 2000-01-01'}),
@@ -255,7 +255,7 @@ def encode_variable(self, var):
@requires_netCDF4
-class TestCFEncodedDataStore(CFEncodedDataTest, TestCase):
+class TestCFEncodedDataStore(CFEncodedDataTest):
@contextlib.contextmanager
def create_store(self):
yield CFEncodedInMemoryStore()
@@ -267,9 +267,10 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={},
data.dump_to_store(store, **save_kwargs)
yield open_dataset(store, **open_kwargs)
+ @pytest.mark.skip('cannot roundtrip coordinates yet for '
+ 'CFEncodedInMemoryStore')
def test_roundtrip_coordinates(self):
- raise unittest.SkipTest('cannot roundtrip coordinates yet for '
- 'CFEncodedInMemoryStore')
+ pass
def test_invalid_dataarray_names_raise(self):
# only relevant for on-disk file formats
diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py
index f6c47cce8d8..e56f751bef9 100644
--- a/xarray/tests/test_dask.py
+++ b/xarray/tests/test_dask.py
@@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function
import pickle
+from distutils.version import LooseVersion
from textwrap import dedent
import numpy as np
@@ -14,18 +15,22 @@
from xarray.tests import mock
from . import (
- TestCase, assert_allclose, assert_array_equal, assert_equal,
- assert_frame_equal, assert_identical, raises_regex)
+ assert_allclose, assert_array_equal, assert_equal, assert_frame_equal,
+ assert_identical, raises_regex)
dask = pytest.importorskip('dask')
da = pytest.importorskip('dask.array')
dd = pytest.importorskip('dask.dataframe')
-class DaskTestCase(TestCase):
+class DaskTestCase(object):
def assertLazyAnd(self, expected, actual, test):
- with dask.set_options(get=dask.get):
+
+ with (dask.config.set(get=dask.get)
+ if LooseVersion(dask.__version__) >= LooseVersion('0.18.0')
+ else dask.set_options(get=dask.get)):
test(actual, expected)
+
if isinstance(actual, Dataset):
for k, v in actual.variables.items():
if k in actual.dims:
@@ -52,6 +57,7 @@ def assertLazyAndIdentical(self, expected, actual):
def assertLazyAndAllClose(self, expected, actual):
self.assertLazyAnd(expected, actual, assert_allclose)
+ @pytest.fixture(autouse=True)
def setUp(self):
self.values = np.random.RandomState(0).randn(4, 6)
self.data = da.from_array(self.values, chunks=(2, 2))
@@ -196,11 +202,13 @@ def test_missing_methods(self):
except NotImplementedError as err:
assert 'dask' in str(err)
+ @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning')
def test_univariate_ufunc(self):
u = self.eager_var
v = self.lazy_var
self.assertLazyAndAllClose(np.sin(u), xu.sin(v))
+ @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning')
def test_bivariate_ufunc(self):
u = self.eager_var
v = self.lazy_var
@@ -242,6 +250,7 @@ def assertLazyAndAllClose(self, expected, actual):
def assertLazyAndEqual(self, expected, actual):
self.assertLazyAnd(expected, actual, assert_equal)
+ @pytest.fixture(autouse=True)
def setUp(self):
self.values = np.random.randn(4, 6)
self.data = da.from_array(self.values, chunks=(2, 2))
@@ -378,8 +387,8 @@ def test_groupby(self):
u = self.eager_array
v = self.lazy_array
- expected = u.groupby('x').mean()
- actual = v.groupby('x').mean()
+ expected = u.groupby('x').mean(xr.ALL_DIMS)
+ actual = v.groupby('x').mean(xr.ALL_DIMS)
self.assertLazyAndAllClose(expected, actual)
def test_groupby_first(self):
@@ -421,6 +430,7 @@ def duplicate_and_merge(array):
actual = duplicate_and_merge(self.lazy_array)
self.assertLazyAndEqual(expected, actual)
+ @pytest.mark.filterwarnings('ignore::PendingDeprecationWarning')
def test_ufuncs(self):
u = self.eager_array
v = self.lazy_array
@@ -573,7 +583,7 @@ def test_from_dask_variable(self):
self.assertLazyAndIdentical(self.lazy_array, a)
-class TestToDaskDataFrame(TestCase):
+class TestToDaskDataFrame(object):
def test_to_dask_dataframe(self):
# Test conversion of Datasets to dask DataFrames
@@ -821,7 +831,9 @@ def test_basic_compute():
dask.multiprocessing.get,
dask.local.get_sync,
None]:
- with dask.set_options(get=get):
+ with (dask.config.set(get=get)
+ if LooseVersion(dask.__version__) >= LooseVersion('0.18.0')
+ else dask.set_options(get=get)):
ds.compute()
ds.foo.compute()
ds.foo.variable.compute()
diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py
index 2950e97cc75..d15a0bb6081 100644
--- a/xarray/tests/test_dataarray.py
+++ b/xarray/tests/test_dataarray.py
@@ -1,9 +1,9 @@
from __future__ import absolute_import, division, print_function
import pickle
+import warnings
from copy import deepcopy
from textwrap import dedent
-import warnings
import numpy as np
import pandas as pd
@@ -12,19 +12,20 @@
import xarray as xr
from xarray import (
DataArray, Dataset, IndexVariable, Variable, align, broadcast, set_options)
-from xarray.convert import from_cdms2
from xarray.coding.times import CFDatetimeCoder, _import_cftime
-from xarray.core.common import full_like
+from xarray.convert import from_cdms2
+from xarray.core.common import ALL_DIMS, full_like
from xarray.core.pycompat import OrderedDict, iteritems
from xarray.tests import (
- ReturnItem, TestCase, assert_allclose, assert_array_equal, assert_equal,
+ ReturnItem, assert_allclose, assert_array_equal, assert_equal,
assert_identical, raises_regex, requires_bottleneck, requires_cftime,
requires_dask, requires_iris, requires_np113, requires_scipy,
- source_ndarray, unittest)
+ source_ndarray)
-class TestDataArray(TestCase):
- def setUp(self):
+class TestDataArray(object):
+ @pytest.fixture(autouse=True)
+ def setup(self):
self.attrs = {'attr1': 'value1', 'attr2': 2929}
self.x = np.random.random((10, 20))
self.v = Variable(['x', 'y'], self.x)
@@ -440,7 +441,7 @@ def test_getitem(self):
assert_identical(self.ds['x'], x)
assert_identical(self.ds['y'], y)
- I = ReturnItem() # noqa: E741 # allow ambiguous name
+ I = ReturnItem() # noqa
for i in [I[:], I[...], I[x.values], I[x.variable], I[x], I[x, y],
I[x.values > -1], I[x.variable > -1], I[x > -1],
I[x > -1, y > -1]]:
@@ -672,6 +673,7 @@ def test_isel_types(self):
assert_identical(da.isel(x=np.array([0], dtype="int64")),
da.isel(x=np.array([0])))
+ @pytest.mark.filterwarnings('ignore::DeprecationWarning')
def test_isel_fancy(self):
shape = (10, 7, 6)
np_array = np.random.random(shape)
@@ -845,6 +847,7 @@ def test_isel_drop(self):
selected = data.isel(x=0, drop=False)
assert_identical(expected, selected)
+ @pytest.mark.filterwarnings("ignore:Dataset.isel_points")
def test_isel_points(self):
shape = (10, 5, 6)
np_array = np.random.random(shape)
@@ -999,7 +1002,7 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False,
assert da.dims[0] == renamed_dim
da = da.rename({renamed_dim: 'x'})
assert_identical(da.variable, expected_da.variable)
- self.assertVariableNotEqual(da['x'], expected_da['x'])
+ assert not da['x'].equals(expected_da['x'])
test_sel(('a', 1, -1), 0)
test_sel(('b', 2, -2), -1)
@@ -1237,6 +1240,7 @@ def test_reindex_like_no_index(self):
ValueError, 'different size for unlabeled'):
foo.reindex_like(bar)
+ @pytest.mark.filterwarnings('ignore:Indexer has dimensions')
def test_reindex_regressions(self):
# regression test for #279
expected = DataArray(np.random.randn(5), coords=[("time", range(5))])
@@ -1286,7 +1290,7 @@ def test_swap_dims(self):
def test_expand_dims_error(self):
array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'],
- coords={'x': np.linspace(0.0, 1.0, 3.0)},
+ coords={'x': np.linspace(0.0, 1.0, 3)},
attrs={'key': 'entry'})
with raises_regex(ValueError, 'dim should be str or'):
@@ -1660,9 +1664,23 @@ def test_dataset_math(self):
def test_stack_unstack(self):
orig = DataArray([[0, 1], [2, 3]], dims=['x', 'y'], attrs={'foo': 2})
+ assert_identical(orig, orig.unstack())
+
actual = orig.stack(z=['x', 'y']).unstack('z').drop(['x', 'y'])
assert_identical(orig, actual)
+ dims = ['a', 'b', 'c', 'd', 'e']
+ orig = xr.DataArray(np.random.rand(1, 2, 3, 2, 1), dims=dims)
+ stacked = orig.stack(ab=['a', 'b'], cd=['c', 'd'])
+
+ unstacked = stacked.unstack(['ab', 'cd'])
+ roundtripped = unstacked.drop(['a', 'b', 'c', 'd']).transpose(*dims)
+ assert_identical(orig, roundtripped)
+
+ unstacked = stacked.unstack()
+ roundtripped = unstacked.drop(['a', 'b', 'c', 'd']).transpose(*dims)
+ assert_identical(orig, roundtripped)
+
def test_stack_unstack_decreasing_coordinate(self):
# regression test for GH980
orig = DataArray(np.random.rand(3, 4), dims=('y', 'x'),
@@ -1983,15 +2001,15 @@ def test_groupby_sum(self):
self.x[:, 10:].sum(),
self.x[:, 9:10].sum()]).T),
'abc': Variable(['abc'], np.array(['a', 'b', 'c']))})['foo']
- assert_allclose(expected_sum_all, grouped.reduce(np.sum))
- assert_allclose(expected_sum_all, grouped.sum())
+ assert_allclose(expected_sum_all, grouped.reduce(np.sum, dim=ALL_DIMS))
+ assert_allclose(expected_sum_all, grouped.sum(ALL_DIMS))
expected = DataArray([array['y'].values[idx].sum() for idx
in [slice(9), slice(10, None), slice(9, 10)]],
[['a', 'b', 'c']], ['abc'])
actual = array['y'].groupby('abc').apply(np.sum)
assert_allclose(expected, actual)
- actual = array['y'].groupby('abc').sum()
+ actual = array['y'].groupby('abc').sum(ALL_DIMS)
assert_allclose(expected, actual)
expected_sum_axis1 = Dataset(
@@ -2002,6 +2020,29 @@ def test_groupby_sum(self):
assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, 'y'))
assert_allclose(expected_sum_axis1, grouped.sum('y'))
+ def test_groupby_warning(self):
+ array = self.make_groupby_example_array()
+ grouped = array.groupby('y')
+ with pytest.warns(FutureWarning):
+ grouped.sum()
+
+ # Currently disabled due to https://github.com/pydata/xarray/issues/2468
+ # @pytest.mark.skipif(LooseVersion(xr.__version__) < LooseVersion('0.12'),
+ # reason="not to forget the behavior change")
+ @pytest.mark.skip
+ def test_groupby_sum_default(self):
+ array = self.make_groupby_example_array()
+ grouped = array.groupby('abc')
+
+ expected_sum_all = Dataset(
+ {'foo': Variable(['x', 'abc'],
+ np.array([self.x[:, :9].sum(axis=-1),
+ self.x[:, 10:].sum(axis=-1),
+ self.x[:, 9:10].sum(axis=-1)]).T),
+ 'abc': Variable(['abc'], np.array(['a', 'b', 'c']))})['foo']
+
+ assert_allclose(expected_sum_all, grouped.sum())
+
def test_groupby_count(self):
array = DataArray(
[0, 0, np.nan, np.nan, 0, 0],
@@ -2011,7 +2052,7 @@ def test_groupby_count(self):
expected = DataArray([1, 1, 2], coords=[('cat', ['a', 'b', 'c'])])
assert_identical(actual, expected)
- @unittest.skip('needs to be fixed for shortcut=False, keep_attrs=False')
+ @pytest.mark.skip('needs to be fixed for shortcut=False, keep_attrs=False')
def test_groupby_reduce_attrs(self):
array = self.make_groupby_example_array()
array.attrs['foo'] = 'bar'
@@ -2082,9 +2123,9 @@ def test_groupby_math(self):
assert_identical(expected, actual)
grouped = array.groupby('abc')
- expected_agg = (grouped.mean() - np.arange(3)).rename(None)
+ expected_agg = (grouped.mean(ALL_DIMS) - np.arange(3)).rename(None)
actual = grouped - DataArray(range(3), [('abc', ['a', 'b', 'c'])])
- actual_agg = actual.groupby('abc').mean()
+ actual_agg = actual.groupby('abc').mean(ALL_DIMS)
assert_allclose(expected_agg, actual_agg)
with raises_regex(TypeError, 'only support binary ops'):
@@ -2158,7 +2199,7 @@ def test_groupby_multidim(self):
('lon', DataArray([5, 28, 23],
coords=[('lon', [30., 40., 50.])])),
('lat', DataArray([16, 40], coords=[('lat', [10., 20.])]))]:
- actual_sum = array.groupby(dim).sum()
+ actual_sum = array.groupby(dim).sum(ALL_DIMS)
assert_identical(expected_sum, actual_sum)
def test_groupby_multidim_apply(self):
@@ -2787,7 +2828,7 @@ def test_to_and_from_series(self):
def test_series_categorical_index(self):
# regression test for GH700
if not hasattr(pd, 'CategoricalIndex'):
- raise unittest.SkipTest('requires pandas with CategoricalIndex')
+ pytest.skip('requires pandas with CategoricalIndex')
s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list('aabbc')))
arr = DataArray(s)
@@ -2915,7 +2956,6 @@ def test_to_masked_array(self):
ma = da.to_masked_array()
assert len(ma.mask) == N
- @pytest.mark.xfail # GH:2332 TODO fix this in upstream?
def test_to_and_from_cdms2_classic(self):
"""Classic with 1D axes"""
pytest.importorskip('cdms2')
@@ -2928,9 +2968,9 @@ def test_to_and_from_cdms2_classic(self):
expected_coords = [IndexVariable('distance', [-2, 2]),
IndexVariable('time', [0, 1, 2])]
actual = original.to_cdms2()
- assert_array_equal(actual, original)
+ assert_array_equal(actual.asma(), original)
assert actual.id == original.name
- self.assertItemsEqual(actual.getAxisIds(), original.dims)
+ assert tuple(actual.getAxisIds()) == original.dims
for axis, coord in zip(actual.getAxisList(), expected_coords):
assert axis.id == coord.name
assert_array_equal(axis, coord.values)
@@ -2944,13 +2984,12 @@ def test_to_and_from_cdms2_classic(self):
assert_identical(original, roundtripped)
back = from_cdms2(actual)
- self.assertItemsEqual(original.dims, back.dims)
- self.assertItemsEqual(original.coords.keys(), back.coords.keys())
+ assert original.dims == back.dims
+ assert original.coords.keys() == back.coords.keys()
for coord_name in original.coords.keys():
assert_array_equal(original.coords[coord_name],
back.coords[coord_name])
- @pytest.mark.xfail # GH:2332 TODO fix this in upstream?
def test_to_and_from_cdms2_sgrid(self):
"""Curvilinear (structured) grid
@@ -2967,17 +3006,18 @@ def test_to_and_from_cdms2_sgrid(self):
coords=OrderedDict(x=x, y=y, lon=lon, lat=lat),
name='sst')
actual = original.to_cdms2()
- self.assertItemsEqual(actual.getAxisIds(), original.dims)
- assert_array_equal(original.coords['lon'], actual.getLongitude())
- assert_array_equal(original.coords['lat'], actual.getLatitude())
+ assert tuple(actual.getAxisIds()) == original.dims
+ assert_array_equal(original.coords['lon'],
+ actual.getLongitude().asma())
+ assert_array_equal(original.coords['lat'],
+ actual.getLatitude().asma())
back = from_cdms2(actual)
- self.assertItemsEqual(original.dims, back.dims)
- self.assertItemsEqual(original.coords.keys(), back.coords.keys())
+ assert original.dims == back.dims
+ assert set(original.coords.keys()) == set(back.coords.keys())
assert_array_equal(original.coords['lat'], back.coords['lat'])
assert_array_equal(original.coords['lon'], back.coords['lon'])
- @pytest.mark.xfail # GH:2332 TODO fix this in upstream?
def test_to_and_from_cdms2_ugrid(self):
"""Unstructured grid"""
pytest.importorskip('cdms2')
@@ -2988,13 +3028,15 @@ def test_to_and_from_cdms2_ugrid(self):
original = DataArray(np.arange(5), dims=['cell'],
coords={'lon': lon, 'lat': lat, 'cell': cell})
actual = original.to_cdms2()
- self.assertItemsEqual(actual.getAxisIds(), original.dims)
- assert_array_equal(original.coords['lon'], actual.getLongitude())
- assert_array_equal(original.coords['lat'], actual.getLatitude())
+ assert tuple(actual.getAxisIds()) == original.dims
+ assert_array_equal(original.coords['lon'],
+ actual.getLongitude().getValue())
+ assert_array_equal(original.coords['lat'],
+ actual.getLatitude().getValue())
back = from_cdms2(actual)
- self.assertItemsEqual(original.dims, back.dims)
- self.assertItemsEqual(original.coords.keys(), back.coords.keys())
+ assert set(original.dims) == set(back.dims)
+ assert set(original.coords.keys()) == set(back.coords.keys())
assert_array_equal(original.coords['lat'], back.coords['lat'])
assert_array_equal(original.coords['lon'], back.coords['lon'])
@@ -3087,24 +3129,51 @@ def test_coordinate_diff(self):
actual = lon.diff('lon')
assert_equal(expected, actual)
- def test_shift(self):
+ @pytest.mark.parametrize('offset', [-5, -2, -1, 0, 1, 2, 5])
+ def test_shift(self, offset):
arr = DataArray([1, 2, 3], dims='x')
actual = arr.shift(x=1)
expected = DataArray([np.nan, 1, 2], dims='x')
assert_identical(expected, actual)
arr = DataArray([1, 2, 3], [('x', ['a', 'b', 'c'])])
- for offset in [-5, -2, -1, 0, 1, 2, 5]:
- expected = DataArray(arr.to_pandas().shift(offset))
- actual = arr.shift(x=offset)
- assert_identical(expected, actual)
+ expected = DataArray(arr.to_pandas().shift(offset))
+ actual = arr.shift(x=offset)
+ assert_identical(expected, actual)
- def test_roll(self):
+ def test_roll_coords(self):
arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x')
- actual = arr.roll(x=1)
+ actual = arr.roll(x=1, roll_coords=True)
expected = DataArray([3, 1, 2], coords=[('x', [2, 0, 1])])
assert_identical(expected, actual)
+ def test_roll_no_coords(self):
+ arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x')
+ actual = arr.roll(x=1, roll_coords=False)
+ expected = DataArray([3, 1, 2], coords=[('x', [0, 1, 2])])
+ assert_identical(expected, actual)
+
+ def test_roll_coords_none(self):
+ arr = DataArray([1, 2, 3], coords={'x': range(3)}, dims='x')
+
+ with pytest.warns(FutureWarning):
+ actual = arr.roll(x=1, roll_coords=None)
+
+ expected = DataArray([3, 1, 2], coords=[('x', [2, 0, 1])])
+ assert_identical(expected, actual)
+
+ def test_copy_with_data(self):
+ orig = DataArray(np.random.random(size=(2, 2)),
+ dims=('x', 'y'),
+ attrs={'attr1': 'value1'},
+ coords={'x': [4, 3]},
+ name='helloworld')
+ new_data = np.arange(4).reshape(2, 2)
+ actual = orig.copy(data=new_data)
+ expected = orig.copy()
+ expected.data = new_data
+ assert_identical(expected, actual)
+
def test_real_and_imag(self):
array = DataArray(1 + 2j)
assert_identical(array.real, DataArray(1))
@@ -3329,7 +3398,9 @@ def test_isin(da):
def test_rolling_iter(da):
rolling_obj = da.rolling(time=7)
- rolling_obj_mean = rolling_obj.mean()
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', 'Mean of empty slice')
+ rolling_obj_mean = rolling_obj.mean()
assert len(rolling_obj.window_labels) == len(da['time'])
assert_identical(rolling_obj.window_labels, da['time'])
@@ -3337,8 +3408,10 @@ def test_rolling_iter(da):
for i, (label, window_da) in enumerate(rolling_obj):
assert label == da['time'].isel(time=i)
- actual = rolling_obj_mean.isel(time=i)
- expected = window_da.mean('time')
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', 'Mean of empty slice')
+ actual = rolling_obj_mean.isel(time=i)
+ expected = window_da.mean('time')
# TODO add assert_allclose_with_nan, which compares nan position
# as well as the closeness of the values.
diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py
index 08d71d462d8..89704653e92 100644
--- a/xarray/tests/test_dataset.py
+++ b/xarray/tests/test_dataset.py
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
+import sys
+import warnings
from copy import copy, deepcopy
from io import StringIO
from textwrap import dedent
-import warnings
-import sys
import numpy as np
import pandas as pd
@@ -13,17 +13,18 @@
import xarray as xr
from xarray import (
- DataArray, Dataset, IndexVariable, MergeError, Variable, align, backends,
- broadcast, open_dataset, set_options)
-from xarray.core import indexing, utils
+ ALL_DIMS, DataArray, Dataset, IndexVariable, MergeError, Variable, align,
+ backends, broadcast, open_dataset, set_options)
+from xarray.core import indexing, npcompat, utils
from xarray.core.common import full_like
from xarray.core.pycompat import (
OrderedDict, integer_types, iteritems, unicode_type)
from . import (
- InaccessibleArray, TestCase, UnexpectedDataAccess, assert_allclose,
- assert_array_equal, assert_equal, assert_identical, has_dask, raises_regex,
- requires_bottleneck, requires_dask, requires_scipy, source_ndarray)
+ InaccessibleArray, UnexpectedDataAccess, assert_allclose,
+ assert_array_equal, assert_equal, assert_identical, has_cftime, has_dask,
+ raises_regex, requires_bottleneck, requires_dask, requires_scipy,
+ source_ndarray)
try:
import cPickle as pickle
@@ -63,8 +64,8 @@ def create_test_multiindex():
class InaccessibleVariableDataStore(backends.InMemoryDataStore):
- def __init__(self, writer=None):
- super(InaccessibleVariableDataStore, self).__init__(writer)
+ def __init__(self):
+ super(InaccessibleVariableDataStore, self).__init__()
self._indexvars = set()
def store(self, variables, *args, **kwargs):
@@ -85,7 +86,7 @@ def lazy_inaccessible(k, v):
k, v in iteritems(self._variables))
-class TestDataset(TestCase):
+class TestDataset(object):
def test_repr(self):
data = create_test_data(seed=123)
data.attrs['foo'] = 'bar'
@@ -398,7 +399,7 @@ def test_constructor_with_coords(self):
ds = Dataset({}, {'a': ('x', [1])})
assert not ds.data_vars
- self.assertItemsEqual(ds.coords.keys(), ['a'])
+ assert list(ds.coords.keys()) == ['a']
mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2]],
names=('level_1', 'level_2'))
@@ -420,9 +421,9 @@ def test_properties(self):
assert type(ds.dims.mapping.mapping) is dict # noqa
with pytest.warns(FutureWarning):
- self.assertItemsEqual(ds, list(ds.variables))
+ assert list(ds) == list(ds.variables)
with pytest.warns(FutureWarning):
- self.assertItemsEqual(ds.keys(), list(ds.variables))
+ assert list(ds.keys()) == list(ds.variables)
assert 'aasldfjalskdfj' not in ds.variables
assert 'dim1' in repr(ds.variables)
with pytest.warns(FutureWarning):
@@ -430,18 +431,18 @@ def test_properties(self):
with pytest.warns(FutureWarning):
assert bool(ds)
- self.assertItemsEqual(ds.data_vars, ['var1', 'var2', 'var3'])
- self.assertItemsEqual(ds.data_vars.keys(), ['var1', 'var2', 'var3'])
+ assert list(ds.data_vars) == ['var1', 'var2', 'var3']
+ assert list(ds.data_vars.keys()) == ['var1', 'var2', 'var3']
assert 'var1' in ds.data_vars
assert 'dim1' not in ds.data_vars
assert 'numbers' not in ds.data_vars
assert len(ds.data_vars) == 3
- self.assertItemsEqual(ds.indexes, ['dim2', 'dim3', 'time'])
+ assert set(ds.indexes) == {'dim2', 'dim3', 'time'}
assert len(ds.indexes) == 3
assert 'dim2' in repr(ds.indexes)
- self.assertItemsEqual(ds.coords, ['time', 'dim2', 'dim3', 'numbers'])
+ assert list(ds.coords) == ['time', 'dim2', 'dim3', 'numbers']
assert 'dim2' in ds.coords
assert 'numbers' in ds.coords
assert 'var1' not in ds.coords
@@ -534,7 +535,7 @@ def test_coords_properties(self):
assert 4 == len(data.coords)
- self.assertItemsEqual(['x', 'y', 'a', 'b'], list(data.coords))
+ assert ['x', 'y', 'a', 'b'] == list(data.coords)
assert_identical(data.coords['x'].variable, data['x'].variable)
assert_identical(data.coords['y'].variable, data['y'].variable)
@@ -830,7 +831,7 @@ def test_isel(self):
ret = data.isel(**slicers)
# Verify that only the specified dimension was altered
- self.assertItemsEqual(data.dims, ret.dims)
+ assert list(data.dims) == list(ret.dims)
for d in data.dims:
if d in slicers:
assert ret.dims[d] == \
@@ -856,21 +857,21 @@ def test_isel(self):
ret = data.isel(dim1=0)
assert {'time': 20, 'dim2': 9, 'dim3': 10} == ret.dims
- self.assertItemsEqual(data.data_vars, ret.data_vars)
- self.assertItemsEqual(data.coords, ret.coords)
- self.assertItemsEqual(data.indexes, ret.indexes)
+ assert set(data.data_vars) == set(ret.data_vars)
+ assert set(data.coords) == set(ret.coords)
+ assert set(data.indexes) == set(ret.indexes)
ret = data.isel(time=slice(2), dim1=0, dim2=slice(5))
assert {'time': 2, 'dim2': 5, 'dim3': 10} == ret.dims
- self.assertItemsEqual(data.data_vars, ret.data_vars)
- self.assertItemsEqual(data.coords, ret.coords)
- self.assertItemsEqual(data.indexes, ret.indexes)
+ assert set(data.data_vars) == set(ret.data_vars)
+ assert set(data.coords) == set(ret.coords)
+ assert set(data.indexes) == set(ret.indexes)
ret = data.isel(time=0, dim1=0, dim2=slice(5))
- self.assertItemsEqual({'dim2': 5, 'dim3': 10}, ret.dims)
- self.assertItemsEqual(data.data_vars, ret.data_vars)
- self.assertItemsEqual(data.coords, ret.coords)
- self.assertItemsEqual(data.indexes, list(ret.indexes) + ['time'])
+ assert {'dim2': 5, 'dim3': 10} == ret.dims
+ assert set(data.data_vars) == set(ret.data_vars)
+ assert set(data.coords) == set(ret.coords)
+ assert set(data.indexes) == set(list(ret.indexes) + ['time'])
def test_isel_fancy(self):
# isel with fancy indexing.
@@ -1240,6 +1241,7 @@ def test_isel_drop(self):
selected = data.isel(x=0, drop=False)
assert_identical(expected, selected)
+ @pytest.mark.filterwarnings("ignore:Dataset.isel_points")
def test_isel_points(self):
data = create_test_data()
@@ -1317,6 +1319,8 @@ def test_isel_points(self):
dim2=stations['dim2s'],
dim=np.array([4, 5, 6]))
+ @pytest.mark.filterwarnings("ignore:Dataset.sel_points")
+ @pytest.mark.filterwarnings("ignore:Dataset.isel_points")
def test_sel_points(self):
data = create_test_data()
@@ -1347,6 +1351,7 @@ def test_sel_points(self):
with pytest.raises(KeyError):
data.sel_points(x=[2.5], y=[2.0], method='pad', tolerance=1e-3)
+ @pytest.mark.filterwarnings('ignore::DeprecationWarning')
def test_sel_fancy(self):
data = create_test_data()
@@ -1477,7 +1482,7 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False,
ds = ds.rename({renamed_dim: 'x'})
assert_identical(ds['var'].variable,
expected_ds['var'].variable)
- self.assertVariableNotEqual(ds['x'], expected_ds['x'])
+ assert not ds['x'].equals(expected_ds['x'])
test_sel(('a', 1, -1), 0)
test_sel(('b', 2, -2), -1)
@@ -1888,6 +1893,27 @@ def test_copy(self):
v1 = copied.variables[k]
assert v0 is not v1
+ def test_copy_with_data(self):
+ orig = create_test_data()
+ new_data = {k: np.random.randn(*v.shape)
+ for k, v in iteritems(orig.data_vars)}
+ actual = orig.copy(data=new_data)
+
+ expected = orig.copy()
+ for k, v in new_data.items():
+ expected[k].data = v
+ assert_identical(expected, actual)
+
+ def test_copy_with_data_errors(self):
+ orig = create_test_data()
+ new_var1 = np.arange(orig['var1'].size).reshape(orig['var1'].shape)
+ with raises_regex(ValueError, 'Data must be dict-like'):
+ orig.copy(data=new_var1)
+ with raises_regex(ValueError, 'only contain variables in original'):
+ orig.copy(data={'not_in_original': new_var1})
+ with raises_regex(ValueError, 'contain all variables in original'):
+ orig.copy(data={'var1': new_var1})
+
def test_rename(self):
data = create_test_data()
newnames = {'var1': 'renamed_var1', 'dim2': 'renamed_dim2'}
@@ -2103,17 +2129,18 @@ def test_unstack(self):
expected = Dataset({'b': (('x', 'y'), [[0, 1], [2, 3]]),
'x': [0, 1],
'y': ['a', 'b']})
- actual = ds.unstack('z')
- assert_identical(actual, expected)
+ for dim in ['z', ['z'], None]:
+ actual = ds.unstack(dim)
+ assert_identical(actual, expected)
def test_unstack_errors(self):
ds = Dataset({'x': [1, 2, 3]})
- with raises_regex(ValueError, 'invalid dimension'):
+ with raises_regex(ValueError, 'does not contain the dimensions'):
ds.unstack('foo')
- with raises_regex(ValueError, 'does not have a MultiIndex'):
+ with raises_regex(ValueError, 'do not have a MultiIndex'):
ds.unstack('x')
- def test_stack_unstack(self):
+ def test_stack_unstack_fast(self):
ds = Dataset({'a': ('x', [0, 1]),
'b': (('x', 'y'), [[0, 1], [2, 3]]),
'x': [0, 1],
@@ -2124,6 +2151,19 @@ def test_stack_unstack(self):
actual = ds[['b']].stack(z=['x', 'y']).unstack('z')
assert actual.identical(ds[['b']])
+ def test_stack_unstack_slow(self):
+ ds = Dataset({'a': ('x', [0, 1]),
+ 'b': (('x', 'y'), [[0, 1], [2, 3]]),
+ 'x': [0, 1],
+ 'y': ['a', 'b']})
+ stacked = ds.stack(z=['x', 'y'])
+ actual = stacked.isel(z=slice(None, None, -1)).unstack('z')
+ assert actual.broadcast_equals(ds)
+
+ stacked = ds[['b']].stack(z=['x', 'y'])
+ actual = stacked.isel(z=slice(None, None, -1)).unstack('z')
+ assert actual.identical(ds[['b']])
+
def test_update(self):
data = create_test_data(seed=0)
expected = data.copy()
@@ -2506,12 +2546,11 @@ def test_setitem_multiindex_level(self):
def test_delitem(self):
data = create_test_data()
all_items = set(data.variables)
- self.assertItemsEqual(data.variables, all_items)
+ assert set(data.variables) == all_items
del data['var1']
- self.assertItemsEqual(data.variables, all_items - set(['var1']))
+ assert set(data.variables) == all_items - set(['var1'])
del data['numbers']
- self.assertItemsEqual(data.variables,
- all_items - set(['var1', 'numbers']))
+ assert set(data.variables) == all_items - set(['var1', 'numbers'])
assert 'numbers' not in data.coords
def test_squeeze(self):
@@ -2609,20 +2648,28 @@ def test_groupby_reduce(self):
expected = data.mean('y')
expected['yonly'] = expected['yonly'].variable.set_dims({'x': 3})
- actual = data.groupby('x').mean()
+ actual = data.groupby('x').mean(ALL_DIMS)
assert_allclose(expected, actual)
actual = data.groupby('x').mean('y')
assert_allclose(expected, actual)
letters = data['letters']
- expected = Dataset({'xy': data['xy'].groupby(letters).mean(),
+ expected = Dataset({'xy': data['xy'].groupby(letters).mean(ALL_DIMS),
'xonly': (data['xonly'].mean().variable
.set_dims({'letters': 2})),
'yonly': data['yonly'].groupby(letters).mean()})
- actual = data.groupby('letters').mean()
+ actual = data.groupby('letters').mean(ALL_DIMS)
assert_allclose(expected, actual)
+ def test_groupby_warn(self):
+ data = Dataset({'xy': (['x', 'y'], np.random.randn(3, 4)),
+ 'xonly': ('x', np.random.randn(3)),
+ 'yonly': ('y', np.random.randn(4)),
+ 'letters': ('y', ['a', 'a', 'b', 'b'])})
+ with pytest.warns(FutureWarning):
+ data.groupby('x').mean()
+
def test_groupby_math(self):
def reorder_dims(x):
return x.transpose('dim1', 'dim2', 'dim3', 'time')
@@ -2677,7 +2724,7 @@ def test_groupby_math_virtual(self):
ds = Dataset({'x': ('t', [1, 2, 3])},
{'t': pd.date_range('20100101', periods=3)})
grouped = ds.groupby('t.day')
- actual = grouped - grouped.mean()
+ actual = grouped - grouped.mean(ALL_DIMS)
expected = Dataset({'x': ('t', [0, 0, 0])},
ds[['t', 't.day']])
assert_identical(actual, expected)
@@ -2686,18 +2733,17 @@ def test_groupby_nan(self):
# nan should be excluded from groupby
ds = Dataset({'foo': ('x', [1, 2, 3, 4])},
{'bar': ('x', [1, 1, 2, np.nan])})
- actual = ds.groupby('bar').mean()
+ actual = ds.groupby('bar').mean(ALL_DIMS)
expected = Dataset({'foo': ('bar', [1.5, 3]), 'bar': [1, 2]})
assert_identical(actual, expected)
def test_groupby_order(self):
# groupby should preserve variables order
-
ds = Dataset()
for vn in ['a', 'b', 'c']:
ds[vn] = DataArray(np.arange(10), dims=['t'])
data_vars_ref = list(ds.data_vars.keys())
- ds = ds.groupby('t').mean()
+ ds = ds.groupby('t').mean(ALL_DIMS)
data_vars = list(ds.data_vars.keys())
assert data_vars == data_vars_ref
# coords are now at the end of the list, so the test below fails
@@ -2727,6 +2773,20 @@ def test_resample_and_first(self):
result = actual.reduce(method)
assert_equal(expected, result)
+ def test_resample_min_count(self):
+ times = pd.date_range('2000-01-01', freq='6H', periods=10)
+ ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)),
+ 'bar': ('time', np.random.randn(10), {'meta': 'data'}),
+ 'time': times})
+ # inject nan
+ ds['foo'] = xr.where(ds['foo'] > 2.0, np.nan, ds['foo'])
+
+ actual = ds.resample(time='1D').sum(min_count=1)
+ expected = xr.concat([
+ ds.isel(time=slice(i * 4, (i + 1) * 4)).sum('time', min_count=1)
+ for i in range(3)], dim=actual['time'])
+ assert_equal(expected, actual)
+
def test_resample_by_mean_with_keep_attrs(self):
times = pd.date_range('2000-01-01', freq='6H', periods=10)
ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)),
@@ -3364,9 +3424,8 @@ def test_reduce(self):
(['dim2', 'time'], ['dim1', 'dim3']),
(('dim2', 'time'), ['dim1', 'dim3']),
((), ['dim1', 'dim2', 'dim3', 'time'])]:
- actual = data.min(dim=reduct).dims
- print(reduct, actual, expected)
- self.assertItemsEqual(actual, expected)
+ actual = list(data.min(dim=reduct).dims)
+ assert actual == expected
assert_equal(data.mean(dim=[]), data)
@@ -3420,8 +3479,7 @@ def test_reduce_cumsum_test_dims(self):
('time', ['dim1', 'dim2', 'dim3'])
]:
actual = getattr(data, cumfunc)(dim=reduct).dims
- print(reduct, actual, expected)
- self.assertItemsEqual(actual, expected)
+ assert list(actual) == expected
def test_reduce_non_numeric(self):
data1 = create_test_data(seed=44)
@@ -3559,14 +3617,14 @@ def test_rank(self):
ds = create_test_data(seed=1234)
# only ds.var3 depends on dim3
z = ds.rank('dim3')
- self.assertItemsEqual(['var3'], list(z.data_vars))
+ assert ['var3'] == list(z.data_vars)
# same as dataarray version
x = z.var3
y = ds.var3.rank('dim3')
assert_equal(x, y)
# coordinates stick
- self.assertItemsEqual(list(z.coords), list(ds.coords))
- self.assertItemsEqual(list(x.coords), list(y.coords))
+ assert list(z.coords) == list(ds.coords)
+ assert list(x.coords) == list(y.coords)
# invalid dim
with raises_regex(ValueError, 'does not contain'):
x.rank('invalid_dim')
@@ -3849,18 +3907,52 @@ def test_shift(self):
with raises_regex(ValueError, 'dimensions'):
ds.shift(foo=123)
- def test_roll(self):
+ def test_roll_coords(self):
coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]}
attrs = {'meta': 'data'}
ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs)
- actual = ds.roll(x=1)
+ actual = ds.roll(x=1, roll_coords=True)
ex_coords = {'bar': ('x', list('cab')), 'x': [2, -4, 3]}
expected = Dataset({'foo': ('x', [3, 1, 2])}, ex_coords, attrs)
assert_identical(expected, actual)
with raises_regex(ValueError, 'dimensions'):
- ds.roll(foo=123)
+ ds.roll(foo=123, roll_coords=True)
+
+ def test_roll_no_coords(self):
+ coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]}
+ attrs = {'meta': 'data'}
+ ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs)
+ actual = ds.roll(x=1, roll_coords=False)
+
+ expected = Dataset({'foo': ('x', [3, 1, 2])}, coords, attrs)
+ assert_identical(expected, actual)
+
+ with raises_regex(ValueError, 'dimensions'):
+ ds.roll(abc=321, roll_coords=False)
+
+ def test_roll_coords_none(self):
+ coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]}
+ attrs = {'meta': 'data'}
+ ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs)
+
+ with pytest.warns(FutureWarning):
+ actual = ds.roll(x=1, roll_coords=None)
+
+ ex_coords = {'bar': ('x', list('cab')), 'x': [2, -4, 3]}
+ expected = Dataset({'foo': ('x', [3, 1, 2])}, ex_coords, attrs)
+ assert_identical(expected, actual)
+
+ def test_roll_multidim(self):
+ # regression test for 2445
+ arr = xr.DataArray(
+ [[1, 2, 3], [4, 5, 6]], coords={'x': range(3), 'y': range(2)},
+ dims=('y', 'x'))
+ actual = arr.roll(x=1, roll_coords=True)
+ expected = xr.DataArray([[3, 1, 2], [6, 4, 5]],
+ coords=[('y', [0, 1]), ('x', [2, 0, 1])])
+ assert_identical(expected, actual)
def test_real_and_imag(self):
attrs = {'foo': 'bar'}
@@ -3920,6 +4012,26 @@ def test_filter_by_attrs(self):
for var in new_ds.data_vars:
assert new_ds[var].height == '10 m'
+ # Test return empty Dataset due to conflicting filters
+ new_ds = ds.filter_by_attrs(
+ standard_name='convective_precipitation_flux',
+ height='0 m')
+ assert not bool(new_ds.data_vars)
+
+ # Test return one DataArray with two filter conditions
+ new_ds = ds.filter_by_attrs(
+ standard_name='air_potential_temperature',
+ height='0 m')
+ for var in new_ds.data_vars:
+ assert new_ds[var].standard_name == 'air_potential_temperature'
+ assert new_ds[var].height == '0 m'
+ assert new_ds[var].height != '10 m'
+
+ # Test return empty Dataset due to conflicting callables
+ new_ds = ds.filter_by_attrs(standard_name=lambda v: False,
+ height=lambda v: True)
+ assert not bool(new_ds.data_vars)
+
def test_binary_op_join_setting(self):
# arithmetic_join applies to data array coordinates
missing_2 = xr.Dataset({'x': [0, 1]})
@@ -4418,3 +4530,107 @@ def test_raise_no_warning_for_nan_in_binary_ops():
with pytest.warns(None) as record:
Dataset(data_vars={'x': ('y', [1, 2, np.NaN])}) > 0
assert len(record) == 0
+
+
+@pytest.mark.parametrize('dask', [True, False])
+@pytest.mark.parametrize('edge_order', [1, 2])
+def test_differentiate(dask, edge_order):
+ rs = np.random.RandomState(42)
+ coord = [0.2, 0.35, 0.4, 0.6, 0.7, 0.75, 0.76, 0.8]
+
+ da = xr.DataArray(rs.randn(8, 6), dims=['x', 'y'],
+ coords={'x': coord,
+ 'z': 3, 'x2d': (('x', 'y'), rs.randn(8, 6))})
+ if dask and has_dask:
+ da = da.chunk({'x': 4})
+
+ ds = xr.Dataset({'var': da})
+
+ # along x
+ actual = da.differentiate('x', edge_order)
+ expected_x = xr.DataArray(
+ npcompat.gradient(da, da['x'], axis=0, edge_order=edge_order),
+ dims=da.dims, coords=da.coords)
+ assert_equal(expected_x, actual)
+ assert_equal(ds['var'].differentiate('x', edge_order=edge_order),
+ ds.differentiate('x', edge_order=edge_order)['var'])
+ # coordinate should not change
+ assert_equal(da['x'], actual['x'])
+
+ # along y
+ actual = da.differentiate('y', edge_order)
+ expected_y = xr.DataArray(
+ npcompat.gradient(da, da['y'], axis=1, edge_order=edge_order),
+ dims=da.dims, coords=da.coords)
+ assert_equal(expected_y, actual)
+ assert_equal(actual, ds.differentiate('y', edge_order=edge_order)['var'])
+ assert_equal(ds['var'].differentiate('y', edge_order=edge_order),
+ ds.differentiate('y', edge_order=edge_order)['var'])
+
+ with pytest.raises(ValueError):
+ da.differentiate('x2d')
+
+
+@pytest.mark.parametrize('dask', [True, False])
+def test_differentiate_datetime(dask):
+ rs = np.random.RandomState(42)
+ coord = np.array(
+ ['2004-07-13', '2006-01-13', '2010-08-13', '2010-09-13',
+ '2010-10-11', '2010-12-13', '2011-02-13', '2012-08-13'],
+ dtype='datetime64')
+
+ da = xr.DataArray(rs.randn(8, 6), dims=['x', 'y'],
+ coords={'x': coord,
+ 'z': 3, 'x2d': (('x', 'y'), rs.randn(8, 6))})
+ if dask and has_dask:
+ da = da.chunk({'x': 4})
+
+ # along x
+ actual = da.differentiate('x', edge_order=1, datetime_unit='D')
+ expected_x = xr.DataArray(
+ npcompat.gradient(
+ da, utils.datetime_to_numeric(da['x'], datetime_unit='D'),
+ axis=0, edge_order=1), dims=da.dims, coords=da.coords)
+ assert_equal(expected_x, actual)
+
+ actual2 = da.differentiate('x', edge_order=1, datetime_unit='h')
+ assert np.allclose(actual, actual2 * 24)
+
+ # for datetime variable
+ actual = da['x'].differentiate('x', edge_order=1, datetime_unit='D')
+ assert np.allclose(actual, 1.0)
+
+ # with different date unit
+ da = xr.DataArray(coord.astype('datetime64[ms]'), dims=['x'],
+ coords={'x': coord})
+ actual = da.differentiate('x', edge_order=1)
+ assert np.allclose(actual, 1.0)
+
+
+@pytest.mark.skipif(not has_cftime, reason='Test requires cftime.')
+@pytest.mark.parametrize('dask', [True, False])
+def test_differentiate_cftime(dask):
+ rs = np.random.RandomState(42)
+ coord = xr.cftime_range('2000', periods=8, freq='2M')
+
+ da = xr.DataArray(
+ rs.randn(8, 6),
+ coords={'time': coord, 'z': 3, 't2d': (('time', 'y'), rs.randn(8, 6))},
+ dims=['time', 'y'])
+
+ if dask and has_dask:
+ da = da.chunk({'time': 4})
+
+ actual = da.differentiate('time', edge_order=1, datetime_unit='D')
+ expected_data = npcompat.gradient(
+ da, utils.datetime_to_numeric(da['time'], datetime_unit='D'),
+ axis=0, edge_order=1)
+ expected = xr.DataArray(expected_data, coords=da.coords, dims=da.dims)
+ assert_equal(expected, actual)
+
+ actual2 = da.differentiate('time', edge_order=1, datetime_unit='h')
+ assert_allclose(actual, actual2 * 24)
+
+ # Test the differentiation of datetimes themselves
+ actual = da['time'].differentiate('time', edge_order=1, datetime_unit='D')
+ assert_allclose(actual, xr.ones_like(da['time']).astype(float))
diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py
index 32035afdc57..7c77a62d3c9 100644
--- a/xarray/tests/test_distributed.py
+++ b/xarray/tests/test_distributed.py
@@ -15,12 +15,13 @@
from distributed.utils_test import cluster, gen_cluster
from distributed.utils_test import loop # flake8: noqa
from distributed.client import futures_of
+import numpy as np
import xarray as xr
+from xarray.backends.locks import HDF5_LOCK, CombinedLock
from xarray.tests.test_backends import (ON_WINDOWS, create_tmp_file,
create_tmp_geotiff)
from xarray.tests.test_dataset import create_test_data
-from xarray.backends.common import HDF5_LOCK, CombinedLock
from . import (
assert_allclose, has_h5netcdf, has_netCDF4, requires_rasterio, has_scipy,
@@ -33,6 +34,11 @@
da = pytest.importorskip('dask.array')
+@pytest.fixture
+def tmp_netcdf_filename(tmpdir):
+ return str(tmpdir.join('testfile.nc'))
+
+
ENGINES = []
if has_scipy:
ENGINES.append('scipy')
@@ -45,81 +51,69 @@
'NETCDF3_64BIT_DATA', 'NETCDF4_CLASSIC', 'NETCDF4'],
'scipy': ['NETCDF3_CLASSIC', 'NETCDF3_64BIT'],
'h5netcdf': ['NETCDF4']}
-TEST_FORMATS = ['NETCDF3_CLASSIC', 'NETCDF4_CLASSIC', 'NETCDF4']
+ENGINES_AND_FORMATS = [
+ ('netcdf4', 'NETCDF3_CLASSIC'),
+ ('netcdf4', 'NETCDF4_CLASSIC'),
+ ('netcdf4', 'NETCDF4'),
+ ('h5netcdf', 'NETCDF4'),
+ ('scipy', 'NETCDF3_64BIT'),
+]
-@pytest.mark.xfail(sys.platform == 'win32',
- reason='https://github.com/pydata/xarray/issues/1738')
-@pytest.mark.parametrize('engine', ['netcdf4'])
-@pytest.mark.parametrize('autoclose', [True, False])
-@pytest.mark.parametrize('nc_format', TEST_FORMATS)
-def test_dask_distributed_netcdf_roundtrip(monkeypatch, loop,
- engine, autoclose, nc_format):
- monkeypatch.setenv('HDF5_USE_FILE_LOCKING', 'FALSE')
-
- chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6}
-
- with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename:
- with cluster() as (s, [a, b]):
- with Client(s['address'], loop=loop) as c:
+@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS)
+def test_dask_distributed_netcdf_roundtrip(
+ loop, tmp_netcdf_filename, engine, nc_format):
- original = create_test_data().chunk(chunks)
- original.to_netcdf(filename, engine=engine, format=nc_format)
-
- with xr.open_dataset(filename,
- chunks=chunks,
- engine=engine,
- autoclose=autoclose) as restored:
- assert isinstance(restored.var1.data, da.Array)
- computed = restored.compute()
- assert_allclose(original, computed)
+ if engine not in ENGINES:
+ pytest.skip('engine not available')
+ chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6}
-@pytest.mark.xfail(sys.platform == 'win32',
- reason='https://github.com/pydata/xarray/issues/1738')
-@pytest.mark.parametrize('engine', ENGINES)
-@pytest.mark.parametrize('autoclose', [True, False])
-@pytest.mark.parametrize('nc_format', TEST_FORMATS)
-def test_dask_distributed_read_netcdf_integration_test(loop, engine, autoclose,
- nc_format):
+ with cluster() as (s, [a, b]):
+ with Client(s['address'], loop=loop) as c:
- if engine == 'h5netcdf' and autoclose:
- pytest.skip('h5netcdf does not support autoclose')
+ original = create_test_data().chunk(chunks)
- if nc_format not in NC_FORMATS[engine]:
- pytest.skip('invalid format for engine')
+ if engine == 'scipy':
+ with pytest.raises(NotImplementedError):
+ original.to_netcdf(tmp_netcdf_filename,
+ engine=engine, format=nc_format)
+ return
- chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6}
+ original.to_netcdf(tmp_netcdf_filename,
+ engine=engine, format=nc_format)
- with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename:
- with cluster() as (s, [a, b]):
- with Client(s['address'], loop=loop) as c:
+ with xr.open_dataset(tmp_netcdf_filename,
+ chunks=chunks, engine=engine) as restored:
+ assert isinstance(restored.var1.data, da.Array)
+ computed = restored.compute()
+ assert_allclose(original, computed)
- original = create_test_data()
- original.to_netcdf(filename, engine=engine, format=nc_format)
- with xr.open_dataset(filename,
- chunks=chunks,
- engine=engine,
- autoclose=autoclose) as restored:
- assert isinstance(restored.var1.data, da.Array)
- computed = restored.compute()
- assert_allclose(original, computed)
+@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS)
+def test_dask_distributed_read_netcdf_integration_test(
+ loop, tmp_netcdf_filename, engine, nc_format):
+ if engine not in ENGINES:
+ pytest.skip('engine not available')
-@pytest.mark.parametrize('engine', ['h5netcdf', 'scipy'])
-def test_dask_distributed_netcdf_integration_test_not_implemented(loop, engine):
chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6}
- with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename:
- with cluster() as (s, [a, b]):
- with Client(s['address'], loop=loop) as c:
+ with cluster() as (s, [a, b]):
+ with Client(s['address'], loop=loop) as c:
+
+ original = create_test_data()
+ original.to_netcdf(tmp_netcdf_filename,
+ engine=engine, format=nc_format)
- original = create_test_data().chunk(chunks)
+ with xr.open_dataset(tmp_netcdf_filename,
+ chunks=chunks,
+ engine=engine) as restored:
+ assert isinstance(restored.var1.data, da.Array)
+ computed = restored.compute()
+ assert_allclose(original, computed)
- with raises_regex(NotImplementedError, 'distributed'):
- original.to_netcdf(filename, engine=engine)
@requires_zarr
diff --git a/xarray/tests/test_dtypes.py b/xarray/tests/test_dtypes.py
index 833df85f8af..292c60b4d05 100644
--- a/xarray/tests/test_dtypes.py
+++ b/xarray/tests/test_dtypes.py
@@ -50,3 +50,39 @@ def error():
def test_inf(obj):
assert dtypes.INF > obj
assert dtypes.NINF < obj
+
+
+@pytest.mark.parametrize("kind, expected", [
+ ('a', (np.dtype('O'), 'nan')), # dtype('S')
+ ('b', (np.float32, 'nan')), # dtype('int8')
+ ('B', (np.float32, 'nan')), # dtype('uint8')
+ ('c', (np.dtype('O'), 'nan')), # dtype('S1')
+ ('D', (np.complex128, '(nan+nanj)')), # dtype('complex128')
+ ('d', (np.float64, 'nan')), # dtype('float64')
+ ('e', (np.float16, 'nan')), # dtype('float16')
+ ('F', (np.complex64, '(nan+nanj)')), # dtype('complex64')
+ ('f', (np.float32, 'nan')), # dtype('float32')
+ ('h', (np.float32, 'nan')), # dtype('int16')
+ ('H', (np.float32, 'nan')), # dtype('uint16')
+ ('i', (np.float64, 'nan')), # dtype('int32')
+ ('I', (np.float64, 'nan')), # dtype('uint32')
+ ('l', (np.float64, 'nan')), # dtype('int64')
+ ('L', (np.float64, 'nan')), # dtype('uint64')
+ ('m', (np.timedelta64, 'NaT')), # dtype(' 0:
+ assert isinstance(da.data, dask_array_type)
+
+
@pytest.mark.parametrize('dim_num', [1, 2])
@pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_])
@pytest.mark.parametrize('dask', [False, True])
@pytest.mark.parametrize('func', ['sum', 'min', 'max', 'mean', 'var'])
+# TODO test cumsum, cumprod
@pytest.mark.parametrize('skipna', [False, True])
@pytest.mark.parametrize('aggdim', [None, 'x'])
def test_reduce(dim_num, dtype, dask, func, skipna, aggdim):
@@ -251,6 +269,9 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim):
if dask and not has_dask:
pytest.skip('requires dask')
+ if dask and skipna is False and dtype in [np.bool_]:
+ pytest.skip('dask does not compute object-typed array')
+
rtol = 1e-04 if dtype == np.float32 else 1e-05
da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask)
@@ -259,6 +280,7 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim):
# TODO: remove these after resolving
# https://github.com/dask/dask/issues/3245
with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', 'Mean of empty slice')
warnings.filterwarnings('ignore', 'All-NaN slice')
warnings.filterwarnings('ignore', 'invalid value encountered in')
@@ -272,6 +294,7 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim):
expected = getattr(np, func)(da.values, axis=axis)
actual = getattr(da, func)(skipna=skipna, dim=aggdim)
+ assert_dask_array(actual, dask)
assert np.allclose(actual.values, np.array(expected),
rtol=1.0e-4, equal_nan=True)
except (TypeError, AttributeError, ZeroDivisionError):
@@ -279,14 +302,21 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim):
# nanmean for object dtype
pass
- # make sure the compatiblility with pandas' results.
actual = getattr(da, func)(skipna=skipna, dim=aggdim)
- if func == 'var':
+
+ # for dask case, make sure the result is the same for numpy backend
+ expected = getattr(da.compute(), func)(skipna=skipna, dim=aggdim)
+ assert_allclose(actual, expected, rtol=rtol)
+
+ # make sure the compatiblility with pandas' results.
+ if func in ['var', 'std']:
expected = series_reduce(da, func, skipna=skipna, dim=aggdim,
ddof=0)
assert_allclose(actual, expected, rtol=rtol)
# also check ddof!=0 case
actual = getattr(da, func)(skipna=skipna, dim=aggdim, ddof=5)
+ if dask:
+ assert isinstance(da.data, dask_array_type)
expected = series_reduce(da, func, skipna=skipna, dim=aggdim,
ddof=5)
assert_allclose(actual, expected, rtol=rtol)
@@ -297,11 +327,14 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim):
# make sure the dtype argument
if func not in ['max', 'min']:
actual = getattr(da, func)(skipna=skipna, dim=aggdim, dtype=float)
+ assert_dask_array(actual, dask)
assert actual.dtype == float
# without nan
da = construct_dataarray(dim_num, dtype, contains_nan=False, dask=dask)
actual = getattr(da, func)(skipna=skipna)
+ if dask:
+ assert isinstance(da.data, dask_array_type)
expected = getattr(np, 'nan{}'.format(func))(da.values)
if actual.dtype == object:
assert actual.values == np.array(expected)
@@ -338,13 +371,6 @@ def test_argmin_max(dim_num, dtype, contains_nan, dask, func, skipna, aggdim):
with warnings.catch_warnings():
warnings.filterwarnings('ignore', 'All-NaN slice')
- if aggdim == 'y' and contains_nan and skipna:
- with pytest.raises(ValueError):
- actual = da.isel(**{
- aggdim: getattr(da, 'arg' + func)(
- dim=aggdim, skipna=skipna).compute()})
- return
-
actual = da.isel(**{aggdim: getattr(da, 'arg' + func)
(dim=aggdim, skipna=skipna).compute()})
expected = getattr(da, func)(dim=aggdim, skipna=skipna)
@@ -354,6 +380,7 @@ def test_argmin_max(dim_num, dtype, contains_nan, dask, func, skipna, aggdim):
def test_argmin_max_error():
da = construct_dataarray(2, np.bool_, contains_nan=True, dask=False)
+ da[0] = np.nan
with pytest.raises(ValueError):
da.argmin(dim='y')
@@ -388,3 +415,139 @@ def test_dask_rolling(axis, window, center):
with pytest.raises(ValueError):
rolling_window(dx, axis=axis, window=100, center=center,
fill_value=np.nan)
+
+
+@pytest.mark.skipif(not has_dask, reason='This is for dask.')
+@pytest.mark.parametrize('axis', [0, -1, 1])
+@pytest.mark.parametrize('edge_order', [1, 2])
+def test_dask_gradient(axis, edge_order):
+ import dask.array as da
+
+ array = np.array(np.random.randn(100, 5, 40))
+ x = np.exp(np.linspace(0, 1, array.shape[axis]))
+
+ darray = da.from_array(array, chunks=[(6, 30, 30, 20, 14), 5, 8])
+ expected = gradient(array, x, axis=axis, edge_order=edge_order)
+ actual = gradient(darray, x, axis=axis, edge_order=edge_order)
+
+ assert isinstance(actual, da.Array)
+ assert_array_equal(actual, expected)
+
+
+@pytest.mark.parametrize('dim_num', [1, 2])
+@pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_])
+@pytest.mark.parametrize('dask', [False, True])
+@pytest.mark.parametrize('func', ['sum', 'prod'])
+@pytest.mark.parametrize('aggdim', [None, 'x'])
+def test_min_count(dim_num, dtype, dask, func, aggdim):
+ if dask and not has_dask:
+ pytest.skip('requires dask')
+
+ da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask)
+ min_count = 3
+
+ actual = getattr(da, func)(dim=aggdim, skipna=True, min_count=min_count)
+
+ if LooseVersion(pd.__version__) >= LooseVersion('0.22.0'):
+ # min_count is only implenented in pandas > 0.22
+ expected = series_reduce(da, func, skipna=True, dim=aggdim,
+ min_count=min_count)
+ assert_allclose(actual, expected)
+
+ assert_dask_array(actual, dask)
+
+
+@pytest.mark.parametrize('func', ['sum', 'prod'])
+def test_min_count_dataset(func):
+ da = construct_dataarray(2, dtype=float, contains_nan=True, dask=False)
+ ds = Dataset({'var1': da}, coords={'scalar': 0})
+ actual = getattr(ds, func)(dim='x', skipna=True, min_count=3)['var1']
+ expected = getattr(ds['var1'], func)(dim='x', skipna=True, min_count=3)
+ assert_allclose(actual, expected)
+
+
+@pytest.mark.parametrize('dtype', [float, int, np.float32, np.bool_])
+@pytest.mark.parametrize('dask', [False, True])
+@pytest.mark.parametrize('func', ['sum', 'prod'])
+def test_multiple_dims(dtype, dask, func):
+ if dask and not has_dask:
+ pytest.skip('requires dask')
+ da = construct_dataarray(3, dtype, contains_nan=True, dask=dask)
+
+ actual = getattr(da, func)(('x', 'y'))
+ expected = getattr(getattr(da, func)('x'), func)('y')
+ assert_allclose(actual, expected)
+
+
+def test_docs():
+ # with min_count
+ actual = DataArray.sum.__doc__
+ expected = dedent("""\
+ Reduce this DataArray's data by applying `sum` along some dimension(s).
+
+ Parameters
+ ----------
+ dim : str or sequence of str, optional
+ Dimension(s) over which to apply `sum`.
+ axis : int or sequence of int, optional
+ Axis(es) over which to apply `sum`. Only one of the 'dim'
+ and 'axis' arguments can be supplied. If neither are supplied, then
+ `sum` is calculated over axes.
+ skipna : bool, optional
+ If True, skip missing values (as marked by NaN). By default, only
+ skips missing values for float dtypes; other dtypes either do not
+ have a sentinel missing value (int) or skipna=True has not been
+ implemented (object, datetime64 or timedelta64).
+ min_count : int, default None
+ The required number of valid values to perform the operation.
+ If fewer than min_count non-NA values are present the result will
+ be NA. New in version 0.10.8: Added with the default being None.
+ keep_attrs : bool, optional
+ If True, the attributes (`attrs`) will be copied from the original
+ object to the new one. If False (default), the new object will be
+ returned without attributes.
+ **kwargs : dict
+ Additional keyword arguments passed on to the appropriate array
+ function for calculating `sum` on this object's data.
+
+ Returns
+ -------
+ reduced : DataArray
+ New DataArray object with `sum` applied to its data and the
+ indicated dimension(s) removed.
+ """)
+ assert actual == expected
+
+ # without min_count
+ actual = DataArray.std.__doc__
+ expected = dedent("""\
+ Reduce this DataArray's data by applying `std` along some dimension(s).
+
+ Parameters
+ ----------
+ dim : str or sequence of str, optional
+ Dimension(s) over which to apply `std`.
+ axis : int or sequence of int, optional
+ Axis(es) over which to apply `std`. Only one of the 'dim'
+ and 'axis' arguments can be supplied. If neither are supplied, then
+ `std` is calculated over axes.
+ skipna : bool, optional
+ If True, skip missing values (as marked by NaN). By default, only
+ skips missing values for float dtypes; other dtypes either do not
+ have a sentinel missing value (int) or skipna=True has not been
+ implemented (object, datetime64 or timedelta64).
+ keep_attrs : bool, optional
+ If True, the attributes (`attrs`) will be copied from the original
+ object to the new one. If False (default), the new object will be
+ returned without attributes.
+ **kwargs : dict
+ Additional keyword arguments passed on to the appropriate array
+ function for calculating `std` on this object's data.
+
+ Returns
+ -------
+ reduced : DataArray
+ New DataArray object with `std` applied to its data and the
+ indicated dimension(s) removed.
+ """)
+ assert actual == expected
diff --git a/xarray/tests/test_extensions.py b/xarray/tests/test_extensions.py
index 24b710ae223..ffefa78aa34 100644
--- a/xarray/tests/test_extensions.py
+++ b/xarray/tests/test_extensions.py
@@ -4,7 +4,7 @@
import xarray as xr
-from . import TestCase, raises_regex
+from . import raises_regex
try:
import cPickle as pickle
@@ -21,7 +21,7 @@ def __init__(self, xarray_obj):
self.obj = xarray_obj
-class TestAccessor(TestCase):
+class TestAccessor(object):
def test_register(self):
@xr.register_dataset_accessor('demo')
diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py
index 8a1003f1ced..024c669bed9 100644
--- a/xarray/tests/test_formatting.py
+++ b/xarray/tests/test_formatting.py
@@ -7,10 +7,10 @@
from xarray.core import formatting
from xarray.core.pycompat import PY3
-from . import TestCase, raises_regex
+from . import raises_regex
-class TestFormatting(TestCase):
+class TestFormatting(object):
def test_get_indexer_at_least_n_items(self):
cases = [
@@ -45,7 +45,7 @@ def test_first_n_items(self):
for n in [3, 10, 13, 100, 200]:
actual = formatting.first_n_items(array, n)
expected = array.flat[:n]
- self.assertItemsEqual(expected, actual)
+ assert (expected == actual).all()
with raises_regex(ValueError, 'at least one item'):
formatting.first_n_items(array, 0)
@@ -55,7 +55,7 @@ def test_last_n_items(self):
for n in [3, 10, 13, 100, 200]:
actual = formatting.last_n_items(array, n)
expected = array.flat[-n:]
- self.assertItemsEqual(expected, actual)
+ assert (expected == actual).all()
with raises_regex(ValueError, 'at least one item'):
formatting.first_n_items(array, 0)
diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py
index 6dd14f5d6ad..8ace55be66b 100644
--- a/xarray/tests/test_groupby.py
+++ b/xarray/tests/test_groupby.py
@@ -5,9 +5,10 @@
import pytest
import xarray as xr
-from . import assert_identical
from xarray.core.groupby import _consolidate_slices
+from . import assert_identical
+
def test_consolidate_slices():
diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py
index 0d1045d35c0..701eefcb462 100644
--- a/xarray/tests/test_indexing.py
+++ b/xarray/tests/test_indexing.py
@@ -10,13 +10,12 @@
from xarray.core import indexing, nputils
from xarray.core.pycompat import native_int_types
-from . import (
- IndexerMaker, ReturnItem, TestCase, assert_array_equal, raises_regex)
+from . import IndexerMaker, ReturnItem, assert_array_equal, raises_regex
B = IndexerMaker(indexing.BasicIndexer)
-class TestIndexers(TestCase):
+class TestIndexers(object):
def set_to_zero(self, x, i):
x = x.copy()
x[i] = 0
@@ -25,7 +24,7 @@ def set_to_zero(self, x, i):
def test_expanded_indexer(self):
x = np.random.randn(10, 11, 12, 13, 14)
y = np.arange(5)
- I = ReturnItem() # noqa: E741 # allow ambiguous name
+ I = ReturnItem() # noqa
for i in [I[:], I[...], I[0, :, 10], I[..., 10], I[:5, ..., 0],
I[..., 0, :], I[y], I[y, y], I[..., y, y],
I[..., 0, 1, 2, 3, 4]]:
@@ -133,7 +132,7 @@ def test_indexer(data, x, expected_pos, expected_idx=None):
pd.MultiIndex.from_product([[1, 2], [-1, -2]]))
-class TestLazyArray(TestCase):
+class TestLazyArray(object):
def test_slice_slice(self):
I = ReturnItem() # noqa: E741 # allow ambiguous name
for size in [100, 99]:
@@ -248,7 +247,7 @@ def check_indexing(v_eager, v_lazy, indexers):
check_indexing(v_eager, v_lazy, indexers)
-class TestCopyOnWriteArray(TestCase):
+class TestCopyOnWriteArray(object):
def test_setitem(self):
original = np.arange(10)
wrapped = indexing.CopyOnWriteArray(original)
@@ -272,7 +271,7 @@ def test_index_scalar(self):
assert np.array(x[B[0]][B[()]]) == 'foo'
-class TestMemoryCachedArray(TestCase):
+class TestMemoryCachedArray(object):
def test_wrapper(self):
original = indexing.LazilyOuterIndexedArray(np.arange(10))
wrapped = indexing.MemoryCachedArray(original)
@@ -385,8 +384,9 @@ def test_vectorized_indexer():
np.arange(5, dtype=np.int64)))
-class Test_vectorized_indexer(TestCase):
- def setUp(self):
+class Test_vectorized_indexer(object):
+ @pytest.fixture(autouse=True)
+ def setup(self):
self.data = indexing.NumpyIndexingAdapter(np.random.randn(10, 12, 13))
self.indexers = [np.array([[0, 3, 2], ]),
np.array([[0, 3, 3], [4, 6, 7]]),
diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py
index 4a8f4e6eedf..624879cce1f 100644
--- a/xarray/tests/test_interp.py
+++ b/xarray/tests/test_interp.py
@@ -5,8 +5,11 @@
import pytest
import xarray as xr
-from xarray.tests import assert_allclose, assert_equal, requires_scipy
+from xarray.tests import (
+ assert_allclose, assert_equal, requires_cftime, requires_scipy)
+
from . import has_dask, has_scipy
+from ..coding.cftimeindex import _parse_array_of_cftime_strings
from .test_dataset import create_test_data
try:
@@ -490,3 +493,83 @@ def test_datetime_single_string():
expected = xr.DataArray(0.5)
assert_allclose(actual.drop('time'), expected)
+
+
+@requires_cftime
+@requires_scipy
+def test_cftime():
+ times = xr.cftime_range('2000', periods=24, freq='D')
+ da = xr.DataArray(np.arange(24), coords=[times], dims='time')
+
+ times_new = xr.cftime_range('2000-01-01T12:00:00', periods=3, freq='D')
+ actual = da.interp(time=times_new)
+ expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new], dims=['time'])
+
+ assert_allclose(actual, expected)
+
+
+@requires_cftime
+@requires_scipy
+def test_cftime_type_error():
+ times = xr.cftime_range('2000', periods=24, freq='D')
+ da = xr.DataArray(np.arange(24), coords=[times], dims='time')
+
+ times_new = xr.cftime_range('2000-01-01T12:00:00', periods=3, freq='D',
+ calendar='noleap')
+ with pytest.raises(TypeError):
+ da.interp(time=times_new)
+
+
+@requires_cftime
+@requires_scipy
+def test_cftime_list_of_strings():
+ from cftime import DatetimeProlepticGregorian
+
+ times = xr.cftime_range('2000', periods=24, freq='D')
+ da = xr.DataArray(np.arange(24), coords=[times], dims='time')
+
+ times_new = ['2000-01-01T12:00', '2000-01-02T12:00', '2000-01-03T12:00']
+ actual = da.interp(time=times_new)
+
+ times_new_array = _parse_array_of_cftime_strings(
+ np.array(times_new), DatetimeProlepticGregorian)
+ expected = xr.DataArray([0.5, 1.5, 2.5], coords=[times_new_array],
+ dims=['time'])
+
+ assert_allclose(actual, expected)
+
+
+@requires_cftime
+@requires_scipy
+def test_cftime_single_string():
+ from cftime import DatetimeProlepticGregorian
+
+ times = xr.cftime_range('2000', periods=24, freq='D')
+ da = xr.DataArray(np.arange(24), coords=[times], dims='time')
+
+ times_new = '2000-01-01T12:00'
+ actual = da.interp(time=times_new)
+
+ times_new_array = _parse_array_of_cftime_strings(
+ np.array(times_new), DatetimeProlepticGregorian)
+ expected = xr.DataArray(0.5, coords={'time': times_new_array})
+
+ assert_allclose(actual, expected)
+
+
+@requires_scipy
+def test_datetime_to_non_datetime_error():
+ da = xr.DataArray(np.arange(24), dims='time',
+ coords={'time': pd.date_range('2000-01-01', periods=24)})
+ with pytest.raises(TypeError):
+ da.interp(time=0.5)
+
+
+@requires_cftime
+@requires_scipy
+def test_cftime_to_non_cftime_error():
+ times = xr.cftime_range('2000', periods=24, freq='D')
+ da = xr.DataArray(np.arange(24), coords=[times], dims='time')
+
+ with pytest.raises(TypeError):
+ da.interp(time=0.5)
diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py
index 4d89be8ce55..300c490cff6 100644
--- a/xarray/tests/test_merge.py
+++ b/xarray/tests/test_merge.py
@@ -6,11 +6,11 @@
import xarray as xr
from xarray.core import merge
-from . import TestCase, raises_regex
+from . import raises_regex
from .test_dataset import create_test_data
-class TestMergeInternals(TestCase):
+class TestMergeInternals(object):
def test_broadcast_dimension_size(self):
actual = merge.broadcast_dimension_size(
[xr.Variable('x', [1]), xr.Variable('y', [2, 1])])
@@ -25,7 +25,7 @@ def test_broadcast_dimension_size(self):
[xr.Variable(('x', 'y'), [[1, 2]]), xr.Variable('y', [2])])
-class TestMergeFunction(TestCase):
+class TestMergeFunction(object):
def test_merge_arrays(self):
data = create_test_data()
actual = xr.merge([data.var1, data.var2])
@@ -130,7 +130,7 @@ def test_merge_no_conflicts_broadcast(self):
assert expected.identical(actual)
-class TestMergeMethod(TestCase):
+class TestMergeMethod(object):
def test_merge(self):
data = create_test_data()
@@ -195,7 +195,7 @@ def test_merge_compat(self):
with pytest.raises(xr.MergeError):
ds1.merge(ds2, compat='identical')
- with raises_regex(ValueError, 'compat=\S+ invalid'):
+ with raises_regex(ValueError, 'compat=.* invalid'):
ds1.merge(ds2, compat='foobar')
def test_merge_auto_align(self):
diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py
index 5c7e384c789..47224e55473 100644
--- a/xarray/tests/test_missing.py
+++ b/xarray/tests/test_missing.py
@@ -93,14 +93,14 @@ def test_interpolate_pd_compat():
@requires_scipy
-def test_scipy_methods_function():
- for method in ['barycentric', 'krog', 'pchip', 'spline', 'akima']:
- kwargs = {}
- # Note: Pandas does some wacky things with these methods and the full
- # integration tests wont work.
- da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True)
- actual = da.interpolate_na(method=method, dim='time', **kwargs)
- assert (da.count('time') <= actual.count('time')).all()
+@pytest.mark.parametrize('method', ['barycentric', 'krog',
+ 'pchip', 'spline', 'akima'])
+def test_scipy_methods_function(method):
+ # Note: Pandas does some wacky things with these methods and the full
+ # integration tests wont work.
+ da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True)
+ actual = da.interpolate_na(method=method, dim='time')
+ assert (da.count('time') <= actual.count('time')).all()
@requires_scipy
diff --git a/xarray/tests/test_options.py b/xarray/tests/test_options.py
index aed96f1acb6..4441375a1b1 100644
--- a/xarray/tests/test_options.py
+++ b/xarray/tests/test_options.py
@@ -4,6 +4,7 @@
import xarray
from xarray.core.options import OPTIONS
+from xarray.backends.file_manager import FILE_CACHE
def test_invalid_option_raises():
@@ -11,6 +12,38 @@ def test_invalid_option_raises():
xarray.set_options(not_a_valid_options=True)
+def test_display_width():
+ with pytest.raises(ValueError):
+ xarray.set_options(display_width=0)
+ with pytest.raises(ValueError):
+ xarray.set_options(display_width=-10)
+ with pytest.raises(ValueError):
+ xarray.set_options(display_width=3.5)
+
+
+def test_arithmetic_join():
+ with pytest.raises(ValueError):
+ xarray.set_options(arithmetic_join='invalid')
+ with xarray.set_options(arithmetic_join='exact'):
+ assert OPTIONS['arithmetic_join'] == 'exact'
+
+
+def test_enable_cftimeindex():
+ with pytest.raises(ValueError):
+ xarray.set_options(enable_cftimeindex=None)
+ with xarray.set_options(enable_cftimeindex=True):
+ assert OPTIONS['enable_cftimeindex']
+
+
+def test_file_cache_maxsize():
+ with pytest.raises(ValueError):
+ xarray.set_options(file_cache_maxsize=0)
+ original_size = FILE_CACHE.maxsize
+ with xarray.set_options(file_cache_maxsize=123):
+ assert FILE_CACHE.maxsize == 123
+ assert FILE_CACHE.maxsize == original_size
+
+
def test_nested_options():
original = OPTIONS['display_width']
with xarray.set_options(display_width=1):
diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py
index 51450c2838f..0af67532416 100644
--- a/xarray/tests/test_plot.py
+++ b/xarray/tests/test_plot.py
@@ -7,6 +7,7 @@
import pandas as pd
import pytest
+import xarray as xr
import xarray.plot as xplt
from xarray import DataArray
from xarray.coding.times import _import_cftime
@@ -16,9 +17,8 @@
import_seaborn, label_from_attrs)
from . import (
- TestCase, assert_array_equal, assert_equal, raises_regex,
- requires_matplotlib, requires_matplotlib2, requires_seaborn,
- requires_cftime)
+ assert_array_equal, assert_equal, raises_regex, requires_cftime,
+ requires_matplotlib, requires_matplotlib2, requires_seaborn)
# import mpl and change the backend before other mpl imports
try:
@@ -64,8 +64,10 @@ def easy_array(shape, start=0, stop=1):
@requires_matplotlib
-class PlotTestCase(TestCase):
- def tearDown(self):
+class PlotTestCase(object):
+ @pytest.fixture(autouse=True)
+ def setup(self):
+ yield
# Remove all matplotlib figures
plt.close('all')
@@ -87,7 +89,8 @@ def contourf_called(self, plotmethod):
class TestPlot(PlotTestCase):
- def setUp(self):
+ @pytest.fixture(autouse=True)
+ def setup_array(self):
self.darray = DataArray(easy_array((2, 3, 4)))
def test_label_from_attrs(self):
@@ -159,8 +162,8 @@ def test_2d_line_accepts_legend_kw(self):
self.darray[:, :, 0].plot.line(x='dim_0', add_legend=True)
assert plt.gca().get_legend()
# check whether legend title is set
- assert plt.gca().get_legend().get_title().get_text() \
- == 'dim_1'
+ assert (plt.gca().get_legend().get_title().get_text()
+ == 'dim_1')
def test_2d_line_accepts_x_kw(self):
self.darray[:, :, 0].plot.line(x='dim_0')
@@ -171,12 +174,12 @@ def test_2d_line_accepts_x_kw(self):
def test_2d_line_accepts_hue_kw(self):
self.darray[:, :, 0].plot.line(hue='dim_0')
- assert plt.gca().get_legend().get_title().get_text() \
- == 'dim_0'
+ assert (plt.gca().get_legend().get_title().get_text()
+ == 'dim_0')
plt.cla()
self.darray[:, :, 0].plot.line(hue='dim_1')
- assert plt.gca().get_legend().get_title().get_text() \
- == 'dim_1'
+ assert (plt.gca().get_legend().get_title().get_text()
+ == 'dim_1')
def test_2d_before_squeeze(self):
a = DataArray(easy_array((1, 5)))
@@ -267,6 +270,7 @@ def test_datetime_dimension(self):
assert ax.has_data()
@pytest.mark.slow
+ @pytest.mark.filterwarnings('ignore:tight_layout cannot')
def test_convenient_facetgrid(self):
a = easy_array((10, 15, 4))
d = DataArray(a, dims=['y', 'x', 'z'])
@@ -328,6 +332,7 @@ def test_plot_size(self):
self.darray.plot(aspect=1)
@pytest.mark.slow
+ @pytest.mark.filterwarnings('ignore:tight_layout cannot')
def test_convenient_facetgrid_4d(self):
a = easy_array((10, 15, 2, 3))
d = DataArray(a, dims=['y', 'x', 'columns', 'rows'])
@@ -346,6 +351,7 @@ def test_coord_with_interval(self):
class TestPlot1D(PlotTestCase):
+ @pytest.fixture(autouse=True)
def setUp(self):
d = [0, 1.1, 0, 2]
self.darray = DataArray(
@@ -358,7 +364,7 @@ def test_xlabel_is_index_name(self):
def test_no_label_name_on_x_axis(self):
self.darray.plot(y='period')
- self.assertEqual('', plt.gca().get_xlabel())
+ assert '' == plt.gca().get_xlabel()
def test_no_label_name_on_y_axis(self):
self.darray.plot()
@@ -431,6 +437,7 @@ def test_coord_with_interval_step(self):
class TestPlotHistogram(PlotTestCase):
+ @pytest.fixture(autouse=True)
def setUp(self):
self.darray = DataArray(easy_array((2, 3, 4)))
@@ -470,7 +477,8 @@ def test_hist_coord_with_interval(self):
@requires_matplotlib
-class TestDetermineCmapParams(TestCase):
+class TestDetermineCmapParams(object):
+ @pytest.fixture(autouse=True)
def setUp(self):
self.data = np.linspace(0, 1, num=100)
@@ -491,6 +499,21 @@ def test_center(self):
assert cmap_params['levels'] is None
assert cmap_params['norm'] is None
+ def test_cmap_sequential_option(self):
+ with xr.set_options(cmap_sequential='magma'):
+ cmap_params = _determine_cmap_params(self.data)
+ assert cmap_params['cmap'] == 'magma'
+
+ def test_cmap_sequential_explicit_option(self):
+ with xr.set_options(cmap_sequential=mpl.cm.magma):
+ cmap_params = _determine_cmap_params(self.data)
+ assert cmap_params['cmap'] == mpl.cm.magma
+
+ def test_cmap_divergent_option(self):
+ with xr.set_options(cmap_divergent='magma'):
+ cmap_params = _determine_cmap_params(self.data, center=0.5)
+ assert cmap_params['cmap'] == 'magma'
+
def test_nan_inf_are_ignored(self):
cmap_params1 = _determine_cmap_params(self.data)
data = self.data
@@ -626,9 +649,30 @@ def test_divergentcontrol(self):
assert cmap_params['vmax'] == 0.6
assert cmap_params['cmap'] == "viridis"
+ def test_norm_sets_vmin_vmax(self):
+ vmin = self.data.min()
+ vmax = self.data.max()
+
+ for norm, extend in zip([mpl.colors.LogNorm(),
+ mpl.colors.LogNorm(vmin + 1, vmax - 1),
+ mpl.colors.LogNorm(None, vmax - 1),
+ mpl.colors.LogNorm(vmin + 1, None)],
+ ['neither', 'both', 'max', 'min']):
+
+ test_min = vmin if norm.vmin is None else norm.vmin
+ test_max = vmax if norm.vmax is None else norm.vmax
+
+ cmap_params = _determine_cmap_params(self.data, norm=norm)
+
+ assert cmap_params['vmin'] == test_min
+ assert cmap_params['vmax'] == test_max
+ assert cmap_params['extend'] == extend
+ assert cmap_params['norm'] == norm
+
@requires_matplotlib
-class TestDiscreteColorMap(TestCase):
+class TestDiscreteColorMap(object):
+ @pytest.fixture(autouse=True)
def setUp(self):
x = np.arange(start=0, stop=10, step=2)
y = np.arange(start=9, stop=-7, step=-3)
@@ -662,10 +706,10 @@ def test_build_discrete_cmap(self):
@pytest.mark.slow
def test_discrete_colormap_list_of_levels(self):
- for extend, levels in [('max', [-1, 2, 4, 8, 10]), ('both',
- [2, 5, 10, 11]),
- ('neither', [0, 5, 10, 15]), ('min',
- [2, 5, 10, 15])]:
+ for extend, levels in [('max', [-1, 2, 4, 8, 10]),
+ ('both', [2, 5, 10, 11]),
+ ('neither', [0, 5, 10, 15]),
+ ('min', [2, 5, 10, 15])]:
for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']:
primitive = getattr(self.darray.plot, kind)(levels=levels)
assert_array_equal(levels, primitive.norm.boundaries)
@@ -679,10 +723,10 @@ def test_discrete_colormap_list_of_levels(self):
@pytest.mark.slow
def test_discrete_colormap_int_levels(self):
- for extend, levels, vmin, vmax in [('neither', 7, None,
- None), ('neither', 7, None, 20),
- ('both', 7, 4, 8), ('min', 10, 4,
- 15)]:
+ for extend, levels, vmin, vmax in [('neither', 7, None, None),
+ ('neither', 7, None, 20),
+ ('both', 7, 4, 8),
+ ('min', 10, 4, 15)]:
for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']:
primitive = getattr(self.darray.plot, kind)(
levels=levels, vmin=vmin, vmax=vmax)
@@ -708,8 +752,13 @@ def test_discrete_colormap_list_levels_and_vmin_or_vmax(self):
assert primitive.norm.vmax == max(levels)
assert primitive.norm.vmin == min(levels)
+ def test_discrete_colormap_provided_boundary_norm(self):
+ norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4)
+ primitive = self.darray.plot.contourf(norm=norm)
+ np.testing.assert_allclose(primitive.levels, norm.boundaries)
+
-class Common2dMixin:
+class Common2dMixin(object):
"""
Common tests for 2d plotting go here.
@@ -717,6 +766,7 @@ class Common2dMixin:
Should have the same name as the method.
"""
+ @pytest.fixture(autouse=True)
def setUp(self):
da = DataArray(easy_array((10, 15), start=-1),
dims=['y', 'x'],
@@ -765,6 +815,24 @@ def test_nonnumeric_index_raises_typeerror(self):
def test_can_pass_in_axis(self):
self.pass_in_axis(self.plotmethod)
+ def test_xyincrease_defaults(self):
+
+ # With default settings the axis must be ordered regardless
+ # of the coords order.
+ self.plotfunc(DataArray(easy_array((3, 2)), coords=[[1, 2, 3],
+ [1, 2]]))
+ bounds = plt.gca().get_ylim()
+ assert bounds[0] < bounds[1]
+ bounds = plt.gca().get_xlim()
+ assert bounds[0] < bounds[1]
+ # Inverted coords
+ self.plotfunc(DataArray(easy_array((3, 2)), coords=[[3, 2, 1],
+ [2, 1]]))
+ bounds = plt.gca().get_ylim()
+ assert bounds[0] < bounds[1]
+ bounds = plt.gca().get_xlim()
+ assert bounds[0] < bounds[1]
+
def test_xyincrease_false_changes_axes(self):
self.plotmethod(xincrease=False, yincrease=False)
xlim = plt.gca().get_xlim()
@@ -796,10 +864,13 @@ def test_plot_nans(self):
clim2 = self.plotfunc(x2).get_clim()
assert clim1 == clim2
+ @pytest.mark.filterwarnings('ignore::UserWarning')
+ @pytest.mark.filterwarnings('ignore:invalid value encountered')
def test_can_plot_all_nans(self):
# regression test for issue #1780
self.plotfunc(DataArray(np.full((2, 2), np.nan)))
+ @pytest.mark.filterwarnings('ignore: Attempting to set')
def test_can_plot_axis_size_one(self):
if self.plotfunc.__name__ not in ('contour', 'contourf'):
self.plotfunc(DataArray(np.ones((1, 1))))
@@ -991,6 +1062,7 @@ def test_2d_function_and_method_signature_same(self):
del func_sig['darray']
assert func_sig == method_sig
+ @pytest.mark.filterwarnings('ignore:tight_layout cannot')
def test_convenient_facetgrid(self):
a = easy_array((10, 15, 4))
d = DataArray(a, dims=['y', 'x', 'z'])
@@ -1022,6 +1094,7 @@ def test_convenient_facetgrid(self):
else:
assert '' == ax.get_xlabel()
+ @pytest.mark.filterwarnings('ignore:tight_layout cannot')
def test_convenient_facetgrid_4d(self):
a = easy_array((10, 15, 2, 3))
d = DataArray(a, dims=['y', 'x', 'columns', 'rows'])
@@ -1031,6 +1104,19 @@ def test_convenient_facetgrid_4d(self):
for ax in g.axes.flat:
assert ax.has_data()
+ @pytest.mark.filterwarnings('ignore:This figure includes')
+ def test_facetgrid_map_only_appends_mappables(self):
+ a = easy_array((10, 15, 2, 3))
+ d = DataArray(a, dims=['y', 'x', 'columns', 'rows'])
+ g = self.plotfunc(d, x='x', y='y', col='columns', row='rows')
+
+ expected = g._mappables
+
+ g.map(lambda: plt.plot(1, 1))
+ actual = g._mappables
+
+ assert expected == actual
+
def test_facetgrid_cmap(self):
# Regression test for GH592
data = (np.random.random(size=(20, 25, 12)) + np.linspace(-3, 3, 12))
@@ -1051,6 +1137,15 @@ def test_2d_coord_with_interval(self):
for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']:
getattr(gp.plot, kind)()
+ def test_colormap_error_norm_and_vmin_vmax(self):
+ norm = mpl.colors.LogNorm(0.1, 1e1)
+
+ with pytest.raises(ValueError):
+ self.darray.plot(norm=norm, vmin=2)
+
+ with pytest.raises(ValueError):
+ self.darray.plot(norm=norm, vmax=2)
+
@pytest.mark.slow
class TestContourf(Common2dMixin, PlotTestCase):
@@ -1113,23 +1208,23 @@ def test_colors(self):
def _color_as_tuple(c):
return tuple(c[:3])
+ # with single color, we don't want rgb array
artist = self.plotmethod(colors='k')
- assert _color_as_tuple(artist.cmap.colors[0]) == \
- (0.0, 0.0, 0.0)
+ assert artist.cmap.colors[0] == 'k'
artist = self.plotmethod(colors=['k', 'b'])
- assert _color_as_tuple(artist.cmap.colors[1]) == \
- (0.0, 0.0, 1.0)
+ assert (_color_as_tuple(artist.cmap.colors[1]) ==
+ (0.0, 0.0, 1.0))
artist = self.darray.plot.contour(
levels=[-0.5, 0., 0.5, 1.], colors=['k', 'r', 'w', 'b'])
- assert _color_as_tuple(artist.cmap.colors[1]) == \
- (1.0, 0.0, 0.0)
- assert _color_as_tuple(artist.cmap.colors[2]) == \
- (1.0, 1.0, 1.0)
+ assert (_color_as_tuple(artist.cmap.colors[1]) ==
+ (1.0, 0.0, 0.0))
+ assert (_color_as_tuple(artist.cmap.colors[2]) ==
+ (1.0, 1.0, 1.0))
# the last color is now under "over"
- assert _color_as_tuple(artist.cmap._rgba_over) == \
- (0.0, 0.0, 1.0)
+ assert (_color_as_tuple(artist.cmap._rgba_over) ==
+ (0.0, 0.0, 1.0))
def test_cmap_and_color_both(self):
with pytest.raises(ValueError):
@@ -1306,13 +1401,26 @@ def test_imshow_rgb_values_in_valid_range(self):
assert out.dtype == np.uint8
assert (out[..., :3] == da.values).all() # Compare without added alpha
+ @pytest.mark.filterwarnings('ignore:Several dimensions of this array')
def test_regression_rgb_imshow_dim_size_one(self):
# Regression: https://github.com/pydata/xarray/issues/1966
da = DataArray(easy_array((1, 3, 3), start=0.0, stop=1.0))
da.plot.imshow()
+ def test_origin_overrides_xyincrease(self):
+ da = DataArray(easy_array((3, 2)), coords=[[-2, 0, 2], [-1, 1]])
+ da.plot.imshow(origin='upper')
+ assert plt.xlim()[0] < 0
+ assert plt.ylim()[1] < 0
+
+ plt.clf()
+ da.plot.imshow(origin='lower')
+ assert plt.xlim()[0] < 0
+ assert plt.ylim()[0] < 0
+
class TestFacetGrid(PlotTestCase):
+ @pytest.fixture(autouse=True)
def setUp(self):
d = easy_array((10, 15, 3))
self.darray = DataArray(
@@ -1484,7 +1592,9 @@ def test_num_ticks(self):
@pytest.mark.slow
def test_map(self):
+ assert self.g._finalized is False
self.g.map(plt.contourf, 'x', 'y', Ellipsis)
+ assert self.g._finalized is True
self.g.map(lambda: None)
@pytest.mark.slow
@@ -1538,7 +1648,9 @@ def test_facetgrid_polar(self):
sharey=False)
+@pytest.mark.filterwarnings('ignore:tight_layout cannot')
class TestFacetGrid4d(PlotTestCase):
+ @pytest.fixture(autouse=True)
def setUp(self):
a = easy_array((10, 15, 3, 2))
darray = DataArray(a, dims=['y', 'x', 'col', 'row'])
@@ -1565,7 +1677,9 @@ def test_default_labels(self):
assert substring_in_axes(label, ax)
+@pytest.mark.filterwarnings('ignore:tight_layout cannot')
class TestFacetedLinePlots(PlotTestCase):
+ @pytest.fixture(autouse=True)
def setUp(self):
self.darray = DataArray(np.random.randn(10, 6, 3, 4),
dims=['hue', 'x', 'col', 'row'],
@@ -1646,6 +1760,7 @@ def test_wrong_num_of_dimensions(self):
class TestDatetimePlot(PlotTestCase):
+ @pytest.fixture(autouse=True)
def setUp(self):
'''
Create a DataArray with a time-axis that contains datetime objects.
diff --git a/xarray/tests/test_tutorial.py b/xarray/tests/test_tutorial.py
index d550a85e8ce..083ec5ee72f 100644
--- a/xarray/tests/test_tutorial.py
+++ b/xarray/tests/test_tutorial.py
@@ -2,15 +2,17 @@
import os
+import pytest
+
from xarray import DataArray, tutorial
from xarray.core.pycompat import suppress
-from . import TestCase, assert_identical, network
+from . import assert_identical, network
@network
-class TestLoadDataset(TestCase):
-
+class TestLoadDataset(object):
+ @pytest.fixture(autouse=True)
def setUp(self):
self.testfile = 'tiny'
self.testfilepath = os.path.expanduser(os.sep.join(
diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py
index 195bb36e36e..6941efb1c6e 100644
--- a/xarray/tests/test_ufuncs.py
+++ b/xarray/tests/test_ufuncs.py
@@ -8,9 +8,9 @@
import xarray as xr
import xarray.ufuncs as xu
-from . import (
- assert_array_equal, assert_identical as assert_identical_, mock,
- raises_regex, requires_np113)
+from . import assert_array_equal
+from . import assert_identical as assert_identical_
+from . import mock, raises_regex, requires_np113
def assert_identical(a, b):
diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py
index ed8045b78e4..34f401dd243 100644
--- a/xarray/tests/test_utils.py
+++ b/xarray/tests/test_utils.py
@@ -6,19 +6,21 @@
import pandas as pd
import pytest
+import xarray as xr
from xarray.coding.cftimeindex import CFTimeIndex
from xarray.core import duck_array_ops, utils
from xarray.core.options import set_options
from xarray.core.pycompat import OrderedDict
from xarray.core.utils import either_dict_or_kwargs
+from xarray.testing import assert_identical
from . import (
- TestCase, assert_array_equal, has_cftime, has_cftime_or_netCDF4,
+ assert_array_equal, has_cftime, has_cftime_or_netCDF4, requires_cftime,
requires_dask)
from .test_coding_times import _all_cftime_date_types
-class TestAlias(TestCase):
+class TestAlias(object):
def test(self):
def new_method():
pass
@@ -96,7 +98,7 @@ def test_multiindex_from_product_levels_non_unique():
np.testing.assert_array_equal(result.levels[1], [1, 2])
-class TestArrayEquiv(TestCase):
+class TestArrayEquiv(object):
def test_0d(self):
# verify our work around for pd.isnull not working for 0-dimensional
# object arrays
@@ -106,8 +108,9 @@ def test_0d(self):
assert not duck_array_ops.array_equiv(0, np.array(1, dtype=object))
-class TestDictionaries(TestCase):
- def setUp(self):
+class TestDictionaries(object):
+ @pytest.fixture(autouse=True)
+ def setup(self):
self.x = {'a': 'A', 'b': 'B'}
self.y = {'c': 'C', 'b': 'B'}
self.z = {'a': 'Z'}
@@ -174,7 +177,7 @@ def test_frozen(self):
def test_sorted_keys_dict(self):
x = {'a': 1, 'b': 2, 'c': 3}
y = utils.SortedKeysDict(x)
- self.assertItemsEqual(y, ['a', 'b', 'c'])
+ assert list(y) == ['a', 'b', 'c']
assert repr(utils.SortedKeysDict()) == \
"SortedKeysDict({})"
@@ -189,7 +192,7 @@ def test_chain_map(self):
m['x'] = 100
assert m['x'] == 100
assert m.maps[0]['x'] == 100
- self.assertItemsEqual(['x', 'y', 'z'], m)
+ assert set(m) == {'x', 'y', 'z'}
def test_repr_object():
@@ -197,7 +200,7 @@ def test_repr_object():
assert repr(obj) == 'foo'
-class Test_is_uniform_and_sorted(TestCase):
+class Test_is_uniform_and_sorted(object):
def test_sorted_uniform(self):
assert utils.is_uniform_spaced(np.arange(5))
@@ -218,7 +221,7 @@ def test_relative_tolerance(self):
assert utils.is_uniform_spaced([0, 0.97, 2], rtol=0.1)
-class Test_hashable(TestCase):
+class Test_hashable(object):
def test_hashable(self):
for v in [False, 1, (2, ), (3, 4), 'four']:
@@ -263,3 +266,42 @@ def test_either_dict_or_kwargs():
with pytest.raises(ValueError, match=r'foo'):
result = either_dict_or_kwargs(dict(a=1), dict(a=1), 'foo')
+
+
+def test_datetime_to_numeric_datetime64():
+ times = pd.date_range('2000', periods=5, freq='7D')
+ da = xr.DataArray(times, coords=[times], dims=['time'])
+ result = utils.datetime_to_numeric(da, datetime_unit='h')
+ expected = 24 * xr.DataArray(np.arange(0, 35, 7), coords=da.coords)
+ assert_identical(result, expected)
+
+ offset = da.isel(time=1)
+ result = utils.datetime_to_numeric(da, offset=offset, datetime_unit='h')
+ expected = 24 * xr.DataArray(np.arange(-7, 28, 7), coords=da.coords)
+ assert_identical(result, expected)
+
+ dtype = np.float32
+ result = utils.datetime_to_numeric(da, datetime_unit='h', dtype=dtype)
+ expected = 24 * xr.DataArray(
+ np.arange(0, 35, 7), coords=da.coords).astype(dtype)
+ assert_identical(result, expected)
+
+
+@requires_cftime
+def test_datetime_to_numeric_cftime():
+ times = xr.cftime_range('2000', periods=5, freq='7D')
+ da = xr.DataArray(times, coords=[times], dims=['time'])
+ result = utils.datetime_to_numeric(da, datetime_unit='h')
+ expected = 24 * xr.DataArray(np.arange(0, 35, 7), coords=da.coords)
+ assert_identical(result, expected)
+
+ offset = da.isel(time=1)
+ result = utils.datetime_to_numeric(da, offset=offset, datetime_unit='h')
+ expected = 24 * xr.DataArray(np.arange(-7, 28, 7), coords=da.coords)
+ assert_identical(result, expected)
+
+ dtype = np.float32
+ result = utils.datetime_to_numeric(da, datetime_unit='h', dtype=dtype)
+ expected = 24 * xr.DataArray(
+ np.arange(0, 35, 7), coords=da.coords).astype(dtype)
+ assert_identical(result, expected)
diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py
index cdb578aff6c..52289a15d72 100644
--- a/xarray/tests/test_variable.py
+++ b/xarray/tests/test_variable.py
@@ -1,12 +1,11 @@
from __future__ import absolute_import, division, print_function
-from collections import namedtuple
+import warnings
from copy import copy, deepcopy
from datetime import datetime, timedelta
from distutils.version import LooseVersion
from textwrap import dedent
-import warnings
import numpy as np
import pandas as pd
@@ -26,11 +25,11 @@
from xarray.tests import requires_bottleneck
from . import (
- TestCase, assert_allclose, assert_array_equal, assert_equal,
- assert_identical, raises_regex, requires_dask, source_ndarray)
+ assert_allclose, assert_array_equal, assert_equal, assert_identical,
+ raises_regex, requires_dask, source_ndarray)
-class VariableSubclassTestCases(object):
+class VariableSubclassobjects(object):
def test_properties(self):
data = 0.5 * np.arange(10)
v = self.cls(['time'], data, {'foo': 'bar'})
@@ -480,20 +479,20 @@ def test_concat_mixed_dtypes(self):
assert_identical(expected, actual)
assert actual.dtype == object
- def test_copy(self):
+ @pytest.mark.parametrize('deep', [True, False])
+ def test_copy(self, deep):
v = self.cls('x', 0.5 * np.arange(10), {'foo': 'bar'})
- for deep in [True, False]:
- w = v.copy(deep=deep)
- assert type(v) is type(w)
- assert_identical(v, w)
- assert v.dtype == w.dtype
- if self.cls is Variable:
- if deep:
- assert source_ndarray(v.values) is not \
- source_ndarray(w.values)
- else:
- assert source_ndarray(v.values) is \
- source_ndarray(w.values)
+ w = v.copy(deep=deep)
+ assert type(v) is type(w)
+ assert_identical(v, w)
+ assert v.dtype == w.dtype
+ if self.cls is Variable:
+ if deep:
+ assert (source_ndarray(v.values) is not
+ source_ndarray(w.values))
+ else:
+ assert (source_ndarray(v.values) is
+ source_ndarray(w.values))
assert_identical(v, copy(v))
def test_copy_index(self):
@@ -506,6 +505,34 @@ def test_copy_index(self):
assert isinstance(w.to_index(), pd.MultiIndex)
assert_array_equal(v._data.array, w._data.array)
+ def test_copy_with_data(self):
+ orig = Variable(('x', 'y'), [[1.5, 2.0], [3.1, 4.3]], {'foo': 'bar'})
+ new_data = np.array([[2.5, 5.0], [7.1, 43]])
+ actual = orig.copy(data=new_data)
+ expected = orig.copy()
+ expected.data = new_data
+ assert_identical(expected, actual)
+
+ def test_copy_with_data_errors(self):
+ orig = Variable(('x', 'y'), [[1.5, 2.0], [3.1, 4.3]], {'foo': 'bar'})
+ new_data = [2.5, 5.0]
+ with raises_regex(ValueError, 'must match shape of object'):
+ orig.copy(data=new_data)
+
+ def test_copy_index_with_data(self):
+ orig = IndexVariable('x', np.arange(5))
+ new_data = np.arange(5, 10)
+ actual = orig.copy(data=new_data)
+ expected = orig.copy()
+ expected.data = new_data
+ assert_identical(expected, actual)
+
+ def test_copy_index_with_data_errors(self):
+ orig = IndexVariable('x', np.arange(5))
+ new_data = np.arange(5, 20)
+ with raises_regex(ValueError, 'must match shape of object'):
+ orig.copy(data=new_data)
+
def test_real_and_imag(self):
v = self.cls('x', np.arange(3) - 1j * np.arange(3), {'foo': 'bar'})
expected_re = self.cls('x', np.arange(3), {'foo': 'bar'})
@@ -787,10 +814,11 @@ def test_rolling_window(self):
v_loaded[0] = 1.0
-class TestVariable(TestCase, VariableSubclassTestCases):
+class TestVariable(VariableSubclassobjects):
cls = staticmethod(Variable)
- def setUp(self):
+ @pytest.fixture(autouse=True)
+ def setup(self):
self.d = np.random.random((10, 3)).astype(np.float64)
def test_data_and_values(self):
@@ -938,21 +966,6 @@ def test_as_variable(self):
assert not isinstance(ds['x'], Variable)
assert isinstance(as_variable(ds['x']), Variable)
- FakeVariable = namedtuple('FakeVariable', 'values dims')
- fake_xarray = FakeVariable(expected.values, expected.dims)
- assert_identical(expected, as_variable(fake_xarray))
-
- FakeVariable = namedtuple('FakeVariable', 'data dims')
- fake_xarray = FakeVariable(expected.data, expected.dims)
- assert_identical(expected, as_variable(fake_xarray))
-
- FakeVariable = namedtuple('FakeVariable',
- 'data values dims attrs encoding')
- fake_xarray = FakeVariable(expected_extra.data, expected_extra.values,
- expected_extra.dims, expected_extra.attrs,
- expected_extra.encoding)
- assert_identical(expected_extra, as_variable(fake_xarray))
-
xarray_tuple = (expected_extra.dims, expected_extra.values,
expected_extra.attrs, expected_extra.encoding)
assert_identical(expected_extra, as_variable(xarray_tuple))
@@ -1503,8 +1516,8 @@ def test_reduce_funcs(self):
assert_identical(v.all(dim='x'), Variable([], False))
v = Variable('t', pd.date_range('2000-01-01', periods=3))
- with pytest.raises(NotImplementedError):
- v.argmax(skipna=True)
+ assert v.argmax(skipna=True) == 2
+
assert_identical(
v.max(), Variable([], pd.Timestamp('2000-01-03')))
@@ -1639,7 +1652,7 @@ def assert_assigned_2d(array, key_x, key_y, values):
@requires_dask
-class TestVariableWithDask(TestCase, VariableSubclassTestCases):
+class TestVariableWithDask(VariableSubclassobjects):
cls = staticmethod(lambda *args: Variable(*args).chunk())
@pytest.mark.xfail
@@ -1667,7 +1680,7 @@ def test_getitem_1d_fancy(self):
def test_equals_all_dtypes(self):
import dask
- if '0.18.2' <= LooseVersion(dask.__version__) < '0.18.3':
+ if '0.18.2' <= LooseVersion(dask.__version__) < '0.19.1':
pytest.xfail('https://github.com/pydata/xarray/issues/2318')
super(TestVariableWithDask, self).test_equals_all_dtypes()
@@ -1679,7 +1692,7 @@ def test_getitem_with_mask_nd_indexer(self):
self.cls(('x', 'y'), [[0, -1], [-1, 2]]))
-class TestIndexVariable(TestCase, VariableSubclassTestCases):
+class TestIndexVariable(VariableSubclassobjects):
cls = staticmethod(IndexVariable)
def test_init(self):
@@ -1792,7 +1805,7 @@ def test_rolling_window(self):
super(TestIndexVariable, self).test_rolling_window()
-class TestAsCompatibleData(TestCase):
+class TestAsCompatibleData(object):
def test_unchanged_types(self):
types = (np.asarray, PandasIndexAdapter, LazilyOuterIndexedArray)
for t in types:
@@ -1933,9 +1946,10 @@ def test_raise_no_warning_for_nan_in_binary_ops():
assert len(record) == 0
-class TestBackendIndexing(TestCase):
+class TestBackendIndexing(object):
""" Make sure all the array wrappers can be indexed. """
+ @pytest.fixture(autouse=True)
def setUp(self):
self.d = np.random.random((10, 3)).astype(np.float64)