-
Notifications
You must be signed in to change notification settings - Fork 170
Refactoring field interpolation and allow custom interpolation methods in Scipy mode #1816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
61 commits
Select commit
Hold shift + click to select a range
b4d9c2b
refactor error classes removing custom inits
VeckoTheGecko 681f668
update to use _raise_out_of_bound_error
VeckoTheGecko 6990c02
update to use _raise_out_of_bound_error
VeckoTheGecko 44fc36a
update to use _raise_out_of_bound_surface_error
VeckoTheGecko 08f560c
Rename parse_particletime
VeckoTheGecko a3f978c
update to use _raise_field_sampling_error
VeckoTheGecko 13348b4
Renaming functions
VeckoTheGecko 71d9532
Add exception helper
VeckoTheGecko 912b84f
Remove implementation for deprecated public interpolation methods
VeckoTheGecko 4b7ac56
Move search_indices_vertical_z and search_indices_vertical_s to _inte…
VeckoTheGecko 33f9e31
Lift field name during interpolation error handling into _spatial_int…
VeckoTheGecko 3a75f83
Refactoring
VeckoTheGecko c274704
Rename file
VeckoTheGecko 14b2d9b
Refactor 2D interpolators into separate file
VeckoTheGecko debce8a
Refactor 3D interpolators into separate file
VeckoTheGecko 892c155
Move interpolation call
VeckoTheGecko 43f75a5
Temporarily add interp_method to 3D interpolation context
VeckoTheGecko a89e906
Add test_interpolation.py
VeckoTheGecko 088bedd
Add some tests for interpolation methods
VeckoTheGecko 94c4fa3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8dbb106
Supporting code for 3d interpolation refactor
VeckoTheGecko e908e53
Update fixture use method
VeckoTheGecko 24b3c37
Add test_full_depth_provided_to_interpolators
VeckoTheGecko deb02d7
Fix test pollution
VeckoTheGecko 9a01d4c
Redefine zdim based on data array
VeckoTheGecko 8a16e18
Refactor using unit_square_to_target() function
VeckoTheGecko fded9ce
Change order to eta, xsi
VeckoTheGecko 15f2abc
refactor to get_3d_f0_f1
VeckoTheGecko 42561a7
Split up 3d interpolation functions
VeckoTheGecko fe0866b
Force kwargs
VeckoTheGecko 3347cdf
add z_layer_interp
VeckoTheGecko 10d02a0
Reduce code duplication
VeckoTheGecko bf58f67
Add interpolation testing for each grid cell
VeckoTheGecko 78429eb
Delete old 3d interpolator
VeckoTheGecko a235874
update comment
VeckoTheGecko 46fc3c3
Review feedback
VeckoTheGecko 5ed3310
patch indexerror
VeckoTheGecko bf79d18
Rename file _indexing.py -> _index_search.py
VeckoTheGecko 353d220
Refactor test
VeckoTheGecko d604ae1
Rename test data function
VeckoTheGecko b9d9242
review feedback
VeckoTheGecko 7052b26
Move `calc_cell_edge_sizes`, and `cell_areas` out of field.py
VeckoTheGecko cb0ba68
Move methods _search_indices_curvilinear, _search_indices_rectilinear…
VeckoTheGecko 8a861c3
move tests
VeckoTheGecko 77d8a8e
Move reconnect_bnd_indices to grid.py
VeckoTheGecko 76c6865
Remove casts to float32
VeckoTheGecko d88bce7
Remove msg from TimeExtrapolation constructor
VeckoTheGecko 7d2e762
Fix citations
VeckoTheGecko 3854ea1
review feedback
VeckoTheGecko a0a0336
Adding a unit test to compare JIT and SciPy interpolation/integration
erikvansebille 8ac1593
Review edits
VeckoTheGecko af8ffe8
Review feedback
VeckoTheGecko d61563f
cleanup test_interpolation.py
VeckoTheGecko 46a2853
xfail cgrid_velocity on test_scipy_vs_jit
VeckoTheGecko 09aec5e
Relaxing jit-vs-scipy tolerance
erikvansebille baaf04c
Merge branch 'main' into v/refactor-interp
erikvansebille 10a44f8
Copying #1834 changes into new _index_search functions
erikvansebille cc87fcd
Patch unit test
VeckoTheGecko 0a8d593
Merge remote-tracking branch 'origin/main' into v/refactor-interp
VeckoTheGecko 47c9ae6
Fixing timestep in unit test
erikvansebille d029cf8
Reducing velocity strengths (to avoid out-of-bounds deletions)
erikvansebille File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,337 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import numpy as np | ||
|
|
||
| from parcels._typing import ( | ||
| GridIndexingType, | ||
| InterpMethodOption, | ||
| ) | ||
| from parcels.tools.statuscodes import ( | ||
| FieldOutOfBoundError, | ||
| FieldOutOfBoundSurfaceError, | ||
| _raise_field_out_of_bound_error, | ||
| _raise_field_out_of_bound_surface_error, | ||
| _raise_field_sampling_error, | ||
| ) | ||
|
|
||
| from .grid import GridType | ||
|
|
||
| if TYPE_CHECKING: | ||
| from .field import Field | ||
| from .grid import Grid | ||
|
|
||
|
|
||
| def search_indices_vertical_z(grid: Grid, gridindexingtype: GridIndexingType, z: float): | ||
| if grid.depth[-1] > grid.depth[0]: | ||
| if z < grid.depth[0]: | ||
| # Since MOM5 is indexed at cell bottom, allow z at depth[0] - dz where dz = (depth[1] - depth[0]) | ||
| if gridindexingtype == "mom5" and z > 2 * grid.depth[0] - grid.depth[1]: | ||
| return (-1, z / grid.depth[0]) | ||
| else: | ||
| _raise_field_out_of_bound_surface_error(z, None, None) | ||
| elif z > grid.depth[-1]: | ||
| # In case of CROCO, allow particles in last (uppermost) layer using depth[-1] | ||
| if gridindexingtype in ["croco"] and z < 0: | ||
| return (-2, 1) | ||
| _raise_field_out_of_bound_error(z, None, None) | ||
| depth_indices = grid.depth < z | ||
| if z >= grid.depth[-1]: | ||
| zi = len(grid.depth) - 2 | ||
| else: | ||
| zi = depth_indices.argmin() - 1 if z > grid.depth[0] else 0 | ||
| else: | ||
| if z > grid.depth[0]: | ||
| _raise_field_out_of_bound_surface_error(z, None, None) | ||
| elif z < grid.depth[-1]: | ||
| _raise_field_out_of_bound_error(z, None, None) | ||
| depth_indices = grid.depth > z | ||
| if z <= grid.depth[-1]: | ||
| zi = len(grid.depth) - 2 | ||
| else: | ||
| zi = depth_indices.argmin() - 1 if z < grid.depth[0] else 0 | ||
| zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi]) | ||
| while zeta > 1: | ||
| zi += 1 | ||
| zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi]) | ||
| while zeta < 0: | ||
| zi -= 1 | ||
| zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi]) | ||
| return (zi, zeta) | ||
|
|
||
|
|
||
| def search_indices_vertical_s( | ||
| grid: Grid, | ||
| interp_method: InterpMethodOption, | ||
| time: float, | ||
| z: float, | ||
| y: float, | ||
| x: float, | ||
| ti: int, | ||
| yi: int, | ||
| xi: int, | ||
| eta: float, | ||
| xsi: float, | ||
| ): | ||
| if interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"]: | ||
| xsi = 1 | ||
| eta = 1 | ||
| if time < grid.time[ti]: | ||
| ti -= 1 | ||
| if grid._z4d: | ||
| if ti == len(grid.time) - 1: | ||
| depth_vector = ( | ||
| (1 - xsi) * (1 - eta) * grid.depth[-1, :, yi, xi] | ||
| + xsi * (1 - eta) * grid.depth[-1, :, yi, xi + 1] | ||
| + xsi * eta * grid.depth[-1, :, yi + 1, xi + 1] | ||
| + (1 - xsi) * eta * grid.depth[-1, :, yi + 1, xi] | ||
| ) | ||
| else: | ||
| dv2 = ( | ||
| (1 - xsi) * (1 - eta) * grid.depth[ti : ti + 2, :, yi, xi] | ||
| + xsi * (1 - eta) * grid.depth[ti : ti + 2, :, yi, xi + 1] | ||
| + xsi * eta * grid.depth[ti : ti + 2, :, yi + 1, xi + 1] | ||
| + (1 - xsi) * eta * grid.depth[ti : ti + 2, :, yi + 1, xi] | ||
| ) | ||
| tt = (time - grid.time[ti]) / (grid.time[ti + 1] - grid.time[ti]) | ||
| assert tt >= 0 and tt <= 1, "Vertical s grid is being wrongly interpolated in time" | ||
| depth_vector = dv2[0, :] * (1 - tt) + dv2[1, :] * tt | ||
| else: | ||
| depth_vector = ( | ||
| (1 - xsi) * (1 - eta) * grid.depth[:, yi, xi] | ||
| + xsi * (1 - eta) * grid.depth[:, yi, xi + 1] | ||
| + xsi * eta * grid.depth[:, yi + 1, xi + 1] | ||
| + (1 - xsi) * eta * grid.depth[:, yi + 1, xi] | ||
| ) | ||
| z = np.float32(z) # type: ignore # TODO: remove type ignore once we migrate to float64 | ||
VeckoTheGecko marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if depth_vector[-1] > depth_vector[0]: | ||
| if z < depth_vector[0]: | ||
| _raise_field_out_of_bound_error(z, None, None) | ||
| elif z > depth_vector[-1]: | ||
| _raise_field_out_of_bound_error(z, None, None) | ||
| depth_indices = depth_vector < z | ||
| if z >= depth_vector[-1]: | ||
| zi = len(depth_vector) - 2 | ||
| else: | ||
| zi = depth_indices.argmin() - 1 if z > depth_vector[0] else 0 | ||
| else: | ||
| if z > depth_vector[0]: | ||
| _raise_field_out_of_bound_error(z, None, None) | ||
| elif z < depth_vector[-1]: | ||
| _raise_field_out_of_bound_error(z, None, None) | ||
| depth_indices = depth_vector > z | ||
| if z <= depth_vector[-1]: | ||
| zi = len(depth_vector) - 2 | ||
| else: | ||
| zi = depth_indices.argmin() - 1 if z < depth_vector[0] else 0 | ||
| zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi]) | ||
| while zeta > 1: | ||
| zi += 1 | ||
| zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi]) | ||
| while zeta < 0: | ||
| zi -= 1 | ||
| zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi]) | ||
| return (zi, zeta) | ||
|
|
||
|
|
||
| def _search_indices_rectilinear( | ||
| field: Field, time: float, z: float, y: float, x: float, ti=-1, particle=None, search2D=False | ||
| ): | ||
| grid = field.grid | ||
|
|
||
| if grid.xdim > 1 and (not grid.zonal_periodic): | ||
| if x < grid.lonlat_minmax[0] or x > grid.lonlat_minmax[1]: | ||
| _raise_field_out_of_bound_error(z, y, x) | ||
| if grid.ydim > 1 and (y < grid.lonlat_minmax[2] or y > grid.lonlat_minmax[3]): | ||
| _raise_field_out_of_bound_error(z, y, x) | ||
|
|
||
| if grid.xdim > 1: | ||
| if grid.mesh != "spherical": | ||
| lon_index = grid.lon < x | ||
| if lon_index.all(): | ||
| xi = len(grid.lon) - 2 | ||
| else: | ||
| xi = lon_index.argmin() - 1 if lon_index.any() else 0 | ||
| xsi = (x - grid.lon[xi]) / (grid.lon[xi + 1] - grid.lon[xi]) | ||
| if xsi < 0: | ||
| xi -= 1 | ||
| xsi = (x - grid.lon[xi]) / (grid.lon[xi + 1] - grid.lon[xi]) | ||
| elif xsi > 1: | ||
| xi += 1 | ||
| xsi = (x - grid.lon[xi]) / (grid.lon[xi + 1] - grid.lon[xi]) | ||
| else: | ||
| lon_fixed = grid.lon.copy() | ||
| indices = lon_fixed >= lon_fixed[0] | ||
| if not indices.all(): | ||
| lon_fixed[indices.argmin() :] += 360 | ||
| if x < lon_fixed[0]: | ||
| lon_fixed -= 360 | ||
|
|
||
| lon_index = lon_fixed < x | ||
| if lon_index.all(): | ||
| xi = len(lon_fixed) - 2 | ||
| else: | ||
| xi = lon_index.argmin() - 1 if lon_index.any() else 0 | ||
| xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi]) | ||
| if xsi < 0: | ||
| xi -= 1 | ||
| xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi]) | ||
| elif xsi > 1: | ||
| xi += 1 | ||
| xsi = (x - lon_fixed[xi]) / (lon_fixed[xi + 1] - lon_fixed[xi]) | ||
| else: | ||
| xi, xsi = -1, 0 | ||
|
|
||
| if grid.ydim > 1: | ||
| lat_index = grid.lat < y | ||
| if lat_index.all(): | ||
| yi = len(grid.lat) - 2 | ||
| else: | ||
| yi = lat_index.argmin() - 1 if lat_index.any() else 0 | ||
|
|
||
| eta = (y - grid.lat[yi]) / (grid.lat[yi + 1] - grid.lat[yi]) | ||
| if eta < 0: | ||
| yi -= 1 | ||
| eta = (y - grid.lat[yi]) / (grid.lat[yi + 1] - grid.lat[yi]) | ||
| elif eta > 1: | ||
| yi += 1 | ||
| eta = (y - grid.lat[yi]) / (grid.lat[yi + 1] - grid.lat[yi]) | ||
| else: | ||
| yi, eta = -1, 0 | ||
|
|
||
| if grid.zdim > 1 and not search2D: | ||
| if grid._gtype == GridType.RectilinearZGrid: | ||
| try: | ||
| (zi, zeta) = search_indices_vertical_z(field.grid, field.gridindexingtype, z) | ||
| except FieldOutOfBoundError: | ||
| _raise_field_out_of_bound_error(z, y, x) | ||
| except FieldOutOfBoundSurfaceError: | ||
| _raise_field_out_of_bound_surface_error(z, y, x) | ||
| elif grid._gtype == GridType.RectilinearSGrid: | ||
| (zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi) | ||
| else: | ||
| zi, zeta = -1, 0 | ||
|
|
||
| if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)): | ||
| _raise_field_sampling_error(z, y, x) | ||
|
|
||
| if particle: | ||
| particle.xi[field.igrid] = xi | ||
| particle.yi[field.igrid] = yi | ||
| particle.zi[field.igrid] = zi | ||
|
|
||
| return (zeta, eta, xsi, zi, yi, xi) | ||
|
|
||
|
|
||
| def _search_indices_curvilinear(field: Field, time, z, y, x, ti=-1, particle=None, search2D=False): | ||
| if particle: | ||
| xi = particle.xi[field.igrid] | ||
| yi = particle.yi[field.igrid] | ||
| else: | ||
| xi = int(field.grid.xdim / 2) - 1 | ||
| yi = int(field.grid.ydim / 2) - 1 | ||
| xsi = eta = -1 | ||
| grid = field.grid | ||
| invA = np.array([[1, 0, 0, 0], [-1, 1, 0, 0], [-1, 0, 0, 1], [1, -1, 1, -1]]) | ||
| maxIterSearch = 1e6 | ||
| it = 0 | ||
| tol = 1.0e-10 | ||
| if not grid.zonal_periodic: | ||
| if x < grid.lonlat_minmax[0] or x > grid.lonlat_minmax[1]: | ||
| if grid.lon[0, 0] < grid.lon[0, -1]: | ||
| _raise_field_out_of_bound_error(z, y, x) | ||
| elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160] | ||
| _raise_field_out_of_bound_error(z, y, x) | ||
| if y < grid.lonlat_minmax[2] or y > grid.lonlat_minmax[3]: | ||
| _raise_field_out_of_bound_error(z, y, x) | ||
|
|
||
| while xsi < -tol or xsi > 1 + tol or eta < -tol or eta > 1 + tol: | ||
| px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) | ||
| if grid.mesh == "spherical": | ||
| px[0] = px[0] + 360 if px[0] < x - 225 else px[0] | ||
| px[0] = px[0] - 360 if px[0] > x + 225 else px[0] | ||
| px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:]) | ||
| px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:]) | ||
| py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) | ||
| a = np.dot(invA, px) | ||
| b = np.dot(invA, py) | ||
|
|
||
| aa = a[3] * b[2] - a[2] * b[3] | ||
| bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3] | ||
| cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1] | ||
| if abs(aa) < 1e-12: # Rectilinear cell, or quasi | ||
| eta = -cc / bb | ||
| else: | ||
| det2 = bb * bb - 4 * aa * cc | ||
| if det2 > 0: # so, if det is nan we keep the xsi, eta from previous iter | ||
| det = np.sqrt(det2) | ||
| eta = (-bb + det) / (2 * aa) | ||
| if abs(a[1] + a[3] * eta) < 1e-12: # this happens when recti cell rotated of 90deg | ||
| xsi = ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5 | ||
| else: | ||
| xsi = (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta) | ||
| if xsi < 0 and eta < 0 and xi == 0 and yi == 0: | ||
| _raise_field_out_of_bound_error(0, y, x) | ||
| if xsi > 1 and eta > 1 and xi == grid.xdim - 1 and yi == grid.ydim - 1: | ||
| _raise_field_out_of_bound_error(0, y, x) | ||
| if xsi < -tol: | ||
| xi -= 1 | ||
| elif xsi > 1 + tol: | ||
| xi += 1 | ||
| if eta < -tol: | ||
| yi -= 1 | ||
| elif eta > 1 + tol: | ||
| yi += 1 | ||
| (yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh) | ||
| it += 1 | ||
| if it > maxIterSearch: | ||
| print(f"Correct cell not found after {maxIterSearch} iterations") | ||
| _raise_field_out_of_bound_error(0, y, x) | ||
| xsi = max(0.0, xsi) | ||
| eta = max(0.0, eta) | ||
| xsi = min(1.0, xsi) | ||
| eta = min(1.0, eta) | ||
|
|
||
| if grid.zdim > 1 and not search2D: | ||
| if grid._gtype == GridType.CurvilinearZGrid: | ||
| try: | ||
| (zi, zeta) = search_indices_vertical_z(field.grid, field.gridindexingtype, z) | ||
| except FieldOutOfBoundError: | ||
| _raise_field_out_of_bound_error(z, y, x) | ||
| elif grid._gtype == GridType.CurvilinearSGrid: | ||
| (zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi) | ||
| else: | ||
| zi = -1 | ||
| zeta = 0 | ||
|
|
||
| if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)): | ||
| _raise_field_sampling_error(z, y, x) | ||
|
|
||
| if particle: | ||
| particle.xi[field.igrid] = xi | ||
| particle.yi[field.igrid] = yi | ||
| particle.zi[field.igrid] = zi | ||
|
|
||
| return (zeta, eta, xsi, zi, yi, xi) | ||
|
|
||
|
|
||
| def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, sphere_mesh: bool): | ||
| if xi < 0: | ||
| if sphere_mesh: | ||
| xi = xdim - 2 | ||
| else: | ||
| xi = 0 | ||
| if xi > xdim - 2: | ||
| if sphere_mesh: | ||
| xi = 0 | ||
| else: | ||
| xi = xdim - 2 | ||
| if yi < 0: | ||
| yi = 0 | ||
| if yi > ydim - 2: | ||
| yi = ydim - 2 | ||
| if sphere_mesh: | ||
| xi = xdim - xi | ||
| return yi, xi | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.