-
-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement xESMF based regridding operation (#127)
- Loading branch information
1 parent
1437f9e
commit 4f8dca8
Showing
6 changed files
with
674 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from holoviews.core import Element | ||
from holoviews.operation.element import contours | ||
from holoviews.operation.stats import bivariate_kde | ||
|
||
from .. import element as gv_element | ||
from ..element import _Element | ||
from .projection import ( # noqa (API import) | ||
project_image, project_path, project_shape, project_points, | ||
project_graph, project_quadmesh, project | ||
) | ||
|
||
geo_ops = [contours, bivariate_kde] | ||
try: | ||
from holoviews.operation.datashader import ( | ||
ResamplingOperation, shade, stack, dynspread) | ||
geo_ops += [ResamplingOperation, shade, stack, dynspread] | ||
except: | ||
pass | ||
|
||
def convert_to_geotype(element, crs=None): | ||
""" | ||
Converts a HoloViews element type to the equivalent GeoViews | ||
element if given a coordinate reference system. | ||
""" | ||
geotype = getattr(gv_element, type(element).__name__, None) | ||
if crs is None or geotype is None or isinstance(element, _Element): | ||
return element | ||
return geotype(element, crs=crs) | ||
|
||
|
||
def find_crs(element): | ||
""" | ||
Traverses the supplied object looking for coordinate reference | ||
systems (crs). If multiple clashing reference systems are found | ||
it will throw an error. | ||
""" | ||
crss = element.traverse(lambda x: x.crs, [_Element]) | ||
crss = [crs for crs in crss if crs is not None] | ||
if any(crss[0] != crs for crs in crss[1:] if crs is not None): | ||
raise ValueError('Cannot datashade Elements in different ' | ||
'coordinate reference systems.') | ||
return {'crs': crss[0] if crss else None} | ||
|
||
|
||
def add_crs(element, **kwargs): | ||
""" | ||
Converts any elements in the input to their equivalent geotypes | ||
if given a coordinate reference system. | ||
""" | ||
return element.map(lambda x: convert_to_geotype(x, kwargs.get('crs')), Element) | ||
|
||
for op in geo_ops: | ||
op._preprocess_hooks = op._preprocess_hooks + [find_crs] | ||
op._postprocess_hooks = op._postprocess_hooks + [add_crs] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import os | ||
try: | ||
FileNotFoundError | ||
except NameError: | ||
FileNotFoundError = IOError | ||
|
||
import param | ||
import numpy as np | ||
import xarray as xr | ||
|
||
from holoviews.core.util import get_param_values | ||
from holoviews.core.data import XArrayInterface | ||
from holoviews.element import Image as HvImage, QuadMesh as HvQuadMesh | ||
from holoviews.operation.datashader import regrid | ||
|
||
from ..element import Image, QuadMesh, is_geographic | ||
|
||
|
||
class weighted_regrid(regrid): | ||
""" | ||
Implements weighted regridding of rectilinear and curvilinear | ||
grids using the xESMF library, supporting all the ESMF regridding | ||
algorithms including bilinear, conservative and nearest neighbour | ||
regridding. The operation will always store the sparse weight | ||
matrix to disk and reuse the weights for later aggregations. To | ||
delete the weight files call the clean_weight_files method on the | ||
operation. | ||
""" | ||
|
||
interpolation = param.ObjectSelector(default='bilinear', | ||
objects=['bilinear', 'conservative', 'nearest_s2d', 'nearest_d2s'], doc=""" | ||
Interpolation method""") | ||
|
||
reuse_weights = param.Boolean(default=True, doc=""" | ||
Whether the sparse regridding weights should be cached as a local | ||
NetCDF file in the path defined by the file_pattern parameter. | ||
Can provide considerable speedups when exploring a larger dataset.""") | ||
|
||
file_pattern = param.String(default='{method}_{x_range}_{y_range}_{width}x{height}.nc', | ||
doc=""" | ||
The file pattern used to store the regridding weights when the | ||
reuse_weights parameter is disabled. Note the files are not | ||
cleared automatically so make sure you clean up the cached | ||
files when you are done.""") | ||
|
||
_files = [] | ||
|
||
def _get_regridder(self, element): | ||
try: | ||
import xesmf as xe | ||
except: | ||
raise ImportError("xESMF library required for weighted regridding.") | ||
x, y = element.kdims | ||
if self.p.target: | ||
tx, ty = self.p.target.kdims[:2] | ||
if issubclass(self.p.target.interface, XArrayInterface): | ||
ds_out = self.p.target.data | ||
ds_out = ds_out.rename({tx.name: 'lon', ty.name: 'lat'}) | ||
height, width = ds_out[tx.name].shape | ||
else: | ||
xs = self.p.target.dimension_values(0, expanded=False) | ||
ys = self.p.target.dimension_values(1, expanded=False) | ||
ds_out = xr.Dataset({'lat': ys, 'lon': xs}) | ||
height, width = len(ys), len(xs) | ||
x_range = ds_out[tx.name].min(), ds_out[tx.name].max() | ||
y_range = ds_out[ty.name].min(), ds_out[ty.name].max() | ||
xtype, ytype = 'numeric', 'numeric' | ||
else: | ||
info = self._get_sampling(element, x, y) | ||
(x_range, y_range), _, (width, height), (xtype, ytype) = info | ||
if x_range[0] > x_range[1]: | ||
x_range = x_range[::-1] | ||
element = element.select(**{x.name: x_range, y.name: y_range}) | ||
ys = np.linspace(y_range[0], y_range[1], height) | ||
xs = np.linspace(x_range[0], x_range[1], width) | ||
ds_out = xr.Dataset({'lat': ys, 'lon': xs}) | ||
|
||
irregular = any(element.interface.irregular(element, d) | ||
for d in [x, y]) | ||
coord_opts = {'flat': False} if irregular else {'expanded': False} | ||
coords = tuple(element.dimension_values(d, **coord_opts) | ||
for d in [x, y]) | ||
arrays = self._get_xarrays(element, coords, xtype, ytype) | ||
ds = xr.Dataset(arrays) | ||
ds.rename({x.name: 'lon', y.name: 'lat'}, inplace=True) | ||
|
||
x_range = str(tuple('%.3f' % r for r in x_range)).replace("'", '') | ||
y_range = str(tuple('%.3f' % r for r in y_range)).replace("'", '') | ||
filename = self.file_pattern.format(method=self.p.interpolation, | ||
width=width, height=height, | ||
x_range=x_range, y_range=y_range) | ||
self._files.append(os.path.abspath(filename)) | ||
regridder = xe.Regridder(ds, ds_out, self.p.interpolation, | ||
reuse_weights=self.p.reuse_weights, | ||
filename=filename) | ||
return regridder, arrays | ||
|
||
|
||
def _process(self, element, key=None): | ||
regridder, arrays = self._get_regridder(element) | ||
x, y = element.kdims | ||
ds = xr.Dataset({vd: regridder(arr) for vd, arr in arrays.items()}) | ||
ds.rename({'lon': x.name, 'lat': y.name}, inplace=True) | ||
params = get_param_values(element) | ||
if is_geographic(element): | ||
try: | ||
return Image(ds, crs=element.crs, **params) | ||
except: | ||
return QuadMesh(ds, crs=element.crs, **params) | ||
try: | ||
return HvImage(ds, **params) | ||
except: | ||
return HvQuadMesh(ds, **params) | ||
|
||
|
||
@classmethod | ||
def clean_weight_files(cls): | ||
""" | ||
Cleans existing weight files. | ||
""" | ||
deleted = [] | ||
for f in cls._files: | ||
try: | ||
os.remove(f) | ||
deleted.append(f) | ||
except FileNotFoundError: | ||
pass | ||
print('Deleted %d weight files' % len(deleted)) | ||
cls._files = [] |
Oops, something went wrong.