Skip to content

Commit a5fa459

Browse files
Merge branch 'v4-dev' into structured-grid-interpolators
2 parents ae58c0e + 488e3fb commit a5fa459

File tree

13 files changed

+808
-1231
lines changed

13 files changed

+808
-1231
lines changed

docs/v4/api.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,11 @@ Here, important things to note are:
4343
- The `grid` object, in the case of unstructured grids, will be the `Grid` class from UXarray. For structured `Grid`s, it will be an object similar to that of `xgcm.Grid` (note that it will be very different from the v3 `Grid` object hierarchy).
4444

4545
- The `Field.eval` method takes as input the t,z,y,x spatio-temporal position as required arguments; the `particle` is optional and defaults to `None` and the `applyConversion` argument is optional and defaults to `True`. Initially, we will calculate the element index for a particle. As a future optimization, we could pass via the `particle` object a "cached" index value that could be used to bypass an index search. This will effectively provide `(ti,zi,yi,xi)` on a structured grid and `(ti,zi,fi)` on an unstructured grid (where `fi` is the lateral face id); within `eval` these indices will be `ravel`'ed to a single index that can be `unravel`'ed in the `interpolate` method. The `ravel`'ed index is referred to as `rid` in the `Field.Interpolator.interpolate` method. In the `interpolate` method, we envision that a user will benefit from knowing the nearest cell/index from the `ravel`'ed index (which can be `unravel`'ed) in addition the exact coordinate that we want to interpolate onto. This can permit calculation of interpolation weights using points in the neighborhood of `(t,z,y,x)`.
46+
47+
## Changes in API
48+
49+
Below a list of changes in the API that are relevant to users:
50+
51+
- `starttime`, `endtime` and `dt` in `ParticleSet.execute()` are now `numpy.timedelta64` or `numpy.datetime64` objects. This allows for more precise time handling and is consistent with the `numpy` time handling.
52+
53+
- `pid_orig` in `ParticleSet` is removed. Instead, `trajectory_ids` is used to provide a list of "trajectory" values (integers) for the particle IDs.

parcels/application_kernels/advection.py

Lines changed: 58 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,73 +16,76 @@
1616

1717
def AdvectionRK4(particle, fieldset, time): # pragma: no cover
1818
"""Advection of particles using fourth-order Runge-Kutta integration."""
19+
dt = particle.dt / np.timedelta64(1, "s") # noqa TODO improve API for converting dt to seconds
1920
(u1, v1) = fieldset.UV[particle]
20-
lon1, lat1 = (particle.lon + u1 * 0.5 * particle.dt, particle.lat + v1 * 0.5 * particle.dt)
21+
lon1, lat1 = (particle.lon + u1 * 0.5 * dt, particle.lat + v1 * 0.5 * dt)
2122
(u2, v2) = fieldset.UV[time + 0.5 * particle.dt, particle.depth, lat1, lon1, particle]
22-
lon2, lat2 = (particle.lon + u2 * 0.5 * particle.dt, particle.lat + v2 * 0.5 * particle.dt)
23+
lon2, lat2 = (particle.lon + u2 * 0.5 * dt, particle.lat + v2 * 0.5 * dt)
2324
(u3, v3) = fieldset.UV[time + 0.5 * particle.dt, particle.depth, lat2, lon2, particle]
24-
lon3, lat3 = (particle.lon + u3 * particle.dt, particle.lat + v3 * particle.dt)
25+
lon3, lat3 = (particle.lon + u3 * dt, particle.lat + v3 * dt)
2526
(u4, v4) = fieldset.UV[time + particle.dt, particle.depth, lat3, lon3, particle]
26-
particle_dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6.0 * particle.dt # noqa
27-
particle_dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6.0 * particle.dt # noqa
27+
particle_dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6.0 * dt # noqa
28+
particle_dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6.0 * dt # noqa
2829

2930

3031
def AdvectionRK4_3D(particle, fieldset, time): # pragma: no cover
3132
"""Advection of particles using fourth-order Runge-Kutta integration including vertical velocity."""
33+
dt = particle.dt / np.timedelta64(1, "s") # noqa TODO improve API for converting dt to seconds
3234
(u1, v1, w1) = fieldset.UVW[particle]
33-
lon1 = particle.lon + u1 * 0.5 * particle.dt
34-
lat1 = particle.lat + v1 * 0.5 * particle.dt
35-
dep1 = particle.depth + w1 * 0.5 * particle.dt
35+
lon1 = particle.lon + u1 * 0.5 * dt
36+
lat1 = particle.lat + v1 * 0.5 * dt
37+
dep1 = particle.depth + w1 * 0.5 * dt
3638
(u2, v2, w2) = fieldset.UVW[time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
37-
lon2 = particle.lon + u2 * 0.5 * particle.dt
38-
lat2 = particle.lat + v2 * 0.5 * particle.dt
39-
dep2 = particle.depth + w2 * 0.5 * particle.dt
39+
lon2 = particle.lon + u2 * 0.5 * dt
40+
lat2 = particle.lat + v2 * 0.5 * dt
41+
dep2 = particle.depth + w2 * 0.5 * dt
4042
(u3, v3, w3) = fieldset.UVW[time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
41-
lon3 = particle.lon + u3 * particle.dt
42-
lat3 = particle.lat + v3 * particle.dt
43-
dep3 = particle.depth + w3 * particle.dt
43+
lon3 = particle.lon + u3 * dt
44+
lat3 = particle.lat + v3 * dt
45+
dep3 = particle.depth + w3 * dt
4446
(u4, v4, w4) = fieldset.UVW[time + particle.dt, dep3, lat3, lon3, particle]
45-
particle_dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6 * particle.dt # noqa
46-
particle_dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6 * particle.dt # noqa
47-
particle_ddepth += (w1 + 2 * w2 + 2 * w3 + w4) / 6 * particle.dt # noqa
47+
particle_dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6 * dt # noqa
48+
particle_dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6 * dt # noqa
49+
particle_ddepth += (w1 + 2 * w2 + 2 * w3 + w4) / 6 * dt # noqa
4850

4951

5052
def AdvectionRK4_3D_CROCO(particle, fieldset, time): # pragma: no cover
5153
"""Advection of particles using fourth-order Runge-Kutta integration including vertical velocity.
5254
This kernel assumes the vertical velocity is the 'w' field from CROCO output and works on sigma-layers.
5355
"""
56+
dt = particle.dt / np.timedelta64(1, "s") # noqa TODO improve API for converting dt to seconds
5457
sig_dep = particle.depth / fieldset.H[time, 0, particle.lat, particle.lon]
5558

5659
(u1, v1, w1) = fieldset.UVW[time, particle.depth, particle.lat, particle.lon, particle]
5760
w1 *= sig_dep / fieldset.H[time, 0, particle.lat, particle.lon]
58-
lon1 = particle.lon + u1 * 0.5 * particle.dt
59-
lat1 = particle.lat + v1 * 0.5 * particle.dt
60-
sig_dep1 = sig_dep + w1 * 0.5 * particle.dt
61+
lon1 = particle.lon + u1 * 0.5 * dt
62+
lat1 = particle.lat + v1 * 0.5 * dt
63+
sig_dep1 = sig_dep + w1 * 0.5 * dt
6164
dep1 = sig_dep1 * fieldset.H[time, 0, lat1, lon1]
6265

6366
(u2, v2, w2) = fieldset.UVW[time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
6467
w2 *= sig_dep1 / fieldset.H[time, 0, lat1, lon1]
65-
lon2 = particle.lon + u2 * 0.5 * particle.dt
66-
lat2 = particle.lat + v2 * 0.5 * particle.dt
67-
sig_dep2 = sig_dep + w2 * 0.5 * particle.dt
68+
lon2 = particle.lon + u2 * 0.5 * dt
69+
lat2 = particle.lat + v2 * 0.5 * dt
70+
sig_dep2 = sig_dep + w2 * 0.5 * dt
6871
dep2 = sig_dep2 * fieldset.H[time, 0, lat2, lon2]
6972

7073
(u3, v3, w3) = fieldset.UVW[time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
7174
w3 *= sig_dep2 / fieldset.H[time, 0, lat2, lon2]
72-
lon3 = particle.lon + u3 * particle.dt
73-
lat3 = particle.lat + v3 * particle.dt
74-
sig_dep3 = sig_dep + w3 * particle.dt
75+
lon3 = particle.lon + u3 * dt
76+
lat3 = particle.lat + v3 * dt
77+
sig_dep3 = sig_dep + w3 * dt
7578
dep3 = sig_dep3 * fieldset.H[time, 0, lat3, lon3]
7679

7780
(u4, v4, w4) = fieldset.UVW[time + particle.dt, dep3, lat3, lon3, particle]
7881
w4 *= sig_dep3 / fieldset.H[time, 0, lat3, lon3]
79-
lon4 = particle.lon + u4 * particle.dt
80-
lat4 = particle.lat + v4 * particle.dt
81-
sig_dep4 = sig_dep + w4 * particle.dt
82+
lon4 = particle.lon + u4 * dt
83+
lat4 = particle.lat + v4 * dt
84+
sig_dep4 = sig_dep + w4 * dt
8285
dep4 = sig_dep4 * fieldset.H[time, 0, lat4, lon4]
8386

84-
particle_dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6 * particle.dt # noqa
85-
particle_dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6 * particle.dt # noqa
87+
particle_dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6 * dt # noqa
88+
particle_dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6 * dt # noqa
8689
particle_ddepth += ( # noqa
8790
(dep1 - particle.depth) * 2
8891
+ 2 * (dep2 - particle.depth) * 2
@@ -94,9 +97,10 @@ def AdvectionRK4_3D_CROCO(particle, fieldset, time): # pragma: no cover
9497

9598
def AdvectionEE(particle, fieldset, time): # pragma: no cover
9699
"""Advection of particles using Explicit Euler (aka Euler Forward) integration."""
100+
dt = particle.dt / np.timedelta64(1, "s") # noqa TODO improve API for converting dt to seconds
97101
(u1, v1) = fieldset.UV[particle]
98-
particle_dlon += u1 * particle.dt # noqa
99-
particle_dlat += v1 * particle.dt # noqa
102+
particle_dlon += u1 * dt # noqa
103+
particle_dlat += v1 * dt # noqa
100104

101105

102106
def AdvectionRK45(particle, fieldset, time): # pragma: no cover
@@ -109,7 +113,7 @@ def AdvectionRK45(particle, fieldset, time): # pragma: no cover
109113
Time-step dt is halved if error is larger than fieldset.RK45_tol,
110114
and doubled if error is smaller than 1/10th of tolerance.
111115
"""
112-
particle.dt = min(particle.next_dt, fieldset.RK45_max_dt)
116+
dt = min(particle.next_dt, fieldset.RK45_max_dt) / np.timedelta64(1, "s") # noqa TODO improve API for converting dt to seconds
113117
c = [1.0 / 4.0, 3.0 / 8.0, 12.0 / 13.0, 1.0, 1.0 / 2.0]
114118
A = [
115119
[1.0 / 4.0, 0.0, 0.0, 0.0, 0.0],
@@ -122,39 +126,39 @@ def AdvectionRK45(particle, fieldset, time): # pragma: no cover
122126
b5 = [16.0 / 135.0, 0.0, 6656.0 / 12825.0, 28561.0 / 56430.0, -9.0 / 50.0, 2.0 / 55.0]
123127

124128
(u1, v1) = fieldset.UV[particle]
125-
lon1, lat1 = (particle.lon + u1 * A[0][0] * particle.dt, particle.lat + v1 * A[0][0] * particle.dt)
129+
lon1, lat1 = (particle.lon + u1 * A[0][0] * dt, particle.lat + v1 * A[0][0] * dt)
126130
(u2, v2) = fieldset.UV[time + c[0] * particle.dt, particle.depth, lat1, lon1, particle]
127131
lon2, lat2 = (
128-
particle.lon + (u1 * A[1][0] + u2 * A[1][1]) * particle.dt,
129-
particle.lat + (v1 * A[1][0] + v2 * A[1][1]) * particle.dt,
132+
particle.lon + (u1 * A[1][0] + u2 * A[1][1]) * dt,
133+
particle.lat + (v1 * A[1][0] + v2 * A[1][1]) * dt,
130134
)
131135
(u3, v3) = fieldset.UV[time + c[1] * particle.dt, particle.depth, lat2, lon2, particle]
132136
lon3, lat3 = (
133-
particle.lon + (u1 * A[2][0] + u2 * A[2][1] + u3 * A[2][2]) * particle.dt,
134-
particle.lat + (v1 * A[2][0] + v2 * A[2][1] + v3 * A[2][2]) * particle.dt,
137+
particle.lon + (u1 * A[2][0] + u2 * A[2][1] + u3 * A[2][2]) * dt,
138+
particle.lat + (v1 * A[2][0] + v2 * A[2][1] + v3 * A[2][2]) * dt,
135139
)
136140
(u4, v4) = fieldset.UV[time + c[2] * particle.dt, particle.depth, lat3, lon3, particle]
137141
lon4, lat4 = (
138-
particle.lon + (u1 * A[3][0] + u2 * A[3][1] + u3 * A[3][2] + u4 * A[3][3]) * particle.dt,
139-
particle.lat + (v1 * A[3][0] + v2 * A[3][1] + v3 * A[3][2] + v4 * A[3][3]) * particle.dt,
142+
particle.lon + (u1 * A[3][0] + u2 * A[3][1] + u3 * A[3][2] + u4 * A[3][3]) * dt,
143+
particle.lat + (v1 * A[3][0] + v2 * A[3][1] + v3 * A[3][2] + v4 * A[3][3]) * dt,
140144
)
141145
(u5, v5) = fieldset.UV[time + c[3] * particle.dt, particle.depth, lat4, lon4, particle]
142146
lon5, lat5 = (
143-
particle.lon + (u1 * A[4][0] + u2 * A[4][1] + u3 * A[4][2] + u4 * A[4][3] + u5 * A[4][4]) * particle.dt,
144-
particle.lat + (v1 * A[4][0] + v2 * A[4][1] + v3 * A[4][2] + v4 * A[4][3] + v5 * A[4][4]) * particle.dt,
147+
particle.lon + (u1 * A[4][0] + u2 * A[4][1] + u3 * A[4][2] + u4 * A[4][3] + u5 * A[4][4]) * dt,
148+
particle.lat + (v1 * A[4][0] + v2 * A[4][1] + v3 * A[4][2] + v4 * A[4][3] + v5 * A[4][4]) * dt,
145149
)
146150
(u6, v6) = fieldset.UV[time + c[4] * particle.dt, particle.depth, lat5, lon5, particle]
147151

148-
lon_4th = (u1 * b4[0] + u2 * b4[1] + u3 * b4[2] + u4 * b4[3] + u5 * b4[4]) * particle.dt
149-
lat_4th = (v1 * b4[0] + v2 * b4[1] + v3 * b4[2] + v4 * b4[3] + v5 * b4[4]) * particle.dt
150-
lon_5th = (u1 * b5[0] + u2 * b5[1] + u3 * b5[2] + u4 * b5[3] + u5 * b5[4] + u6 * b5[5]) * particle.dt
151-
lat_5th = (v1 * b5[0] + v2 * b5[1] + v3 * b5[2] + v4 * b5[3] + v5 * b5[4] + v6 * b5[5]) * particle.dt
152+
lon_4th = (u1 * b4[0] + u2 * b4[1] + u3 * b4[2] + u4 * b4[3] + u5 * b4[4]) * dt
153+
lat_4th = (v1 * b4[0] + v2 * b4[1] + v3 * b4[2] + v4 * b4[3] + v5 * b4[4]) * dt
154+
lon_5th = (u1 * b5[0] + u2 * b5[1] + u3 * b5[2] + u4 * b5[3] + u5 * b5[4] + u6 * b5[5]) * dt
155+
lat_5th = (v1 * b5[0] + v2 * b5[1] + v3 * b5[2] + v4 * b5[3] + v5 * b5[4] + v6 * b5[5]) * dt
152156

153157
kappa = math.sqrt(math.pow(lon_5th - lon_4th, 2) + math.pow(lat_5th - lat_4th, 2))
154-
if (kappa <= fieldset.RK45_tol) or (math.fabs(particle.dt) < math.fabs(fieldset.RK45_min_dt)):
158+
if (kappa <= fieldset.RK45_tol) or (math.fabs(dt) < math.fabs(fieldset.RK45_min_dt)):
155159
particle_dlon += lon_4th # noqa
156160
particle_dlat += lat_4th # noqa
157-
if (kappa <= fieldset.RK45_tol) / 10 and (math.fabs(particle.dt * 2) <= math.fabs(fieldset.RK45_max_dt)):
161+
if (kappa <= fieldset.RK45_tol) / 10 and (math.fabs(dt * 2) <= math.fabs(fieldset.RK45_max_dt)):
158162
particle.next_dt *= 2
159163
else:
160164
particle.next_dt /= 2
@@ -174,13 +178,14 @@ def AdvectionAnalytical(particle, fieldset, time): # pragma: no cover
174178

175179
tol = 1e-10
176180
I_s = 10 # number of intermediate time steps
177-
direction = 1.0 if particle.dt > 0 else -1.0
181+
dt = particle.dt / np.timedelta64(1, "s") # TODO improve API for converting dt to seconds
182+
direction = 1.0 if dt > 0 else -1.0
178183
withW = True if "W" in [f.name for f in fieldset.fields.values()] else False
179184
withTime = True if len(fieldset.U.grid.time) > 1 else False
180185
tau, zeta, eta, xsi, ti, zi, yi, xi = fieldset.U._search_indices(
181186
time, particle.depth, particle.lat, particle.lon, particle=particle
182187
)
183-
ds_t = particle.dt
188+
ds_t = dt
184189
if withTime:
185190
time_i = np.linspace(0, fieldset.U.grid.time[ti + 1] - fieldset.U.grid.time[ti], I_s)
186191
ds_t = min(ds_t, time_i[np.where(time - fieldset.U.grid.time[ti] < time_i)[0][0]])
@@ -329,6 +334,6 @@ def compute_rs(r, B, delta, s_min):
329334
particle_ddepth += (1.0 - rs_z) * pz[0] + rs_z * pz[1] - particle.depth # noqa
330335

331336
if particle.dt > 0:
332-
particle.dt = max(direction * s_min * (dxdy * dz), 1e-7)
337+
particle.dt = max(direction * s_min * (dxdy * dz), 1e-7).astype("timedelta64[s]")
333338
else:
334-
particle.dt = min(direction * s_min * (dxdy * dz), -1e-7)
339+
particle.dt = min(direction * s_min * (dxdy * dz), -1e-7).astype("timedelta64[s]")

parcels/field.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
VectorType,
1717
assert_valid_mesh,
1818
)
19+
from parcels.particle import Particle
1920
from parcels.tools.converters import (
2021
UnitConverter,
2122
unitconverters_map,
@@ -35,17 +36,10 @@
3536
__all__ = ["Field", "VectorField"]
3637

3738

38-
def _isParticle(key):
39-
if hasattr(key, "obs_written"):
40-
return True
41-
else:
42-
return False
43-
44-
4539
def _deal_with_errors(error, key, vector_type: VectorType):
46-
if _isParticle(key):
40+
if isinstance(key, Particle):
4741
key.state = AllParcelsErrorCodes[type(error)]
48-
elif _isParticle(key[-1]):
42+
elif isinstance(key[-1], Particle):
4943
key[-1].state = AllParcelsErrorCodes[type(error)]
5044
else:
5145
raise RuntimeError(f"{error}. Error could not be handled because particle was not part of the Field Sampling.")
@@ -283,7 +277,7 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
283277
def __getitem__(self, key):
284278
self._check_velocitysampling()
285279
try:
286-
if _isParticle(key):
280+
if isinstance(key, Particle):
287281
return self.eval(key.time, key.depth, key.lat, key.lon, key)
288282
else:
289283
return self.eval(*key)
@@ -379,7 +373,7 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
379373

380374
def __getitem__(self, key):
381375
try:
382-
if _isParticle(key):
376+
if isinstance(key, Particle):
383377
return self.eval(key.time, key.depth, key.lat, key.lon, key)
384378
else:
385379
return self.eval(*key)

parcels/kernel.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,11 @@ def fieldset(self):
7070

7171
def remove_deleted(self, pset):
7272
"""Utility to remove all particles that signalled deletion."""
73-
bool_indices = pset.particledata.state == StatusCode.Delete
73+
bool_indices = pset._data["state"] == StatusCode.Delete
7474
indices = np.where(bool_indices)[0]
75-
if len(indices) > 0 and self.fieldset.particlefile is not None:
76-
self.fieldset.particlefile.write(pset, None, indices=indices)
75+
# TODO v4: need to implement ParticleFile writing of deleted particles
76+
# if len(indices) > 0 and self.fieldset.particlefile is not None:
77+
# self.fieldset.particlefile.write(pset, None, indices=indices)
7778
pset.remove_indices(indices)
7879

7980

@@ -183,6 +184,8 @@ def fieldset(self):
183184
def add_positionupdate_kernels(self):
184185
# Adding kernels that set and update the coordinate changes
185186
def Setcoords(particle, fieldset, time): # pragma: no cover
187+
import numpy as np # noqa
188+
186189
particle_dlon = 0 # noqa
187190
particle_dlat = 0 # noqa
188191
particle_ddepth = 0 # noqa
@@ -303,9 +306,9 @@ def from_list(cls, fieldset, ptype, pyfunc_list, *args, **kwargs):
303306

304307
def execute(self, pset, endtime, dt):
305308
"""Execute this Kernel over a ParticleSet for several timesteps."""
306-
pset.particledata.state[:] = StatusCode.Evaluate
309+
pset._data["state"][:] = StatusCode.Evaluate
307310

308-
if abs(dt) < 1e-6:
311+
if abs(dt) < np.timedelta64(1000, "ns"): # TODO still needed?
309312
warnings.warn(
310313
"'dt' is too small, causing numerical accuracy limit problems. Please chose a higher 'dt' and rather scale the 'time' axis of the field accordingly. (related issue #762)",
311314
RuntimeWarning,
@@ -328,9 +331,8 @@ def execute(self, pset, endtime, dt):
328331
n_error = pset._num_error_particles
329332

330333
while n_error > 0:
331-
error_pset = pset._error_particles
332-
# Check for StatusCodes
333-
for p in error_pset:
334+
for i in pset._error_particles:
335+
p = pset[i]
334336
if p.state == StatusCode.StopExecution:
335337
return
336338
if p.state == StatusCode.StopAllExecution:
@@ -379,21 +381,23 @@ def evaluate_particle(self, p, endtime):
379381
while p.state in [StatusCode.Evaluate, StatusCode.Repeat]:
380382
pre_dt = p.dt
381383

382-
sign_dt = np.sign(p.dt)
383-
if sign_dt * p.time_nextloop >= sign_dt * endtime:
384+
sign_dt = np.sign(p.dt).astype(int)
385+
if sign_dt * (endtime - p.time_nextloop) <= np.timedelta64(0, "ns"):
384386
return p
385387

386-
try: # Use next_dt from AdvectionRK45 if it is set
387-
if abs(endtime - p.time_nextloop) < abs(p.next_dt) - 1e-6:
388-
p.next_dt = abs(endtime - p.time_nextloop) * sign_dt
389-
except KeyError:
390-
if abs(endtime - p.time_nextloop) < abs(p.dt) - 1e-6:
391-
p.dt = abs(endtime - p.time_nextloop) * sign_dt
388+
# TODO implement below later again
389+
# try: # Use next_dt from AdvectionRK45 if it is set
390+
# if abs(endtime - p.time_nextloop) < abs(p.next_dt) - 1e-6:
391+
# p.next_dt = abs(endtime - p.time_nextloop) * sign_dt
392+
# except AttributeError:
393+
if abs(endtime - p.time_nextloop) <= abs(p.dt):
394+
p.dt = abs(endtime - p.time_nextloop) * sign_dt
392395
res = self._pyfunc(p, self._fieldset, p.time_nextloop)
393396

394397
if res is None:
395-
if sign_dt * p.time < sign_dt * endtime and p.state == StatusCode.Success:
396-
p.state = StatusCode.Evaluate
398+
if p.state == StatusCode.Success:
399+
if sign_dt * (p.time - endtime) > np.timedelta64(0, "ns"):
400+
p.state = StatusCode.Evaluate
397401
else:
398402
p.state = res
399403

0 commit comments

Comments
 (0)