Skip to content

Commit

Permalink
Implement xESMF based regridding operation (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored Feb 2, 2018
1 parent 1437f9e commit 4f8dca8
Show file tree
Hide file tree
Showing 6 changed files with 674 additions and 59 deletions.
4 changes: 3 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION
- source activate test-environment
install:
- conda install -c conda-forge nose numpy matplotlib bokeh pandas scipy jupyter ipython param flake8 mock filelock iris cartopy xarray geopandas numpy shapely=1.6.3 gdal=2.2.3 libgdal=2.2.3 glib=2.55.0 gstreamer=1.8.0 --quiet
- conda install -c conda-forge nose numpy matplotlib bokeh pandas scipy jupyter ipython param flake8 mock filelock iris cartopy xarray geopandas numpy shapely=1.6.3 gdal=2.2.3 libgdal=2.2.3 glib=2.55.0 gstreamer=1.8.0 datashader --quiet
- pip install coveralls
- pip install git+https://github.com/ioam/holoviews.git
- python setup.py install
Expand All @@ -47,6 +47,8 @@ jobs:
- python -c "import geoviews as gv; gv.sample_data('notebooks/user_guide/sample-data')"
- python -c "import geoviews as gv; gv.sample_data('doc/sample-data')"
- conda install -c conda-forge sphinx beautifulsoup4 graphviz
- conda install -c nesii/label/dev-esmf -c conda-forge esmpy
- pip install xesmf
- pip install nbsite
- pip install sphinx_ioam_theme
# TODO: should make this content available too rather than deleting it!
Expand Down
54 changes: 54 additions & 0 deletions geoviews/operation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from holoviews.core import Element
from holoviews.operation.element import contours
from holoviews.operation.stats import bivariate_kde

from .. import element as gv_element
from ..element import _Element
from .projection import ( # noqa (API import)
project_image, project_path, project_shape, project_points,
project_graph, project_quadmesh, project
)

geo_ops = [contours, bivariate_kde]
try:
from holoviews.operation.datashader import (
ResamplingOperation, shade, stack, dynspread)
geo_ops += [ResamplingOperation, shade, stack, dynspread]
except:
pass

def convert_to_geotype(element, crs=None):
"""
Converts a HoloViews element type to the equivalent GeoViews
element if given a coordinate reference system.
"""
geotype = getattr(gv_element, type(element).__name__, None)
if crs is None or geotype is None or isinstance(element, _Element):
return element
return geotype(element, crs=crs)


def find_crs(element):
"""
Traverses the supplied object looking for coordinate reference
systems (crs). If multiple clashing reference systems are found
it will throw an error.
"""
crss = element.traverse(lambda x: x.crs, [_Element])
crss = [crs for crs in crss if crs is not None]
if any(crss[0] != crs for crs in crss[1:] if crs is not None):
raise ValueError('Cannot datashade Elements in different '
'coordinate reference systems.')
return {'crs': crss[0] if crss else None}


def add_crs(element, **kwargs):
"""
Converts any elements in the input to their equivalent geotypes
if given a coordinate reference system.
"""
return element.map(lambda x: convert_to_geotype(x, kwargs.get('crs')), Element)

for op in geo_ops:
op._preprocess_hooks = op._preprocess_hooks + [find_crs]
op._postprocess_hooks = op._postprocess_hooks + [add_crs]
63 changes: 5 additions & 58 deletions geoviews/operation.py → geoviews/operation/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,66 +3,13 @@

from cartopy import crs as ccrs
from cartopy.img_transform import warp_array, _determine_bounds
from holoviews.core import Element
from holoviews.core.util import cartesian_product, get_param_values
from holoviews.operation import Operation
from shapely.geometry import Polygon, LineString

from . import element as gv_element
from .element import (Image, Shape, Polygons, Path, Points, Contours,
RGB, Graph, Nodes, EdgePaths, QuadMesh, VectorField,
_Element)
from .util import project_extents, geom_to_array

geo_ops = []
try:
from holoviews.operation.datashader import (
ResamplingOperation, shade, stack, dynspread)
geo_ops += [ResamplingOperation, shade, stack, dynspread]
except:
pass

from holoviews.operation.element import contours
from holoviews.operation.stats import bivariate_kde

geo_ops += [contours, bivariate_kde]

def convert_to_geotype(element, crs=None):
"""
Converts a HoloViews element type to the equivalent GeoViews
element if given a coordinate reference system.
"""
geotype = getattr(gv_element, type(element).__name__, None)
if crs is None or geotype is None or isinstance(element, _Element):
return element
return geotype(element, crs=crs)


def find_crs(element):
"""
Traverses the supplied object looking for coordinate reference
systems (crs). If multiple clashing reference systems are found
it will throw an error.
"""
crss = element.traverse(lambda x: x.crs, [_Element])
crss = [crs for crs in crss if crs is not None]
if any(crss[0] != crs for crs in crss[1:] if crs is not None):
raise ValueError('Cannot datashade Elements in different '
'coordinate reference systems.')
return {'crs': crss[0] if crss else None}


def add_crs(element, **kwargs):
"""
Converts any elements in the input to their equivalent geotypes
if given a coordinate reference system.
"""
return element.map(lambda x: convert_to_geotype(x, kwargs.get('crs')), Element)


for op in geo_ops:
op._preprocess_hooks = op._preprocess_hooks + [find_crs]
op._postprocess_hooks = op._postprocess_hooks + [add_crs]
from ..element import (Image, Shape, Polygons, Path, Points, Contours,
RGB, Graph, Nodes, EdgePaths, QuadMesh, VectorField)
from ..util import project_extents, geom_to_array


class _project_operation(Operation):
Expand Down Expand Up @@ -254,8 +201,8 @@ def _process(self, img, key=None):
data = np.flipud(projected)
bounds = (extents[0], extents[2], extents[1], extents[3])
return img.clone(data, bounds=bounds, kdims=img.kdims,
vdims=img.vdims, crs=proj)

vdims=img.vdims, crs=proj, xdensity=None,
ydensity=None)

def _fast_process(self, element, key=None):
# Project coordinates
Expand Down
129 changes: 129 additions & 0 deletions geoviews/operation/regrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os
try:
FileNotFoundError
except NameError:
FileNotFoundError = IOError

import param
import numpy as np
import xarray as xr

from holoviews.core.util import get_param_values
from holoviews.core.data import XArrayInterface
from holoviews.element import Image as HvImage, QuadMesh as HvQuadMesh
from holoviews.operation.datashader import regrid

from ..element import Image, QuadMesh, is_geographic


class weighted_regrid(regrid):
"""
Implements weighted regridding of rectilinear and curvilinear
grids using the xESMF library, supporting all the ESMF regridding
algorithms including bilinear, conservative and nearest neighbour
regridding. The operation will always store the sparse weight
matrix to disk and reuse the weights for later aggregations. To
delete the weight files call the clean_weight_files method on the
operation.
"""

interpolation = param.ObjectSelector(default='bilinear',
objects=['bilinear', 'conservative', 'nearest_s2d', 'nearest_d2s'], doc="""
Interpolation method""")

reuse_weights = param.Boolean(default=True, doc="""
Whether the sparse regridding weights should be cached as a local
NetCDF file in the path defined by the file_pattern parameter.
Can provide considerable speedups when exploring a larger dataset.""")

file_pattern = param.String(default='{method}_{x_range}_{y_range}_{width}x{height}.nc',
doc="""
The file pattern used to store the regridding weights when the
reuse_weights parameter is disabled. Note the files are not
cleared automatically so make sure you clean up the cached
files when you are done.""")

_files = []

def _get_regridder(self, element):
try:
import xesmf as xe
except:
raise ImportError("xESMF library required for weighted regridding.")
x, y = element.kdims
if self.p.target:
tx, ty = self.p.target.kdims[:2]
if issubclass(self.p.target.interface, XArrayInterface):
ds_out = self.p.target.data
ds_out = ds_out.rename({tx.name: 'lon', ty.name: 'lat'})
height, width = ds_out[tx.name].shape
else:
xs = self.p.target.dimension_values(0, expanded=False)
ys = self.p.target.dimension_values(1, expanded=False)
ds_out = xr.Dataset({'lat': ys, 'lon': xs})
height, width = len(ys), len(xs)
x_range = ds_out[tx.name].min(), ds_out[tx.name].max()
y_range = ds_out[ty.name].min(), ds_out[ty.name].max()
xtype, ytype = 'numeric', 'numeric'
else:
info = self._get_sampling(element, x, y)
(x_range, y_range), _, (width, height), (xtype, ytype) = info
if x_range[0] > x_range[1]:
x_range = x_range[::-1]
element = element.select(**{x.name: x_range, y.name: y_range})
ys = np.linspace(y_range[0], y_range[1], height)
xs = np.linspace(x_range[0], x_range[1], width)
ds_out = xr.Dataset({'lat': ys, 'lon': xs})

irregular = any(element.interface.irregular(element, d)
for d in [x, y])
coord_opts = {'flat': False} if irregular else {'expanded': False}
coords = tuple(element.dimension_values(d, **coord_opts)
for d in [x, y])
arrays = self._get_xarrays(element, coords, xtype, ytype)
ds = xr.Dataset(arrays)
ds.rename({x.name: 'lon', y.name: 'lat'}, inplace=True)

x_range = str(tuple('%.3f' % r for r in x_range)).replace("'", '')
y_range = str(tuple('%.3f' % r for r in y_range)).replace("'", '')
filename = self.file_pattern.format(method=self.p.interpolation,
width=width, height=height,
x_range=x_range, y_range=y_range)
self._files.append(os.path.abspath(filename))
regridder = xe.Regridder(ds, ds_out, self.p.interpolation,
reuse_weights=self.p.reuse_weights,
filename=filename)
return regridder, arrays


def _process(self, element, key=None):
regridder, arrays = self._get_regridder(element)
x, y = element.kdims
ds = xr.Dataset({vd: regridder(arr) for vd, arr in arrays.items()})
ds.rename({'lon': x.name, 'lat': y.name}, inplace=True)
params = get_param_values(element)
if is_geographic(element):
try:
return Image(ds, crs=element.crs, **params)
except:
return QuadMesh(ds, crs=element.crs, **params)
try:
return HvImage(ds, **params)
except:
return HvQuadMesh(ds, **params)


@classmethod
def clean_weight_files(cls):
"""
Cleans existing weight files.
"""
deleted = []
for f in cls._files:
try:
os.remove(f)
deleted.append(f)
except FileNotFoundError:
pass
print('Deleted %d weight files' % len(deleted))
cls._files = []
Loading

0 comments on commit 4f8dca8

Please sign in to comment.