Skip to content
218 changes: 190 additions & 28 deletions parcels/application_kernels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
"CGrid_Velocity",
"UXPiecewiseConstantFace",
"UXPiecewiseLinearNode",
"XFreeslip",
"XLinear",
"XNearest",
"XPartialslip",
"ZeroInterpolator",
"ZeroInterpolator_Vector",
]
Expand All @@ -30,7 +32,7 @@
def ZeroInterpolator(
field: Field,
ti: int,
position: dict[str, tuple[int, float | np.ndarray]],
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
Expand All @@ -44,7 +46,7 @@ def ZeroInterpolator(
def ZeroInterpolator_Vector(
vectorfield: VectorField,
ti: int,
position: dict[str, tuple[int, float | np.ndarray]],
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
Expand All @@ -56,48 +58,38 @@ def ZeroInterpolator_Vector(
return 0.0


def XLinear(
field: Field,
def _get_corner_data_Agrid(
data: np.ndarray | xr.DataArray,
ti: int,
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
y: np.float32 | np.float64,
x: np.float32 | np.float64,
):
"""Trilinear interpolation on a regular grid."""
xi, xsi = position["X"]
yi, eta = position["Y"]
zi, zeta = position["Z"]

axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
data = field.data
tdim, zdim, ydim, xdim = data.shape[0], data.shape[1], data.shape[2], data.shape[3]

lenT = 2 if np.any(tau > 0) else 1
lenZ = 2 if np.any(zeta > 0) else 1

zi: int,
yi: int,
xi: int,
lenT: int,
lenZ: int,
npart: int,
axis_dim: dict[str, str],
) -> np.ndarray:
"""Helper function to get the corner data for a given A-grid field and position."""
# Time coordinates: 8 points at ti, then 8 points at ti+1
if lenT == 1:
ti = np.repeat(ti, lenZ * 4)
else:
ti_1 = np.clip(ti + 1, 0, tdim - 1)
ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1)
ti = np.concatenate([np.repeat(ti, lenZ * 4), np.repeat(ti_1, lenZ * 4)])

# Depth coordinates: 4 points at zi, 4 at zi+1, repeated for both time levels
if lenZ == 1:
zi = np.repeat(zi, lenT * 4)
else:
zi_1 = np.clip(zi + 1, 0, zdim - 1)
zi_1 = np.clip(zi + 1, 0, data.shape[1] - 1)
zi = np.tile(np.array([zi, zi, zi, zi, zi_1, zi_1, zi_1, zi_1]).flatten(), lenT)

# Y coordinates: [yi, yi, yi+1, yi+1] for each spatial point, repeated for time/depth
yi_1 = np.clip(yi + 1, 0, ydim - 1)
yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1)
yi = np.tile(np.repeat(np.column_stack([yi, yi_1]), 2), (lenT) * (lenZ))

# X coordinates: [xi, xi+1, xi, xi+1] for each spatial point, repeated for time/depth
xi_1 = np.clip(xi + 1, 0, xdim - 1)
xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1)
xi = np.tile(np.column_stack([xi, xi_1, xi, xi_1]).flatten(), (lenT) * (lenZ))

# Create DataArrays for indexing
Expand All @@ -110,7 +102,31 @@ def XLinear(
if "time" in data.dims:
selection_dict["time"] = xr.DataArray(ti, dims=("points"))

corner_data = data.isel(selection_dict).data.reshape(lenT, lenZ, len(xsi), 4)
return data.isel(selection_dict).data.reshape(lenT, lenZ, npart, 4)


def XLinear(
field: Field,
ti: int,
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
y: np.float32 | np.float64,
x: np.float32 | np.float64,
):
"""Trilinear interpolation on a regular grid."""
xi, xsi = position["X"]
yi, eta = position["Y"]
zi, zeta = position["Z"]

axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
data = field.data

lenT = 2 if np.any(tau > 0) else 1
lenZ = 2 if np.any(zeta > 0) else 1

corner_data = _get_corner_data_Agrid(data, ti, zi, yi, xi, lenT, lenZ, len(xsi), axis_dim)

if lenT == 2:
tau = tau[np.newaxis, :, np.newaxis]
Expand Down Expand Up @@ -392,6 +408,152 @@ def CGrid_Tracer(
return value.compute() if is_dask_collection(value) else value


def _Spatialslip(
vectorfield: VectorField,
ti: int,
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
y: np.float32 | np.float64,
x: np.float32 | np.float64,
a: np.float32,
b: np.float32,
):
"""Helper function for spatial boundary condition interpolation for velocity fields."""
xi, xsi = position["X"]
yi, eta = position["Y"]
zi, zeta = position["Z"]

axis_dim = vectorfield.U.grid.get_axis_dim_mapping(vectorfield.U.data.dims)
lenT = 2 if np.any(tau > 0) else 1
lenZ = 2 if np.any(zeta > 0) else 1
npart = len(xsi)

u = XLinear(vectorfield.U, ti, position, tau, t, z, y, x)
v = XLinear(vectorfield.V, ti, position, tau, t, z, y, x)
if vectorfield.W:
w = XLinear(vectorfield.W, ti, position, tau, t, z, y, x)

corner_dataU = _get_corner_data_Agrid(vectorfield.U.data, ti, zi, yi, xi, lenT, lenZ, npart, axis_dim)
corner_dataV = _get_corner_data_Agrid(vectorfield.V.data, ti, zi, yi, xi, lenT, lenZ, npart, axis_dim)

def is_land(ti: int, zi: int, yi: int, xi: int):
uval = corner_dataU[ti, zi, :, xi + 2 * yi]
vval = corner_dataV[ti, zi, :, xi + 2 * yi]
return np.where(np.isclose(uval, 0.0) & np.isclose(vval, 0.0), True, False)

f_u = np.ones_like(xsi)
f_v = np.ones_like(eta)

if lenZ == 1:
f_u = np.where(is_land(0, 0, 0, 0) & is_land(0, 0, 0, 1) & (eta > 0), f_u * (a + b * eta) / eta, f_u)
f_u = np.where(is_land(0, 0, 1, 0) & is_land(0, 0, 1, 1) & (eta < 1), f_u * (1 - b * eta) / (1 - eta), f_u)
f_v = np.where(is_land(0, 0, 0, 0) & is_land(0, 0, 1, 0) & (xsi > 0), f_v * (a + b * xsi) / xsi, f_v)
f_v = np.where(is_land(0, 0, 0, 1) & is_land(0, 0, 1, 1) & (xsi < 1), f_v * (1 - b * xsi) / (1 - xsi), f_v)
else:
f_u = np.where(
is_land(0, 0, 0, 0) & is_land(0, 0, 0, 1) & is_land(0, 1, 0, 0) & is_land(0, 1, 0, 1) & (eta > 0),
f_u * (a + b * eta) / eta,
f_u,
)
f_u = np.where(
is_land(0, 0, 1, 0) & is_land(0, 0, 1, 1) & is_land(0, 1, 1, 0) & is_land(0, 1, 1, 1) & (eta < 1),
f_u * (1 - b * eta) / (1 - eta),
f_u,
)
f_v = np.where(
is_land(0, 0, 0, 0) & is_land(0, 0, 1, 0) & is_land(0, 1, 0, 0) & is_land(0, 1, 1, 0) & (xsi > 0),
f_v * (a + b * xsi) / xsi,
f_v,
)
f_v = np.where(
is_land(0, 0, 0, 1) & is_land(0, 0, 1, 1) & is_land(0, 1, 0, 1) & is_land(0, 1, 1, 1) & (xsi < 1),
f_v * (1 - b * xsi) / (1 - xsi),
f_v,
)
f_u = np.where(
is_land(0, 0, 0, 0) & is_land(0, 0, 0, 1) & is_land(0, 0, 1, 0 & is_land(0, 0, 1, 1) & (zeta > 0)),
f_u * (a + b * zeta) / zeta,
f_u,
)
f_u = np.where(
is_land(0, 1, 0, 0) & is_land(0, 1, 0, 1) & is_land(0, 1, 1, 0 & is_land(0, 1, 1, 1) & (zeta < 1)),
f_u * (1 - b * zeta) / (1 - zeta),
f_u,
)
f_v = np.where(
is_land(0, 0, 0, 0) & is_land(0, 0, 0, 1) & is_land(0, 0, 1, 0 & is_land(0, 0, 1, 1) & (zeta > 0)),
f_v * (a + b * zeta) / zeta,
f_v,
)
f_v = np.where(
is_land(0, 1, 0, 0) & is_land(0, 1, 0, 1) & is_land(0, 1, 1, 0 & is_land(0, 1, 1, 1) & (zeta < 1)),
f_v * (1 - b * zeta) / (1 - zeta),
f_v,
)

u *= f_u
v *= f_v
if vectorfield.W:
f_w = np.ones_like(zeta)
f_w = np.where(
is_land(0, 0, 0, 0) & is_land(0, 0, 0, 1) & is_land(0, 1, 0, 0) & is_land(0, 1, 0, 1) & (eta > 0),
f_w * (a + b * eta) / eta,
f_w,
)
f_w = np.where(
is_land(0, 0, 1, 0) & is_land(0, 0, 1, 1) & is_land(0, 1, 1, 0) & is_land(0, 1, 1, 1) & (eta < 1),
f_w * (a - b * eta) / (1 - eta),
f_w,
)
f_w = np.where(
is_land(0, 0, 0, 0) & is_land(0, 0, 1, 0) & is_land(0, 1, 0, 0) & is_land(0, 1, 1, 0) & (xsi > 0),
f_w * (a + b * xsi) / xsi,
f_w,
)
f_w = np.where(
is_land(0, 0, 0, 1) & is_land(0, 0, 1, 1) & is_land(0, 1, 0, 1) & is_land(0, 1, 1, 1) & (xsi < 1),
f_w * (a - b * xsi) / (1 - xsi),
f_w,
)

w *= f_w
else:
w = None
return u, v, w


def XFreeslip(
vectorfield: VectorField,
ti: int,
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
y: np.float32 | np.float64,
x: np.float32 | np.float64,
applyConversion: bool,
):
"""Free-slip boundary condition interpolation for velocity fields."""
return _Spatialslip(vectorfield, ti, position, tau, t, z, y, x, a=1.0, b=0.0)


def XPartialslip(
vectorfield: VectorField,
ti: int,
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
y: np.float32 | np.float64,
x: np.float32 | np.float64,
applyConversion: bool,
):
"""Partial-slip boundary condition interpolation for velocity fields."""
return _Spatialslip(vectorfield, ti, position, tau, t, z, y, x, a=0.5, b=0.5)


def XNearest(
field: Field,
ti: int,
Expand Down
39 changes: 38 additions & 1 deletion tests/v4/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
from parcels._index_search import _search_time_index
from parcels.application_kernels.advection import AdvectionRK4_3D
from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear, XNearest, ZeroInterpolator
from parcels.application_kernels.interpolation import (
UXPiecewiseLinearNode,
XFreeslip,
XLinear,
XNearest,
XPartialslip,
ZeroInterpolator,
)
from parcels.field import Field, VectorField
from parcels.fieldset import FieldSet
from parcels.particle import Particle, Variable
Expand Down Expand Up @@ -80,6 +87,36 @@ def test_raw_2d_interpolation(field, func, t, z, y, x, expected):
np.testing.assert_equal(value, expected)


@pytest.mark.parametrize(
"func, t, z, y, x, expected",
[
(XPartialslip, np.timedelta64(1, "s"), 0, 0, 0.0, [[1], [1]]),
(XFreeslip, np.timedelta64(1, "s"), 0, 0.5, 1.5, [[1], [0.5]]),
(XPartialslip, np.timedelta64(1, "s"), 0, 2.5, 1.5, [[0.75], [0.5]]),
(XFreeslip, np.timedelta64(1, "s"), 0, 2.5, 1.5, [[1], [0.5]]),
(XPartialslip, np.timedelta64(1, "s"), 0, 1.5, 0.5, [[0.5], [0.75]]),
(XFreeslip, np.timedelta64(1, "s"), 0, 1.5, 0.5, [[0.5], [1]]),
(
XFreeslip,
[np.timedelta64(1, "s"), np.timedelta64(0, "s")],
[0, 2],
[1.5, 1.5],
[2.5, 0.5],
[[0.5, 0.5], [1, 1]],
),
],
)
def test_spatial_slip_interpolation(field, func, t, z, y, x, expected):
field.data[:] = 1.0
field.data[:, :, 1:3, 1:3] = 0.0 # Set zero land value to test spatial slip
U = field
V = field
UV = VectorField("UV", U, V, vector_interp_method=func)

velocities = UV[t, z, y, x]
np.testing.assert_array_almost_equal(velocities, expected)


@pytest.mark.parametrize("mesh", ["spherical", "flat"])
def test_interpolation_mesh_type(mesh, npart=10):
ds = simple_UV_dataset(mesh=mesh)
Expand Down
Loading