Skip to content

Commit 8452120

Browse files
authored
cache rasterio example files (#4102)
1 parent ec4e8b5 commit 8452120

File tree

6 files changed

+90
-69
lines changed

6 files changed

+90
-69
lines changed

ci/requirements/doc.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies:
2020
- numba
2121
- numpy>=1.17
2222
- pandas>=1.0
23+
- pooch
2324
- pip
2425
- pydata-sphinx-theme>=0.4.3
2526
- rasterio>=1.1

ci/requirements/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ dependencies:
2727
- pandas
2828
- pint
2929
- pip=20.2
30+
- pooch
3031
- pre-commit
3132
- pseudonetcdf
3233
- pydap

doc/examples/visualization_gallery.ipynb

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,7 @@
209209
"metadata": {},
210210
"outputs": [],
211211
"source": [
212-
"url = 'https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif'\n",
213-
"da = xr.open_rasterio(url)\n",
212+
"da = xr.tutorial.open_dataset(\"RGB.byte\").data\n",
214213
"\n",
215214
"# The data is in UTM projection. We have to set it manually until\n",
216215
"# https://github.com/SciTools/cartopy/issues/813 is implemented\n",
@@ -246,8 +245,7 @@
246245
"from rasterio.warp import transform\n",
247246
"import numpy as np\n",
248247
"\n",
249-
"url = 'https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif'\n",
250-
"da = xr.open_rasterio(url)\n",
248+
"da = xr.tutorial.open_dataset(\"RGB.byte\").data\n",
251249
"\n",
252250
"# Compute the lon/lat coordinates with rasterio.warp.transform\n",
253251
"ny, nx = len(da['y']), len(da['x'])\n",

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ ignore_missing_imports = True
208208
ignore_missing_imports = True
209209
[mypy-pint.*]
210210
ignore_missing_imports = True
211+
[mypy-pooch.*]
212+
ignore_missing_imports = True
211213
[mypy-PseudoNetCDF.*]
212214
ignore_missing_imports = True
213215
[mypy-pydap.*]
@@ -233,6 +235,7 @@ ignore_missing_imports = True
233235
[mypy-xarray.core.pycompat]
234236
ignore_errors = True
235237

238+
236239
[aliases]
237240
test = pytest
238241

xarray/tests/test_tutorial.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
from contextlib import suppress
32

43
import pytest
54

@@ -13,20 +12,26 @@ class TestLoadDataset:
1312
@pytest.fixture(autouse=True)
1413
def setUp(self):
1514
self.testfile = "tiny"
16-
self.testfilepath = os.path.expanduser(
17-
os.sep.join(("~", ".xarray_tutorial_data", self.testfile))
18-
)
19-
with suppress(OSError):
20-
os.remove(f"{self.testfilepath}.nc")
21-
with suppress(OSError):
22-
os.remove(f"{self.testfilepath}.md5")
23-
24-
def test_download_from_github(self):
15+
16+
def test_download_from_github(self, tmp_path, monkeypatch):
17+
monkeypatch.setenv("XDG_CACHE_DIR", os.fspath(tmp_path))
18+
2519
ds = tutorial.open_dataset(self.testfile).load()
2620
tiny = DataArray(range(5), name="tiny").to_dataset()
2721
assert_identical(ds, tiny)
2822

29-
def test_download_from_github_load_without_cache(self):
23+
def test_download_from_github_load_without_cache(self, tmp_path, monkeypatch):
24+
monkeypatch.setenv("XDG_CACHE_DIR", os.fspath(tmp_path))
25+
3026
ds_nocache = tutorial.open_dataset(self.testfile, cache=False).load()
3127
ds_cache = tutorial.open_dataset(self.testfile).load()
3228
assert_identical(ds_cache, ds_nocache)
29+
30+
def test_download_rasterio_from_github_load_without_cache(
31+
self, tmp_path, monkeypatch
32+
):
33+
monkeypatch.setenv("XDG_CACHE_DIR", os.fspath(tmp_path))
34+
35+
ds_nocache = tutorial.open_dataset("RGB.byte", cache=False).load()
36+
ds_cache = tutorial.open_dataset("RGB.byte", cache=True).load()
37+
assert_identical(ds_cache, ds_nocache)

xarray/tutorial.py

Lines changed: 67 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,45 @@
55
* building tutorials in the documentation.
66
77
"""
8-
import hashlib
9-
import os as _os
10-
from urllib.request import urlretrieve
8+
import os
9+
import pathlib
1110

1211
import numpy as np
1312

1413
from .backends.api import open_dataset as _open_dataset
14+
from .backends.rasterio_ import open_rasterio
1515
from .core.dataarray import DataArray
1616
from .core.dataset import Dataset
1717

18-
_default_cache_dir = _os.sep.join(("~", ".xarray_tutorial_data"))
1918

19+
def _open_rasterio(path, engine=None, **kwargs):
20+
data = open_rasterio(path, **kwargs)
21+
name = data.name if data.name is not None else "data"
22+
return data.to_dataset(name=name)
2023

21-
def file_md5_checksum(fname):
22-
hash_md5 = hashlib.md5()
23-
with open(fname, "rb") as f:
24-
hash_md5.update(f.read())
25-
return hash_md5.hexdigest()
24+
25+
_default_cache_dir_name = "xarray_tutorial_data"
26+
base_url = "https://github.com/pydata/xarray-data"
27+
version = "master"
28+
29+
30+
external_urls = {
31+
"RGB.byte": (
32+
"rasterio",
33+
"https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif",
34+
),
35+
}
36+
overrides = {
37+
"rasterio": _open_rasterio,
38+
}
2639

2740

2841
# idea borrowed from Seaborn
2942
def open_dataset(
3043
name,
44+
engine=None,
3145
cache=True,
32-
cache_dir=_default_cache_dir,
33-
github_url="https://github.com/pydata/xarray-data",
34-
branch="master",
46+
cache_dir=None,
3547
**kws,
3648
):
3749
"""
@@ -42,61 +54,62 @@ def open_dataset(
4254
Parameters
4355
----------
4456
name : str
45-
Name of the file containing the dataset. If no suffix is given, assumed
46-
to be netCDF ('.nc' is appended)
57+
Name of the file containing the dataset.
4758
e.g. 'air_temperature'
48-
cache_dir : str, optional
59+
engine : str, optional
60+
The engine to use.
61+
cache_dir : path-like, optional
4962
The directory in which to search for and write cached data.
5063
cache : bool, optional
5164
If True, then cache data locally for use on subsequent calls
52-
github_url : str
53-
Github repository where the data is stored
54-
branch : str
55-
The git branch to download from
5665
kws : dict, optional
5766
Passed to xarray.open_dataset
5867
68+
Notes
69+
-----
70+
Available datasets:
71+
72+
* ``"air_temperature"``
73+
* ``"rasm"``
74+
* ``"ROMS_example"``
75+
* ``"tiny"``
76+
* ``"era5-2mt-2019-03-uk.grib"``
77+
* ``"RGB.byte"``: example rasterio file from https://github.com/mapbox/rasterio
78+
5979
See Also
6080
--------
6181
xarray.open_dataset
62-
6382
"""
64-
root, ext = _os.path.splitext(name)
65-
if not ext:
66-
ext = ".nc"
67-
fullname = root + ext
68-
longdir = _os.path.expanduser(cache_dir)
69-
localfile = _os.sep.join((longdir, fullname))
70-
md5name = fullname + ".md5"
71-
md5file = _os.sep.join((longdir, md5name))
72-
73-
if not _os.path.exists(localfile):
74-
75-
# This will always leave this directory on disk.
76-
# May want to add an option to remove it.
77-
if not _os.path.isdir(longdir):
78-
_os.mkdir(longdir)
79-
80-
url = "/".join((github_url, "raw", branch, fullname))
81-
urlretrieve(url, localfile)
82-
url = "/".join((github_url, "raw", branch, md5name))
83-
urlretrieve(url, md5file)
84-
85-
localmd5 = file_md5_checksum(localfile)
86-
with open(md5file) as f:
87-
remotemd5 = f.read()
88-
if localmd5 != remotemd5:
89-
_os.remove(localfile)
90-
msg = """
91-
MD5 checksum does not match, try downloading dataset again.
92-
"""
93-
raise OSError(msg)
94-
95-
ds = _open_dataset(localfile, **kws)
96-
83+
try:
84+
import pooch
85+
except ImportError:
86+
raise ImportError("using the tutorial data requires pooch")
87+
88+
if isinstance(cache_dir, pathlib.Path):
89+
cache_dir = os.fspath(cache_dir)
90+
elif cache_dir is None:
91+
cache_dir = pooch.os_cache(_default_cache_dir_name)
92+
93+
if name in external_urls:
94+
engine_, url = external_urls[name]
95+
if engine is None:
96+
engine = engine_
97+
else:
98+
# process the name
99+
default_extension = ".nc"
100+
path = pathlib.Path(name)
101+
if not path.suffix:
102+
path = path.with_suffix(default_extension)
103+
104+
url = f"{base_url}/raw/{version}/{path.name}"
105+
106+
_open = overrides.get(engine, _open_dataset)
107+
# retrieve the file
108+
filepath = pooch.retrieve(url=url, known_hash=None, path=cache_dir)
109+
ds = _open(filepath, engine=engine, **kws)
97110
if not cache:
98111
ds = ds.load()
99-
_os.remove(localfile)
112+
pathlib.Path(filepath).unlink()
100113

101114
return ds
102115

0 commit comments

Comments
 (0)