Skip to content

Commit 9d500d0

Browse files
Merge pull request #2490 from erikvansebille/renaming_pyfunc_to_kernels
Renaming pyfunc to kernels
2 parents bc1f7c5 + 1c92b9a commit 9d500d0

File tree

12 files changed

+81
-142
lines changed

12 files changed

+81
-142
lines changed

docs/getting_started/explanation_concepts.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ dt = np.timedelta64(5, "m")
181181
runtime = np.timedelta64(1, "D")
182182

183183
# Run the simulation
184-
pset.execute(pyfunc=kernels, dt=dt, runtime=runtime)
184+
pset.execute(kernels=kernels, dt=dt, runtime=runtime)
185185
```
186186

187187
### Output

docs/user_guide/examples/tutorial_interaction.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@
139139
"]\n",
140140
"\n",
141141
"pset.execute(\n",
142-
" pyfunc=kernels,\n",
142+
" kernels=kernels,\n",
143143
" runtime=np.timedelta64(60, \"s\"),\n",
144144
" dt=np.timedelta64(1, \"s\"),\n",
145145
" output_file=output_file,\n",
@@ -331,7 +331,7 @@
331331
"]\n",
332332
"\n",
333333
"pset.execute(\n",
334-
" pyfunc=kernels,\n",
334+
" kernels=kernels,\n",
335335
" runtime=np.timedelta64(60, \"s\"),\n",
336336
" dt=np.timedelta64(1, \"s\"),\n",
337337
" output_file=output_file,\n",

docs/user_guide/examples_v3/tutorial_stommel_uxarray.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@
334334
" pset.execute(\n",
335335
" endtime=endtime,\n",
336336
" dt=timedelta(seconds=60),\n",
337-
" pyfunc=AdvectionEE,\n",
337+
" kernels=AdvectionEE,\n",
338338
" verbose_progress=False,\n",
339339
" )\n",
340340
" except FieldOutOfBoundError:\n",

docs/user_guide/v4-migration.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Version 4 of Parcels is unreleased at the moment. The information in this migrat
1717
- The `InteractionKernel` class has been removed. Since normal Kernels now have access to _all_ particles, particle-particle interaction can be performed within normal Kernels.
1818
- Users need to explicitly use `convert_z_to_sigma_croco` in sampling kernels (such as the `AdvectionRK4_3D_CROCO` or `SampleOMegaCroco` kernels) when working with CROCO data, as the automatic conversion from depth to sigma grids under the hood has been removed.
1919
- We added a new AdvectionRK2 Kernel. The AdvectionRK4 kernel is still available, but RK2 is now the recommended default advection scheme as it is faster while the accuracy is comparable for most applications. See also the Choosing an integration method tutorial.
20+
- Functions shouldn't be converted to Kernels before adding to a pset.execute() call. Instead, simply pass the function(s) as a list to pset.execute().
2021

2122
## FieldSet
2223

src/parcels/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from parcels._core.fieldset import FieldSet
1313
from parcels._core.particleset import ParticleSet
14-
from parcels._core.kernel import Kernel
1514
from parcels._core.particlefile import ParticleFile
1615
from parcels._core.particle import (
1716
Variable,
@@ -45,7 +44,6 @@
4544
# Core classes
4645
"FieldSet",
4746
"ParticleSet",
48-
"Kernel",
4947
"ParticleFile",
5048
"Variable",
5149
"Particle",

src/parcels/_core/kernel.py

Lines changed: 31 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
if TYPE_CHECKING:
2828
from collections.abc import Callable
2929

30-
__all__ = ["Kernel"]
31-
3230

3331
ErrorsToThrow = {
3432
StatusCode.ErrorOutsideTimeInterval: _raise_outside_time_interval_error,
@@ -45,12 +43,12 @@ class Kernel:
4543
4644
Parameters
4745
----------
46+
kernels :
47+
list of Kernel functions
4848
fieldset : parcels.Fieldset
4949
FieldSet object providing the field information (possibly None)
5050
ptype :
5151
PType object for the kernel particle
52-
pyfunc :
53-
(aggregated) Kernel function
5452
5553
Notes
5654
-----
@@ -60,32 +58,35 @@ class Kernel:
6058

6159
def __init__(
6260
self,
63-
fieldset,
64-
ptype,
65-
pyfuncs: list[types.FunctionType],
61+
kernels: list[types.FunctionType],
62+
pset,
6663
):
67-
for f in pyfuncs:
64+
if not isinstance(kernels, list):
65+
raise ValueError(f"kernels must be a list. Got {kernels=!r}")
66+
67+
for f in kernels:
6868
if not isinstance(f, types.FunctionType):
69-
raise TypeError(f"Argument pyfunc should be a function or list of functions. Got {type(f)}")
69+
raise TypeError(f"Argument `kernels` should be a function or list of functions. Got {type(f)}")
7070
assert_same_function_signature(f, ref=AdvectionRK4, context="Kernel")
7171

72-
if len(pyfuncs) == 0:
73-
raise ValueError("List of `pyfuncs` should have at least one function.")
72+
if len(kernels) == 0:
73+
raise ValueError("List of `kernels` should have at least one function.")
7474

75-
self._fieldset = fieldset
76-
self._ptype = ptype
75+
self._fieldset = pset.fieldset
76+
self._ptype = pset._ptype
7777

78-
self._positionupdate_kernel_added = False
79-
80-
for f in pyfuncs:
78+
for f in kernels:
8179
self.check_fieldsets_in_kernels(f)
8280

83-
self._pyfuncs: list[Callable] = pyfuncs
81+
self._kernels: list[Callable] = kernels
82+
83+
if pset._requires_prepended_positionupdate_kernel:
84+
self.prepend_positionupdate_kernel()
8485

8586
@property #! Ported from v3. To be removed in v4? (/find another way to name kernels in output file)
8687
def funcname(self):
8788
ret = ""
88-
for f in self._pyfuncs:
89+
for f in self._kernels:
8990
ret += f.__name__
9091
return ret
9192

@@ -107,7 +108,7 @@ def remove_deleted(self, pset):
107108
if len(indices) > 0:
108109
pset.remove_indices(indices)
109110

110-
def add_positionupdate_kernel(self):
111+
def prepend_positionupdate_kernel(self):
111112
# Adding kernels that set and update the coordinate changes
112113
def PositionUpdate(particles, fieldset): # pragma: no cover
113114
particles.lon += particles.dlon
@@ -123,21 +124,21 @@ def PositionUpdate(particles, fieldset): # pragma: no cover
123124
# Update dt in case it's increased in RK45 kernel
124125
particles.dt = particles.next_dt
125126

126-
self._pyfuncs = (PositionUpdate + self)._pyfuncs
127+
self._kernels = [PositionUpdate] + self._kernels
127128

128-
def check_fieldsets_in_kernels(self, pyfunc): # TODO v4: this can go into another method? assert_is_compatible()?
129+
def check_fieldsets_in_kernels(self, kernel): # TODO v4: this can go into another method? assert_is_compatible()?
129130
"""
130131
Checks the integrity of the fieldset with the kernels.
131132
132-
This function is to be called from the derived class when setting up the 'pyfunc'.
133+
This function is to be called from the derived class when setting up the 'kernel'.
133134
"""
134135
if self.fieldset is not None:
135-
if pyfunc is AdvectionAnalytical:
136+
if kernel is AdvectionAnalytical:
136137
if self._fieldset.U.interp_method != "cgrid_velocity":
137138
raise NotImplementedError("Analytical Advection only works with C-grids")
138139
if self._fieldset.U.grid._gtype not in [GridType.CurvilinearZGrid, GridType.RectilinearZGrid]:
139140
raise NotImplementedError("Analytical Advection only works with Z-grids in the vertical")
140-
elif pyfunc is AdvectionRK45:
141+
elif kernel is AdvectionRK45:
141142
if "next_dt" not in [v.name for v in self.ptype.variables]:
142143
raise ValueError('ParticleClass requires a "next_dt" for AdvectionRK45 Kernel.')
143144
if not hasattr(self.fieldset, "RK45_tol"):
@@ -174,48 +175,11 @@ def merge(self, kernel):
174175
assert self.ptype == kernel.ptype, "Cannot merge kernels with different particle types"
175176

176177
return type(self)(
178+
self._kernels + kernel._kernels,
177179
self.fieldset,
178180
self.ptype,
179-
pyfuncs=self._pyfuncs + kernel._pyfuncs,
180181
)
181182

182-
def __add__(self, kernel):
183-
if isinstance(kernel, types.FunctionType):
184-
kernel = type(self)(self.fieldset, self.ptype, pyfuncs=[kernel])
185-
return self.merge(kernel)
186-
187-
def __radd__(self, kernel):
188-
if isinstance(kernel, types.FunctionType):
189-
kernel = type(self)(self.fieldset, self.ptype, pyfuncs=[kernel])
190-
return kernel.merge(self)
191-
192-
@classmethod
193-
def from_list(cls, fieldset, ptype, pyfunc_list):
194-
"""Create a combined kernel from a list of functions.
195-
196-
Takes a list of functions, converts them to kernels, and joins them
197-
together.
198-
199-
Parameters
200-
----------
201-
fieldset : parcels.Fieldset
202-
FieldSet object providing the field information (possibly None)
203-
ptype :
204-
PType object for the kernel particle
205-
pyfunc_list : list of functions
206-
List of functions to be combined into a single kernel.
207-
*args :
208-
Additional arguments passed to first kernel during construction.
209-
**kwargs :
210-
Additional keyword arguments passed to first kernel during construction.
211-
"""
212-
if not isinstance(pyfunc_list, list):
213-
raise TypeError(f"Argument `pyfunc_list` should be a list of functions. Got {type(pyfunc_list)}")
214-
if not all([isinstance(f, types.FunctionType) for f in pyfunc_list]):
215-
raise ValueError("Argument `pyfunc_list` should be a list of functions.")
216-
217-
return cls(fieldset, ptype, pyfunc_list)
218-
219183
def execute(self, pset, endtime, dt):
220184
"""Execute this Kernel over a ParticleSet for several timesteps.
221185
@@ -248,7 +212,7 @@ def execute(self, pset, endtime, dt):
248212
pset.dt = np.minimum(np.maximum(pset.dt, -time_to_endtime), 0)
249213

250214
# run kernels for all particles that need to be evaluated
251-
for f in self._pyfuncs:
215+
for f in self._kernels:
252216
f(pset[evaluate_particles], self._fieldset)
253217

254218
# check for particles that have to be repeated
@@ -280,9 +244,9 @@ def execute(self, pset, endtime, dt):
280244
else:
281245
error_func(pset[inds].z, pset[inds].lat, pset[inds].lon)
282246

283-
# Only add PositionUpdate kernel at the end of the first execute call to avoid adding dt to time too early
284-
if not self._positionupdate_kernel_added:
285-
self.add_positionupdate_kernel()
286-
self._positionupdate_kernel_added = True
247+
# Only prepend PositionUpdate kernel at the end of the first execute call to avoid adding dt to time too early
248+
if not pset._requires_prepended_positionupdate_kernel:
249+
self.prepend_positionupdate_kernel()
250+
pset._requires_prepended_positionupdate_kernel = True
287251

288252
return pset

src/parcels/_core/particleset.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
import sys
3+
import types
34
import warnings
45
from collections.abc import Iterable
56
from typing import Literal
@@ -134,6 +135,7 @@ def __init__(
134135
self._data[kwvar][:] = kwval
135136

136137
self._kernel = None
138+
self._requires_prepended_positionupdate_kernel = False
137139

138140
def __del__(self):
139141
if self._data is not None and isinstance(self._data, xr.Dataset):
@@ -290,29 +292,6 @@ def from_particlefile(cls, fieldset, pclass, filename, restart=True, restarttime
290292
"ParticleSet.from_particlefile is not yet implemented in v4."
291293
) # TODO implement this when ParticleFile is implemented in v4
292294

293-
def Kernel(self, pyfunc):
294-
"""Wrapper method to convert a `pyfunc` into a :class:`parcels.kernel.Kernel` object.
295-
296-
Conversion is based on `fieldset` and `ptype` of the ParticleSet.
297-
298-
Parameters
299-
----------
300-
pyfunc : function or list of functions
301-
Python function to convert into kernel. If a list of functions is provided,
302-
the functions will be converted to kernels and combined into a single kernel.
303-
"""
304-
if isinstance(pyfunc, list):
305-
return Kernel.from_list(
306-
self.fieldset,
307-
self._ptype,
308-
pyfunc,
309-
)
310-
return Kernel(
311-
self.fieldset,
312-
self._ptype,
313-
pyfuncs=[pyfunc],
314-
)
315-
316295
def data_indices(self, variable_name, compare_values, invert=False):
317296
"""Get the indices of all particles where the value of `variable_name` equals (one of) `compare_values`.
318297
@@ -376,7 +355,7 @@ def set_variable_write_status(self, var, write_status):
376355

377356
def execute(
378357
self,
379-
pyfunc,
358+
kernels,
380359
dt: datetime.timedelta | np.timedelta64 | float,
381360
endtime: np.timedelta64 | np.datetime64 | None = None,
382361
runtime: datetime.timedelta | np.timedelta64 | float | None = None,
@@ -390,10 +369,9 @@ def execute(
390369
391370
Parameters
392371
----------
393-
pyfunc :
394-
Kernel function to execute. This can be the name of a
372+
kernels :
373+
List of Kernel functions to execute. This can be the name of a
395374
defined Python function or a :class:`parcels.kernel.Kernel` object.
396-
Kernels can be concatenated using the + operator.
397375
dt (np.timedelta64 or float):
398376
Timestep interval (as a np.timedelta64 object of float in seconds) to be passed to the kernel.
399377
Use a negative value for a backward-in-time simulation.
@@ -417,10 +395,9 @@ def execute(
417395
if len(self) == 0:
418396
return
419397

420-
if not isinstance(pyfunc, Kernel):
421-
pyfunc = self.Kernel(pyfunc)
422-
423-
self._kernel = pyfunc
398+
if isinstance(kernels, types.FunctionType):
399+
kernels = [kernels]
400+
self._kernel = Kernel(kernels, self)
424401

425402
if output_file is not None:
426403
output_file.set_metadata(self.fieldset.gridset[0]._mesh)

tests-v3/test_kernel_language.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
def expr_kernel(name, pset, expr):
1919
pycode = (f"def {name}(particle, fieldset, time):\n"
2020
f" particle.p = {expr}") # fmt: skip
21-
return Kernel(pset.fieldset, pset.particledata.ptype, pyfunc=None, funccode=pycode, funcname=name)
21+
return Kernel(kernels=None, fieldset=pset.fieldset, ptype=pset._ptype, funccode=pycode, funcname=name)
2222

2323

2424
@pytest.fixture

tests/test_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_fieldKh_Brownian(mesh):
3434

3535
np.random.seed(1234)
3636
pset = ParticleSet(fieldset=fieldset, lon=np.zeros(npart), lat=np.zeros(npart))
37-
pset.execute(pset.Kernel(DiffusionUniformKh), runtime=runtime, dt=np.timedelta64(1, "h"))
37+
pset.execute(DiffusionUniformKh, runtime=runtime, dt=np.timedelta64(1, "h"))
3838

3939
expected_std_lon = np.sqrt(2 * kh_zonal * mesh_conversion**2 * timedelta_to_float(runtime))
4040
expected_std_lat = np.sqrt(2 * kh_meridional * mesh_conversion**2 * timedelta_to_float(runtime))
@@ -70,7 +70,7 @@ def test_fieldKh_SpatiallyVaryingDiffusion(mesh, kernel):
7070

7171
np.random.seed(1636)
7272
pset = ParticleSet(fieldset=fieldset, lon=np.zeros(npart), lat=np.zeros(npart))
73-
pset.execute(pset.Kernel(kernel), runtime=np.timedelta64(3, "h"), dt=np.timedelta64(1, "h"))
73+
pset.execute(kernel, runtime=np.timedelta64(3, "h"), dt=np.timedelta64(1, "h"))
7474

7575
tol = 2000 * mesh_conversion # effectively 2000 m errors (because of low numbers of particles)
7676
assert np.allclose(np.mean(pset.lon), 0, atol=tol)

0 commit comments

Comments
 (0)