2727if TYPE_CHECKING :
2828 from collections .abc import Callable
2929
30- __all__ = ["Kernel" ]
31-
3230
3331ErrorsToThrow = {
3432 StatusCode .ErrorOutsideTimeInterval : _raise_outside_time_interval_error ,
@@ -45,12 +43,12 @@ class Kernel:
4543
4644 Parameters
4745 ----------
46+ kernels :
47+ list of Kernel functions
4848 fieldset : parcels.Fieldset
4949 FieldSet object providing the field information (possibly None)
5050 ptype :
5151 PType object for the kernel particle
52- pyfunc :
53- (aggregated) Kernel function
5452
5553 Notes
5654 -----
@@ -60,32 +58,35 @@ class Kernel:
6058
6159 def __init__ (
6260 self ,
63- fieldset ,
64- ptype ,
65- pyfuncs : list [types .FunctionType ],
61+ kernels : list [types .FunctionType ],
62+ pset ,
6663 ):
67- for f in pyfuncs :
64+ if not isinstance (kernels , list ):
65+ raise ValueError (f"kernels must be a list. Got { kernels = !r} " )
66+
67+ for f in kernels :
6868 if not isinstance (f , types .FunctionType ):
69- raise TypeError (f"Argument pyfunc should be a function or list of functions. Got { type (f )} " )
69+ raise TypeError (f"Argument `kernels` should be a function or list of functions. Got { type (f )} " )
7070 assert_same_function_signature (f , ref = AdvectionRK4 , context = "Kernel" )
7171
72- if len (pyfuncs ) == 0 :
73- raise ValueError ("List of `pyfuncs ` should have at least one function." )
72+ if len (kernels ) == 0 :
73+ raise ValueError ("List of `kernels ` should have at least one function." )
7474
75- self ._fieldset = fieldset
76- self ._ptype = ptype
75+ self ._fieldset = pset . fieldset
76+ self ._ptype = pset . _ptype
7777
78- self ._positionupdate_kernel_added = False
79-
80- for f in pyfuncs :
78+ for f in kernels :
8179 self .check_fieldsets_in_kernels (f )
8280
83- self ._pyfuncs : list [Callable ] = pyfuncs
81+ self ._kernels : list [Callable ] = kernels
82+
83+ if pset ._requires_prepended_positionupdate_kernel :
84+ self .prepend_positionupdate_kernel ()
8485
8586 @property #! Ported from v3. To be removed in v4? (/find another way to name kernels in output file)
8687 def funcname (self ):
8788 ret = ""
88- for f in self ._pyfuncs :
89+ for f in self ._kernels :
8990 ret += f .__name__
9091 return ret
9192
@@ -107,7 +108,7 @@ def remove_deleted(self, pset):
107108 if len (indices ) > 0 :
108109 pset .remove_indices (indices )
109110
110- def add_positionupdate_kernel (self ):
111+ def prepend_positionupdate_kernel (self ):
111112 # Adding kernels that set and update the coordinate changes
112113 def PositionUpdate (particles , fieldset ): # pragma: no cover
113114 particles .lon += particles .dlon
@@ -123,21 +124,21 @@ def PositionUpdate(particles, fieldset): # pragma: no cover
123124 # Update dt in case it's increased in RK45 kernel
124125 particles .dt = particles .next_dt
125126
126- self ._pyfuncs = ( PositionUpdate + self ). _pyfuncs
127+ self ._kernels = [ PositionUpdate ] + self . _kernels
127128
128- def check_fieldsets_in_kernels (self , pyfunc ): # TODO v4: this can go into another method? assert_is_compatible()?
129+ def check_fieldsets_in_kernels (self , kernel ): # TODO v4: this can go into another method? assert_is_compatible()?
129130 """
130131 Checks the integrity of the fieldset with the kernels.
131132
132- This function is to be called from the derived class when setting up the 'pyfunc '.
133+ This function is to be called from the derived class when setting up the 'kernel '.
133134 """
134135 if self .fieldset is not None :
135- if pyfunc is AdvectionAnalytical :
136+ if kernel is AdvectionAnalytical :
136137 if self ._fieldset .U .interp_method != "cgrid_velocity" :
137138 raise NotImplementedError ("Analytical Advection only works with C-grids" )
138139 if self ._fieldset .U .grid ._gtype not in [GridType .CurvilinearZGrid , GridType .RectilinearZGrid ]:
139140 raise NotImplementedError ("Analytical Advection only works with Z-grids in the vertical" )
140- elif pyfunc is AdvectionRK45 :
141+ elif kernel is AdvectionRK45 :
141142 if "next_dt" not in [v .name for v in self .ptype .variables ]:
142143 raise ValueError ('ParticleClass requires a "next_dt" for AdvectionRK45 Kernel.' )
143144 if not hasattr (self .fieldset , "RK45_tol" ):
@@ -174,48 +175,11 @@ def merge(self, kernel):
174175 assert self .ptype == kernel .ptype , "Cannot merge kernels with different particle types"
175176
176177 return type (self )(
178+ self ._kernels + kernel ._kernels ,
177179 self .fieldset ,
178180 self .ptype ,
179- pyfuncs = self ._pyfuncs + kernel ._pyfuncs ,
180181 )
181182
182- def __add__ (self , kernel ):
183- if isinstance (kernel , types .FunctionType ):
184- kernel = type (self )(self .fieldset , self .ptype , pyfuncs = [kernel ])
185- return self .merge (kernel )
186-
187- def __radd__ (self , kernel ):
188- if isinstance (kernel , types .FunctionType ):
189- kernel = type (self )(self .fieldset , self .ptype , pyfuncs = [kernel ])
190- return kernel .merge (self )
191-
192- @classmethod
193- def from_list (cls , fieldset , ptype , pyfunc_list ):
194- """Create a combined kernel from a list of functions.
195-
196- Takes a list of functions, converts them to kernels, and joins them
197- together.
198-
199- Parameters
200- ----------
201- fieldset : parcels.Fieldset
202- FieldSet object providing the field information (possibly None)
203- ptype :
204- PType object for the kernel particle
205- pyfunc_list : list of functions
206- List of functions to be combined into a single kernel.
207- *args :
208- Additional arguments passed to first kernel during construction.
209- **kwargs :
210- Additional keyword arguments passed to first kernel during construction.
211- """
212- if not isinstance (pyfunc_list , list ):
213- raise TypeError (f"Argument `pyfunc_list` should be a list of functions. Got { type (pyfunc_list )} " )
214- if not all ([isinstance (f , types .FunctionType ) for f in pyfunc_list ]):
215- raise ValueError ("Argument `pyfunc_list` should be a list of functions." )
216-
217- return cls (fieldset , ptype , pyfunc_list )
218-
219183 def execute (self , pset , endtime , dt ):
220184 """Execute this Kernel over a ParticleSet for several timesteps.
221185
@@ -248,7 +212,7 @@ def execute(self, pset, endtime, dt):
248212 pset .dt = np .minimum (np .maximum (pset .dt , - time_to_endtime ), 0 )
249213
250214 # run kernels for all particles that need to be evaluated
251- for f in self ._pyfuncs :
215+ for f in self ._kernels :
252216 f (pset [evaluate_particles ], self ._fieldset )
253217
254218 # check for particles that have to be repeated
@@ -280,9 +244,9 @@ def execute(self, pset, endtime, dt):
280244 else :
281245 error_func (pset [inds ].z , pset [inds ].lat , pset [inds ].lon )
282246
283- # Only add PositionUpdate kernel at the end of the first execute call to avoid adding dt to time too early
284- if not self . _positionupdate_kernel_added :
285- self .add_positionupdate_kernel ()
286- self . _positionupdate_kernel_added = True
247+ # Only prepend PositionUpdate kernel at the end of the first execute call to avoid adding dt to time too early
248+ if not pset . _requires_prepended_positionupdate_kernel :
249+ self .prepend_positionupdate_kernel ()
250+ pset . _requires_prepended_positionupdate_kernel = True
287251
288252 return pset
0 commit comments