Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions parcels/_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def _search_time_index(field: Field, time: datetime):
if not field.time_interval.is_all_time_in_interval(time):
_raise_time_extrapolation_error(time, field=None)

# TODO this could be sped up when data has only two timeslices (i.e. when data_full is not None)?
ti = np.searchsorted(field.data.time.data, time, side="right") - 1
tau = (time - field.data.time.data[ti]) / (field.data.time.data[ti + 1] - field.data.time.data[ti])
return np.atleast_1d(tau), np.atleast_1d(ti)
Expand Down
40 changes: 33 additions & 7 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import uxarray as ux
import xarray as xr
from dask import is_dask_collection

from parcels._core.utils.time import TimeInterval
from parcels._reprs import default_repr
Expand Down Expand Up @@ -132,8 +133,14 @@ def __init__(
data = _transpose_xfield_data_to_tzyx(data, grid.xgcm_grid)

self.name = name
self.data = data
self.grid = grid
if is_dask_collection(data) and ("time" in data.dims):
self.data = None
self.data_full = data
else:
self.data = data
self.data_full = None
self._nexttime_to_load = None

try:
self.time_interval = _get_time_interval(data)
Expand Down Expand Up @@ -167,8 +174,8 @@ def __init__(
elif self.grid._mesh == "spherical":
self.units = unitconverters_map[self.name]

if self.data.shape[0] > 1:
if "time" not in self.data.coords:
if data.shape[0] > 1:
if "time" not in data.coords:
raise ValueError("Field data is missing a 'time' coordinate.")

@property
Expand All @@ -183,25 +190,29 @@ def units(self, value):

@property
def xdim(self):
if type(self.data) is xr.DataArray:
if hasattr(self.grid, "xdim"):
return self.grid.xdim
else:
raise NotImplementedError("xdim not implemented for unstructured grids")

@property
def ydim(self):
if type(self.data) is xr.DataArray:
if hasattr(self.grid, "ydim"):
return self.grid.ydim
else:
raise NotImplementedError("ydim not implemented for unstructured grids")

@property
def zdim(self):
if type(self.data) is xr.DataArray:
if hasattr(self.grid, "zdim"):
return self.grid.zdim
else:
if "nz1" in self.data.dims:
if "nz1" in self.data_full.dims:
return self.data_full.sizes["nz1"]
elif "nz1" in self.data.dims:
return self.data.sizes["nz1"]
elif "nz" in self.data_full.dims:
return self.data_full.sizes["nz"]
elif "nz" in self.data.dims:
return self.data.sizes["nz"]
else:
Expand All @@ -224,6 +235,21 @@ def _check_velocitysampling(self):
stacklevel=2,
)

def _load_timesteps(self, time):
"""Load the appropriate timesteps of a field."""
if self.data_full is not None:
ti = np.argmin(self.data_full.time.data <= time) - 1 # TODO also implement dt < 0
if self.data is None:
self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load()
elif self.data_full.time.data[ti] == self.data.time.data[1]:
self.data = xr.concat([self.data[1, :], self.data_full.isel({"time": ti + 1}).load()], dim="time")
elif self.data_full.time.data[ti] != self.data.time.data[0]:
self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load()
assert len(self.data.time) == 2, (
f"Field {self.name} has not been loaded correctly. Expected 2 timesteps, but got {len(self.data.time)}."
)
self._nexttime_to_load = self.data_full.time.data[ti + 1]

def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
"""Interpolate field values in space and time.

Expand Down
12 changes: 12 additions & 0 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ def time_interval(self):
return None
return functools.reduce(lambda x, y: x.intersection(y), time_intervals)

def _load_timesteps(self, time):
"""Load the appropriate timesteps of all fields in the fieldset."""
next_times = []
for fldname in self.fields:
field = self.fields[fldname]
if isinstance(field, Field):
field._load_timesteps(time)
if field._nexttime_to_load is not None:
next_times.append(field._nexttime_to_load)

return min(next_times) if next_times else None

def add_field(self, field: Field, name: str | None = None):
"""Add a :class:`parcels.field.Field` object to the FieldSet.

Expand Down
13 changes: 9 additions & 4 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,16 @@ def execute(

time = start_time
while sign_dt * (time - end_time) < 0:
# Load the appropriate timesteps of the fieldset
next_load_time = self.fieldset._load_timesteps(time)

possible_next_time = [end_time]
if next_load_time is not None:
possible_next_time.append(next_load_time)
if next_output is not None:
f = min if sign_dt > 0 else max
next_time = f(next_output, end_time)
else:
next_time = end_time
possible_next_time.append(next_output)
f = min if sign_dt > 0 else max
next_time = f(possible_next_time)

self._kernel.execute(self, endtime=next_time, dt=dt)

Expand Down
10 changes: 6 additions & 4 deletions tests/v4/test_advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def test_horizontal_advection_in_3D_flow(npart=10):
"""Flat 2D zonal flow that increases linearly with depth from 0 m/s to 1 m/s."""
ds = simple_UV_dataset(mesh="flat")
ds["U"].data[:] = 1.0
ds["U"].data[:, 0, :, :] = 0.0 # Set U to 0 at the surface
grid = XGrid.from_dataset(ds)
U = Field("U", ds["U"], grid, interp_method=XLinear)
U.data[:, 0, :, :] = 0.0 # Set U to 0 at the surface
V = Field("V", ds["V"], grid, interp_method=XLinear)
UV = VectorField("UV", U, V)
fieldset = FieldSet([U, V, UV])
Expand All @@ -128,12 +128,13 @@ def test_horizontal_advection_in_3D_flow(npart=10):
@pytest.mark.parametrize("wErrorThroughSurface", [True, False])
def test_advection_3D_outofbounds(direction, wErrorThroughSurface):
ds = simple_UV_dataset(mesh="flat")
ds["U"].data[:] = 0.01 # Set U to small value (to avoid horizontal out of bounds)
ds["W"] = ds["V"].copy() # Use V as W for testing
ds["W"].data[:] = -1.0 if direction == "up" else 1.0
grid = XGrid.from_dataset(ds)
U = Field("U", ds["U"], grid, interp_method=XLinear)
U.data[:] = 0.01 # Set U to small value (to avoid horizontal out of bounds)
V = Field("V", ds["V"], grid, interp_method=XLinear)
W = Field("W", ds["V"], grid, interp_method=XLinear) # Use V as W for testing
W.data[:] = -1.0 if direction == "up" else 1.0
W = Field("W", ds["W"], grid, interp_method=XLinear)
UVW = VectorField("UVW", U, V, W)
UV = VectorField("UV", U, V)
fieldset = FieldSet([U, V, W, UVW, UV])
Expand Down Expand Up @@ -213,6 +214,7 @@ def test_length1dimensions(u, v, w): # TODO: Refactor this test to be more read
fields = [U, V, VectorField("UV", U, V)]
if w:
W = Field("W", ds["W"], grid, interp_method=XLinear)
fields.append(W)
fields.append(VectorField("UVW", U, V, W))
fieldset = FieldSet(fields)

Expand Down
12 changes: 6 additions & 6 deletions tests/v4/test_uxarray_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,25 @@ def uvw_fesom_channel(ds_fesom_channel) -> VectorField:
def test_fesom_fieldset(ds_fesom_channel, uv_fesom_channel):
fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V])
# Check that the fieldset has the expected properties
assert (fieldset.U.data == ds_fesom_channel.U).all()
assert (fieldset.V.data == ds_fesom_channel.V).all()
assert (fieldset.U.data_full == ds_fesom_channel.U).all()
assert (fieldset.V.data_full == ds_fesom_channel.V).all()


def test_fesom_in_particleset(ds_fesom_channel, uv_fesom_channel):
fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V])

# Check that the fieldset has the expected properties
assert (fieldset.U.data == ds_fesom_channel.U).all()
assert (fieldset.V.data == ds_fesom_channel.V).all()
assert (fieldset.U.data_full == ds_fesom_channel.U).all()
assert (fieldset.V.data_full == ds_fesom_channel.V).all()
pset = ParticleSet(fieldset, pclass=Particle)
assert pset.fieldset == fieldset


def test_set_interp_methods(ds_fesom_channel, uv_fesom_channel):
fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V])
# Check that the fieldset has the expected properties
assert (fieldset.U.data == ds_fesom_channel.U).all()
assert (fieldset.V.data == ds_fesom_channel.V).all()
assert (fieldset.U.data_full == ds_fesom_channel.U).all()
assert (fieldset.V.data_full == ds_fesom_channel.V).all()

# Set the interpolation method for each field
fieldset.U.interp_method = UXPiecewiseConstantFace
Expand Down
Loading