diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index fe6dead5434..b7d9ac69ea5 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -20,6 +20,7 @@ dependencies: - numba - numpy>=1.17 - pandas>=1.0 + - pooch - pip - pydata-sphinx-theme>=0.4.3 - rasterio>=1.1 diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 57498fa5700..308dd02080f 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -27,6 +27,7 @@ dependencies: - pandas - pint - pip=20.2 + - pooch - pre-commit - pseudonetcdf - pydap diff --git a/doc/examples/visualization_gallery.ipynb b/doc/examples/visualization_gallery.ipynb index f8d5b1ae458..249c1b7ee94 100644 --- a/doc/examples/visualization_gallery.ipynb +++ b/doc/examples/visualization_gallery.ipynb @@ -209,8 +209,7 @@ "metadata": {}, "outputs": [], "source": [ - "url = 'https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif'\n", - "da = xr.open_rasterio(url)\n", + "da = xr.tutorial.open_dataset(\"RGB.byte\").data\n", "\n", "# The data is in UTM projection. We have to set it manually until\n", "# https://github.com/SciTools/cartopy/issues/813 is implemented\n", @@ -246,8 +245,7 @@ "from rasterio.warp import transform\n", "import numpy as np\n", "\n", - "url = 'https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif'\n", - "da = xr.open_rasterio(url)\n", + "da = xr.tutorial.open_dataset(\"RGB.byte\").data\n", "\n", "# Compute the lon/lat coordinates with rasterio.warp.transform\n", "ny, nx = len(da['y']), len(da['x'])\n", diff --git a/setup.cfg b/setup.cfg index 80785d6f108..f6565b45e31 100644 --- a/setup.cfg +++ b/setup.cfg @@ -208,6 +208,8 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-pint.*] ignore_missing_imports = True +[mypy-pooch.*] +ignore_missing_imports = True [mypy-PseudoNetCDF.*] ignore_missing_imports = True [mypy-pydap.*] @@ -233,6 +235,7 @@ ignore_missing_imports = True [mypy-xarray.core.pycompat] ignore_errors = True + [aliases] test = pytest diff --git a/xarray/tests/test_tutorial.py b/xarray/tests/test_tutorial.py index a2eb159f624..9b9dfe83867 100644 --- a/xarray/tests/test_tutorial.py +++ b/xarray/tests/test_tutorial.py @@ -1,5 +1,4 @@ import os -from contextlib import suppress import pytest @@ -13,20 +12,26 @@ class TestLoadDataset: @pytest.fixture(autouse=True) def setUp(self): self.testfile = "tiny" - self.testfilepath = os.path.expanduser( - os.sep.join(("~", ".xarray_tutorial_data", self.testfile)) - ) - with suppress(OSError): - os.remove(f"{self.testfilepath}.nc") - with suppress(OSError): - os.remove(f"{self.testfilepath}.md5") - - def test_download_from_github(self): + + def test_download_from_github(self, tmp_path, monkeypatch): + monkeypatch.setenv("XDG_CACHE_DIR", os.fspath(tmp_path)) + ds = tutorial.open_dataset(self.testfile).load() tiny = DataArray(range(5), name="tiny").to_dataset() assert_identical(ds, tiny) - def test_download_from_github_load_without_cache(self): + def test_download_from_github_load_without_cache(self, tmp_path, monkeypatch): + monkeypatch.setenv("XDG_CACHE_DIR", os.fspath(tmp_path)) + ds_nocache = tutorial.open_dataset(self.testfile, cache=False).load() ds_cache = tutorial.open_dataset(self.testfile).load() assert_identical(ds_cache, ds_nocache) + + def test_download_rasterio_from_github_load_without_cache( + self, tmp_path, monkeypatch + ): + monkeypatch.setenv("XDG_CACHE_DIR", os.fspath(tmp_path)) + + ds_nocache = tutorial.open_dataset("RGB.byte", cache=False).load() + ds_cache = tutorial.open_dataset("RGB.byte", cache=True).load() + assert_identical(ds_cache, ds_nocache) diff --git a/xarray/tutorial.py b/xarray/tutorial.py index 055be36d80b..351113c31c0 100644 --- a/xarray/tutorial.py +++ b/xarray/tutorial.py @@ -5,33 +5,45 @@ * building tutorials in the documentation. """ -import hashlib -import os as _os -from urllib.request import urlretrieve +import os +import pathlib import numpy as np from .backends.api import open_dataset as _open_dataset +from .backends.rasterio_ import open_rasterio from .core.dataarray import DataArray from .core.dataset import Dataset -_default_cache_dir = _os.sep.join(("~", ".xarray_tutorial_data")) +def _open_rasterio(path, engine=None, **kwargs): + data = open_rasterio(path, **kwargs) + name = data.name if data.name is not None else "data" + return data.to_dataset(name=name) -def file_md5_checksum(fname): - hash_md5 = hashlib.md5() - with open(fname, "rb") as f: - hash_md5.update(f.read()) - return hash_md5.hexdigest() + +_default_cache_dir_name = "xarray_tutorial_data" +base_url = "https://github.com/pydata/xarray-data" +version = "master" + + +external_urls = { + "RGB.byte": ( + "rasterio", + "https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif", + ), +} +overrides = { + "rasterio": _open_rasterio, +} # idea borrowed from Seaborn def open_dataset( name, + engine=None, cache=True, - cache_dir=_default_cache_dir, - github_url="https://github.com/pydata/xarray-data", - branch="master", + cache_dir=None, **kws, ): """ @@ -42,61 +54,62 @@ def open_dataset( Parameters ---------- name : str - Name of the file containing the dataset. If no suffix is given, assumed - to be netCDF ('.nc' is appended) + Name of the file containing the dataset. e.g. 'air_temperature' - cache_dir : str, optional + engine : str, optional + The engine to use. + cache_dir : path-like, optional The directory in which to search for and write cached data. cache : bool, optional If True, then cache data locally for use on subsequent calls - github_url : str - Github repository where the data is stored - branch : str - The git branch to download from kws : dict, optional Passed to xarray.open_dataset + Notes + ----- + Available datasets: + + * ``"air_temperature"`` + * ``"rasm"`` + * ``"ROMS_example"`` + * ``"tiny"`` + * ``"era5-2mt-2019-03-uk.grib"`` + * ``"RGB.byte"``: example rasterio file from https://github.com/mapbox/rasterio + See Also -------- xarray.open_dataset - """ - root, ext = _os.path.splitext(name) - if not ext: - ext = ".nc" - fullname = root + ext - longdir = _os.path.expanduser(cache_dir) - localfile = _os.sep.join((longdir, fullname)) - md5name = fullname + ".md5" - md5file = _os.sep.join((longdir, md5name)) - - if not _os.path.exists(localfile): - - # This will always leave this directory on disk. - # May want to add an option to remove it. - if not _os.path.isdir(longdir): - _os.mkdir(longdir) - - url = "/".join((github_url, "raw", branch, fullname)) - urlretrieve(url, localfile) - url = "/".join((github_url, "raw", branch, md5name)) - urlretrieve(url, md5file) - - localmd5 = file_md5_checksum(localfile) - with open(md5file) as f: - remotemd5 = f.read() - if localmd5 != remotemd5: - _os.remove(localfile) - msg = """ - MD5 checksum does not match, try downloading dataset again. - """ - raise OSError(msg) - - ds = _open_dataset(localfile, **kws) - + try: + import pooch + except ImportError: + raise ImportError("using the tutorial data requires pooch") + + if isinstance(cache_dir, pathlib.Path): + cache_dir = os.fspath(cache_dir) + elif cache_dir is None: + cache_dir = pooch.os_cache(_default_cache_dir_name) + + if name in external_urls: + engine_, url = external_urls[name] + if engine is None: + engine = engine_ + else: + # process the name + default_extension = ".nc" + path = pathlib.Path(name) + if not path.suffix: + path = path.with_suffix(default_extension) + + url = f"{base_url}/raw/{version}/{path.name}" + + _open = overrides.get(engine, _open_dataset) + # retrieve the file + filepath = pooch.retrieve(url=url, known_hash=None, path=cache_dir) + ds = _open(filepath, engine=engine, **kws) if not cache: ds = ds.load() - _os.remove(localfile) + pathlib.Path(filepath).unlink() return ds