Skip to content

Commit

Permalink
cache rasterio example files (pydata#4102)
Browse files Browse the repository at this point in the history
  • Loading branch information
keewis authored Mar 24, 2021
1 parent ec4e8b5 commit 8452120
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 69 deletions.
1 change: 1 addition & 0 deletions ci/requirements/doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies:
- numba
- numpy>=1.17
- pandas>=1.0
- pooch
- pip
- pydata-sphinx-theme>=0.4.3
- rasterio>=1.1
Expand Down
1 change: 1 addition & 0 deletions ci/requirements/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies:
- pandas
- pint
- pip=20.2
- pooch
- pre-commit
- pseudonetcdf
- pydap
Expand Down
6 changes: 2 additions & 4 deletions doc/examples/visualization_gallery.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@
"metadata": {},
"outputs": [],
"source": [
"url = 'https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif'\n",
"da = xr.open_rasterio(url)\n",
"da = xr.tutorial.open_dataset(\"RGB.byte\").data\n",
"\n",
"# The data is in UTM projection. We have to set it manually until\n",
"# https://github.com/SciTools/cartopy/issues/813 is implemented\n",
Expand Down Expand Up @@ -246,8 +245,7 @@
"from rasterio.warp import transform\n",
"import numpy as np\n",
"\n",
"url = 'https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif'\n",
"da = xr.open_rasterio(url)\n",
"da = xr.tutorial.open_dataset(\"RGB.byte\").data\n",
"\n",
"# Compute the lon/lat coordinates with rasterio.warp.transform\n",
"ny, nx = len(da['y']), len(da['x'])\n",
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-pint.*]
ignore_missing_imports = True
[mypy-pooch.*]
ignore_missing_imports = True
[mypy-PseudoNetCDF.*]
ignore_missing_imports = True
[mypy-pydap.*]
Expand All @@ -233,6 +235,7 @@ ignore_missing_imports = True
[mypy-xarray.core.pycompat]
ignore_errors = True


[aliases]
test = pytest

Expand Down
27 changes: 16 additions & 11 deletions xarray/tests/test_tutorial.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from contextlib import suppress

import pytest

Expand All @@ -13,20 +12,26 @@ class TestLoadDataset:
@pytest.fixture(autouse=True)
def setUp(self):
self.testfile = "tiny"
self.testfilepath = os.path.expanduser(
os.sep.join(("~", ".xarray_tutorial_data", self.testfile))
)
with suppress(OSError):
os.remove(f"{self.testfilepath}.nc")
with suppress(OSError):
os.remove(f"{self.testfilepath}.md5")

def test_download_from_github(self):

def test_download_from_github(self, tmp_path, monkeypatch):
monkeypatch.setenv("XDG_CACHE_DIR", os.fspath(tmp_path))

ds = tutorial.open_dataset(self.testfile).load()
tiny = DataArray(range(5), name="tiny").to_dataset()
assert_identical(ds, tiny)

def test_download_from_github_load_without_cache(self):
def test_download_from_github_load_without_cache(self, tmp_path, monkeypatch):
monkeypatch.setenv("XDG_CACHE_DIR", os.fspath(tmp_path))

ds_nocache = tutorial.open_dataset(self.testfile, cache=False).load()
ds_cache = tutorial.open_dataset(self.testfile).load()
assert_identical(ds_cache, ds_nocache)

def test_download_rasterio_from_github_load_without_cache(
self, tmp_path, monkeypatch
):
monkeypatch.setenv("XDG_CACHE_DIR", os.fspath(tmp_path))

ds_nocache = tutorial.open_dataset("RGB.byte", cache=False).load()
ds_cache = tutorial.open_dataset("RGB.byte", cache=True).load()
assert_identical(ds_cache, ds_nocache)
121 changes: 67 additions & 54 deletions xarray/tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,45 @@
* building tutorials in the documentation.
"""
import hashlib
import os as _os
from urllib.request import urlretrieve
import os
import pathlib

import numpy as np

from .backends.api import open_dataset as _open_dataset
from .backends.rasterio_ import open_rasterio
from .core.dataarray import DataArray
from .core.dataset import Dataset

_default_cache_dir = _os.sep.join(("~", ".xarray_tutorial_data"))

def _open_rasterio(path, engine=None, **kwargs):
data = open_rasterio(path, **kwargs)
name = data.name if data.name is not None else "data"
return data.to_dataset(name=name)

def file_md5_checksum(fname):
hash_md5 = hashlib.md5()
with open(fname, "rb") as f:
hash_md5.update(f.read())
return hash_md5.hexdigest()

_default_cache_dir_name = "xarray_tutorial_data"
base_url = "https://github.com/pydata/xarray-data"
version = "master"


external_urls = {
"RGB.byte": (
"rasterio",
"https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif",
),
}
overrides = {
"rasterio": _open_rasterio,
}


# idea borrowed from Seaborn
def open_dataset(
name,
engine=None,
cache=True,
cache_dir=_default_cache_dir,
github_url="https://github.com/pydata/xarray-data",
branch="master",
cache_dir=None,
**kws,
):
"""
Expand All @@ -42,61 +54,62 @@ def open_dataset(
Parameters
----------
name : str
Name of the file containing the dataset. If no suffix is given, assumed
to be netCDF ('.nc' is appended)
Name of the file containing the dataset.
e.g. 'air_temperature'
cache_dir : str, optional
engine : str, optional
The engine to use.
cache_dir : path-like, optional
The directory in which to search for and write cached data.
cache : bool, optional
If True, then cache data locally for use on subsequent calls
github_url : str
Github repository where the data is stored
branch : str
The git branch to download from
kws : dict, optional
Passed to xarray.open_dataset
Notes
-----
Available datasets:
* ``"air_temperature"``
* ``"rasm"``
* ``"ROMS_example"``
* ``"tiny"``
* ``"era5-2mt-2019-03-uk.grib"``
* ``"RGB.byte"``: example rasterio file from https://github.com/mapbox/rasterio
See Also
--------
xarray.open_dataset
"""
root, ext = _os.path.splitext(name)
if not ext:
ext = ".nc"
fullname = root + ext
longdir = _os.path.expanduser(cache_dir)
localfile = _os.sep.join((longdir, fullname))
md5name = fullname + ".md5"
md5file = _os.sep.join((longdir, md5name))

if not _os.path.exists(localfile):

# This will always leave this directory on disk.
# May want to add an option to remove it.
if not _os.path.isdir(longdir):
_os.mkdir(longdir)

url = "/".join((github_url, "raw", branch, fullname))
urlretrieve(url, localfile)
url = "/".join((github_url, "raw", branch, md5name))
urlretrieve(url, md5file)

localmd5 = file_md5_checksum(localfile)
with open(md5file) as f:
remotemd5 = f.read()
if localmd5 != remotemd5:
_os.remove(localfile)
msg = """
MD5 checksum does not match, try downloading dataset again.
"""
raise OSError(msg)

ds = _open_dataset(localfile, **kws)

try:
import pooch
except ImportError:
raise ImportError("using the tutorial data requires pooch")

if isinstance(cache_dir, pathlib.Path):
cache_dir = os.fspath(cache_dir)
elif cache_dir is None:
cache_dir = pooch.os_cache(_default_cache_dir_name)

if name in external_urls:
engine_, url = external_urls[name]
if engine is None:
engine = engine_
else:
# process the name
default_extension = ".nc"
path = pathlib.Path(name)
if not path.suffix:
path = path.with_suffix(default_extension)

url = f"{base_url}/raw/{version}/{path.name}"

_open = overrides.get(engine, _open_dataset)
# retrieve the file
filepath = pooch.retrieve(url=url, known_hash=None, path=cache_dir)
ds = _open(filepath, engine=engine, **kws)
if not cache:
ds = ds.load()
_os.remove(localfile)
pathlib.Path(filepath).unlink()

return ds

Expand Down

0 comments on commit 8452120

Please sign in to comment.