From 4e9d98c528dfebf05aa9d07edf451eff29f37f65 Mon Sep 17 00:00:00 2001 From: Roelof Groenewald <40245517+roelof-groenewald@users.noreply.github.com> Date: Wed, 19 Jan 2022 09:01:47 -0800 Subject: [PATCH] Refactor python callback handling (#2703) * added support to uninstall an external Poisson solver and return to using the default MLMG solver; also updated some callbacks.py calls to Python3 * refactor callback handling - use a map to handle all the different callbacks * warpx_callback_py_map does not need to link to C * Apply suggestions from code review Co-authored-by: Axel Huebl * further suggested changes from code review * added function ExecutePythonCallback to reduce code duplication * moved ExecutePythonCallback to WarpX_py * added function IsPythonCallbackInstalled Co-authored-by: Axel Huebl --- Python/pywarpx/callbacks.py | 65 ++++++++++++--------- Source/Evolve/WarpXEvolve.cpp | 54 +++++------------- Source/FieldSolver/ElectrostaticSolver.cpp | 7 +-- Source/Python/WarpXWrappers.H | 16 +----- Source/Python/WarpXWrappers.cpp | 66 ++++------------------ Source/Python/WarpX_py.H | 36 +++++++----- Source/Python/WarpX_py.cpp | 28 ++++----- 7 files changed, 102 insertions(+), 170 deletions(-) diff --git a/Python/pywarpx/callbacks.py b/Python/pywarpx/callbacks.py index 35cc5189467..9ae36266e5b 100644 --- a/Python/pywarpx/callbacks.py +++ b/Python/pywarpx/callbacks.py @@ -86,32 +86,36 @@ def __init__(self,name=None,lcallonce=0): self.lcallonce = lcallonce def __call__(self,*args,**kw): - "Call all of the functions in the list" + """Call all of the functions in the list""" tt = self.callfuncsinlist(*args,**kw) self.time = self.time + tt if self.lcallonce: self.funcs = [] def clearlist(self): + """Unregister/clear out all registered C callbacks""" self.funcs = [] + libwarpx.libwarpx_so.warpx_clear_callback_py( + ctypes.c_char_p(self.name.encode('utf-8')) + ) - def __nonzero__(self): - "Returns True if functions are installed, otherwise False" + def __bool__(self): + """Returns True if functions are installed, otherwise False""" return self.hasfuncsinstalled() def __len__(self): - "Returns number of functions installed" + """Returns number of functions installed""" return len(self.funcs) def hasfuncsinstalled(self): - "Checks if there are any functions installed" + """Checks if there are any functions installed""" return len(self.funcs) > 0 def _getmethodobject(self,func): - "For call backs that are methods, returns the method's instance" + """For call backs that are methods, returns the method's instance""" return func[0] def callbackfunclist(self): - "Generator returning callable functions from the list" + """Generator returning callable functions from the list""" funclistcopy = copy.copy(self.funcs) for f in funclistcopy: if isinstance(f,list): @@ -147,15 +151,16 @@ def callbackfunclist(self): yield result def installfuncinlist(self,f): - "Check if the specified function is installed" + """Check if the specified function is installed""" if len(self.funcs) == 0: # If this is the first function installed, set the callback in the C++ # to call this class instance. # Note that the _c_func must be saved. _CALLBACK_FUNC_0 = ctypes.CFUNCTYPE(None) self._c_func = _CALLBACK_FUNC_0(self) - callback_setter = getattr(libwarpx.libwarpx_so, f'warpx_set_callback_py_{self.name}') - callback_setter(self._c_func) + libwarpx.libwarpx_so.warpx_set_callback_py( + ctypes.c_char_p(self.name.encode('utf-8')), self._c_func + ) if isinstance(f,types.MethodType): # --- If the function is a method of a class instance, then save a full # --- reference to that instance and the method name. @@ -177,7 +182,7 @@ def installfuncinlist(self,f): self.funcs.append(f) def uninstallfuncinlist(self,f): - "Uninstall the specified function" + """Uninstall the specified function""" # --- An element by element search is needed # --- f can be a function or method object, or a name (string). # --- Note that method objects can not be removed by name. @@ -185,27 +190,34 @@ def uninstallfuncinlist(self,f): for func in funclistcopy: if f == func: self.funcs.remove(f) - return + break elif isinstance(func,list) and isinstance(f,types.MethodType): object = self._getmethodobject(func) - if f.im_self is object and f.__name__ == func[1]: + if f.__self__ is object and f.__name__ == func[1]: self.funcs.remove(func) - return + break elif isinstance(func,str): if f.__name__ == func: self.funcs.remove(func) - return + break elif isinstance(f,str): if isinstance(func,str): funcname = func elif isinstance(func,list): funcname = None else: funcname = func.__name__ if f == funcname: self.funcs.remove(func) - return - raise Exception('Warning: no such function had been installed') + break + + # check that a function was removed + if len(self.funcs) == len(funclistcopy): + raise Exception(f'Warning: no function, {f}, had been installed') + + # if there are no functions left, remove the C callback + if not self.hasfuncsinstalled(): + self.clearlist() def isinstalledfuncinlist(self,f): - "Checks if the specified function is installed" + """Checks if the specified function is installed""" # --- An element by element search is needed funclistcopy = copy.copy(self.funcs) for func in funclistcopy: @@ -213,7 +225,7 @@ def isinstalledfuncinlist(self,f): return 1 elif isinstance(func,list) and isinstance(f,types.MethodType): object = self._getmethodobject(func) - if f.im_self is object and f.__name__ == func[1]: + if f.__self__ is object and f.__name__ == func[1]: return 1 elif isinstance(func,str): if f.__name__ == func: @@ -221,7 +233,7 @@ def isinstalledfuncinlist(self,f): return 0 def callfuncsinlist(self,*args,**kw): - "Call the functions in the list" + """Call the functions in the list""" bb = time.time() for f in self.callbackfunclist(): #barrier() @@ -320,16 +332,15 @@ def callfrompoissonsolver(f): installpoissonsolver(f) return f def installpoissonsolver(f): - """Adds a function to solve Poisson's equation. Note that the C++ object - warpx_py_poissonsolver is declared as a nullptr but once the call to set it - to _c_poissonsolver below is executed it is no longer a nullptr, and therefore - if (warpx_py_poissonsolver) evaluates to True. For this reason a poissonsolver - cannot be uninstalled with the uninstallfuncinlist functionality at present.""" + """Installs an external function to solve Poisson's equation""" if _poissonsolver.hasfuncsinstalled(): - raise RuntimeError('Only one field solver can be installed.') + raise RuntimeError("Only one external Poisson solver can be installed.") _poissonsolver.installfuncinlist(f) +def uninstallpoissonsolver(f): + """Removes the external function to solve Poisson's equation""" + _poissonsolver.uninstallfuncinlist(f) def isinstalledpoissonsolver(f): - """Checks if the function is called for a field solve""" + """Checks if the function is called to solve Poisson's equation""" return _poissonsolver.isinstalledfuncinlist(f) # ---------------------------------------------------------------------------- diff --git a/Source/Evolve/WarpXEvolve.cpp b/Source/Evolve/WarpXEvolve.cpp index c3c297b56c3..86bcc7fb753 100644 --- a/Source/Evolve/WarpXEvolve.cpp +++ b/Source/Evolve/WarpXEvolve.cpp @@ -83,10 +83,7 @@ WarpX::Evolve (int numsteps) if (verbose) { amrex::Print() << "\nSTEP " << step+1 << " starts ...\n"; } - if (warpx_py_beforestep) { - WARPX_PROFILE("warpx_py_beforestep"); - warpx_py_beforestep(); - } + ExecutePythonCallback("beforestep"); amrex::LayoutData* cost = WarpX::getCosts(0); if (cost) { @@ -168,10 +165,7 @@ WarpX::Evolve (int numsteps) // Main PIC operation: // gather fields, push particles, deposit sources, update fields - if (warpx_py_particleinjection) { - WARPX_PROFILE("warpx_py_particleinjection"); - warpx_py_particleinjection(); - } + ExecutePythonCallback("particleinjection"); // Electrostatic case: only gather fields and push particles, // deposition and calculation of fields done further below if (do_electrostatic != ElectrostaticSolverAlgo::None) @@ -307,10 +301,7 @@ WarpX::Evolve (int numsteps) } if( do_electrostatic != ElectrostaticSolverAlgo::None ) { - if (warpx_py_beforeEsolve) { - WARPX_PROFILE("warpx_py_beforeEsolve"); - warpx_py_beforeEsolve(); - } + ExecutePythonCallback("beforeEsolve"); // Electrostatic solver: // For each species: deposit charge and add the associated space-charge // E and B field to the grid ; this is done at the end of the PIC @@ -319,18 +310,12 @@ WarpX::Evolve (int numsteps) // and so that the fields are at the correct time in the output. bool const reset_fields = true; ComputeSpaceChargeField( reset_fields ); - if (warpx_py_afterEsolve) { - WARPX_PROFILE("warpx_py_afterEsolve"); - warpx_py_afterEsolve(); - } + ExecutePythonCallback("afterEsolve"); } - // warpx_py_afterstep runs with the updated global time. It is included + // afterstep callback runs with the updated global time. It is included // in the evolve timing. - if (warpx_py_afterstep) { - WARPX_PROFILE("warpx_py_afterstep"); - warpx_py_afterstep(); - } + ExecutePythonCallback("afterstep"); /// reduced diags if (reduced_diags->m_plot_rd != 0) @@ -387,20 +372,13 @@ WarpX::OneStep_nosub (Real cur_time) // from p^{n-1/2} to p^{n+1/2} // Deposit current j^{n+1/2} // Deposit charge density rho^{n} - if (warpx_py_particlescraper) { - WARPX_PROFILE("warpx_py_particlescraper"); - warpx_py_particlescraper(); - } - if (warpx_py_beforedeposition) { - WARPX_PROFILE("warpx_py_beforedeposition"); - warpx_py_beforedeposition(); - } + + ExecutePythonCallback("particlescraper"); + ExecutePythonCallback("beforedeposition"); + PushParticlesandDepose(cur_time); - if (warpx_py_afterdeposition) { - WARPX_PROFILE("warpx_py_afterdeposition"); - warpx_py_afterdeposition(); - } + ExecutePythonCallback("afterdeposition"); // Synchronize J and rho SyncCurrent(); @@ -422,10 +400,7 @@ WarpX::OneStep_nosub (Real cur_time) if (do_pml && pml_has_particles) CopyJPML(); if (do_pml && do_pml_j_damping) DampJPML(); - if (warpx_py_beforeEsolve) { - WARPX_PROFILE("warpx_py_beforeEsolve"); - warpx_py_beforeEsolve(); - } + ExecutePythonCallback("beforeEsolve"); // Push E and B from {n} to {n+1} // (And update guard cells immediately afterwards) @@ -507,10 +482,7 @@ WarpX::OneStep_nosub (Real cur_time) FillBoundaryB(guard_cells.ng_alloc_EB); } // !PSATD - if (warpx_py_afterEsolve) { - WARPX_PROFILE("warpx_py_afterEsolve"); - warpx_py_afterEsolve(); - } + ExecutePythonCallback("afterEsolve"); } void diff --git a/Source/FieldSolver/ElectrostaticSolver.cpp b/Source/FieldSolver/ElectrostaticSolver.cpp index 18f90431fff..e1d70be164d 100644 --- a/Source/FieldSolver/ElectrostaticSolver.cpp +++ b/Source/FieldSolver/ElectrostaticSolver.cpp @@ -172,10 +172,7 @@ WarpX::AddSpaceChargeFieldLabFrame () std::array beta = {0._rt}; // Compute the potential phi, by solving the Poisson equation - if (warpx_py_poissonsolver) { - WARPX_PROFILE("warpx_py_poissonsolver"); - warpx_py_poissonsolver(); - } + if ( IsPythonCallBackInstalled("poissonsolver") ) ExecutePythonCallback("poissonsolver"); else computePhi( rho_fp, phi_fp, beta, self_fields_required_precision, self_fields_absolute_tolerance, self_fields_max_iters, self_fields_verbosity ); @@ -185,7 +182,7 @@ WarpX::AddSpaceChargeFieldLabFrame () #ifndef AMREX_USE_EB computeE( Efield_fp, phi_fp, beta ); #else - if (warpx_py_poissonsolver) computeE( Efield_fp, phi_fp, beta ); + if ( IsPythonCallBackInstalled("poissonsolver") ) computeE( Efield_fp, phi_fp, beta ); #endif // Compute the magnetic field diff --git a/Source/Python/WarpXWrappers.H b/Source/Python/WarpXWrappers.H index 42066b6729d..8d2ee43648a 100644 --- a/Source/Python/WarpXWrappers.H +++ b/Source/Python/WarpXWrappers.H @@ -48,19 +48,9 @@ extern "C" { typedef void(*WARPX_CALLBACK_PY_FUNC_0)(); - void warpx_set_callback_py_afterinit (WARPX_CALLBACK_PY_FUNC_0); - void warpx_set_callback_py_beforeEsolve (WARPX_CALLBACK_PY_FUNC_0); - void warpx_set_callback_py_poissonsolver (WARPX_CALLBACK_PY_FUNC_0); - void warpx_set_callback_py_afterEsolve (WARPX_CALLBACK_PY_FUNC_0); - void warpx_set_callback_py_beforedeposition (WARPX_CALLBACK_PY_FUNC_0); - void warpx_set_callback_py_afterdeposition (WARPX_CALLBACK_PY_FUNC_0); - void warpx_set_callback_py_particlescraper (WARPX_CALLBACK_PY_FUNC_0); - void warpx_set_callback_py_particleloader (WARPX_CALLBACK_PY_FUNC_0); - void warpx_set_callback_py_beforestep (WARPX_CALLBACK_PY_FUNC_0); - void warpx_set_callback_py_afterstep (WARPX_CALLBACK_PY_FUNC_0); - void warpx_set_callback_py_afterrestart (WARPX_CALLBACK_PY_FUNC_0); - void warpx_set_callback_py_particleinjection (WARPX_CALLBACK_PY_FUNC_0); - void warpx_set_callback_py_appliedfields (WARPX_CALLBACK_PY_FUNC_0); + void warpx_set_callback_py (const char* char_callback_name, + WARPX_CALLBACK_PY_FUNC_0 callback); + void warpx_clear_callback_py (const char* char_callback_name); void warpx_evolve (int numsteps); // -1 means the inputs parameter will be used. diff --git a/Source/Python/WarpXWrappers.cpp b/Source/Python/WarpXWrappers.cpp index d4589ec8a1a..7350bbbb6a9 100644 --- a/Source/Python/WarpXWrappers.cpp +++ b/Source/Python/WarpXWrappers.cpp @@ -164,14 +164,8 @@ namespace { WarpX& warpx = WarpX::GetInstance(); warpx.InitData(); - if (warpx_py_afterinit) { - WARPX_PROFILE("warpx_py_afterinit"); - warpx_py_afterinit(); - } - if (warpx_py_particleloader) { - WARPX_PROFILE("warpx_py_particleloader"); - warpx_py_particleloader(); - } + ExecutePythonCallback("afterinit"); + ExecutePythonCallback("particleloader"); } void warpx_finalize () @@ -179,57 +173,17 @@ namespace WarpX::ResetInstance(); } - void warpx_set_callback_py_afterinit (WARPX_CALLBACK_PY_FUNC_0 callback) - { - warpx_py_afterinit = callback; - } - void warpx_set_callback_py_beforeEsolve (WARPX_CALLBACK_PY_FUNC_0 callback) + void warpx_set_callback_py ( + const char* char_callback_name, WARPX_CALLBACK_PY_FUNC_0 callback) { - warpx_py_beforeEsolve = callback; + const std::string callback_name(char_callback_name); + warpx_callback_py_map[callback_name] = callback; } - void warpx_set_callback_py_poissonsolver (WARPX_CALLBACK_PY_FUNC_0 callback) - { - warpx_py_poissonsolver = callback; - } - void warpx_set_callback_py_afterEsolve (WARPX_CALLBACK_PY_FUNC_0 callback) - { - warpx_py_afterEsolve = callback; - } - void warpx_set_callback_py_beforedeposition (WARPX_CALLBACK_PY_FUNC_0 callback) - { - warpx_py_beforedeposition = callback; - } - void warpx_set_callback_py_afterdeposition (WARPX_CALLBACK_PY_FUNC_0 callback) - { - warpx_py_afterdeposition = callback; - } - void warpx_set_callback_py_particlescraper (WARPX_CALLBACK_PY_FUNC_0 callback) - { - warpx_py_particlescraper = callback; - } - void warpx_set_callback_py_particleloader (WARPX_CALLBACK_PY_FUNC_0 callback) - { - warpx_py_particleloader = callback; - } - void warpx_set_callback_py_beforestep (WARPX_CALLBACK_PY_FUNC_0 callback) - { - warpx_py_beforestep = callback; - } - void warpx_set_callback_py_afterstep (WARPX_CALLBACK_PY_FUNC_0 callback) - { - warpx_py_afterstep = callback; - } - void warpx_set_callback_py_afterrestart (WARPX_CALLBACK_PY_FUNC_0 callback) - { - warpx_py_afterrestart = callback; - } - void warpx_set_callback_py_particleinjection (WARPX_CALLBACK_PY_FUNC_0 callback) - { - warpx_py_particleinjection = callback; - } - void warpx_set_callback_py_appliedfields (WARPX_CALLBACK_PY_FUNC_0 callback) + + void warpx_clear_callback_py (const char* char_callback_name) { - warpx_py_appliedfields = callback; + const std::string callback_name(char_callback_name); + warpx_callback_py_map.erase(callback_name); } void warpx_evolve (int numsteps) diff --git a/Source/Python/WarpX_py.H b/Source/Python/WarpX_py.H index 42697fbac17..efc248591b1 100644 --- a/Source/Python/WarpX_py.H +++ b/Source/Python/WarpX_py.H @@ -9,24 +9,30 @@ #define WARPX_PY_H_ #include "WarpXWrappers.H" +#include "Utils/WarpXProfilerWrapper.H" +#include +#include -extern "C" { +/** + * Declare global map to hold python callback functions. + * + * The keys of the map describe at what point in the simulation the python + * functions will be called. Currently supported keys (callback points) are + * afterinit, beforeEsolve, poissonsolver, afterEsolve, beforedeposition, + * afterdeposition, particlescraper, particleloader, beforestep, afterstep, + * afterrestart, particleinjection and appliedfields. +*/ +extern std::map< std::string, WARPX_CALLBACK_PY_FUNC_0 > warpx_callback_py_map; - extern WARPX_CALLBACK_PY_FUNC_0 warpx_py_afterinit; - extern WARPX_CALLBACK_PY_FUNC_0 warpx_py_beforeEsolve; - extern WARPX_CALLBACK_PY_FUNC_0 warpx_py_poissonsolver; - extern WARPX_CALLBACK_PY_FUNC_0 warpx_py_afterEsolve; - extern WARPX_CALLBACK_PY_FUNC_0 warpx_py_beforedeposition; - extern WARPX_CALLBACK_PY_FUNC_0 warpx_py_afterdeposition; - extern WARPX_CALLBACK_PY_FUNC_0 warpx_py_particlescraper; - extern WARPX_CALLBACK_PY_FUNC_0 warpx_py_particleloader; - extern WARPX_CALLBACK_PY_FUNC_0 warpx_py_beforestep; - extern WARPX_CALLBACK_PY_FUNC_0 warpx_py_afterstep; - extern WARPX_CALLBACK_PY_FUNC_0 warpx_py_afterrestart; - extern WARPX_CALLBACK_PY_FUNC_0 warpx_py_particleinjection; - extern WARPX_CALLBACK_PY_FUNC_0 warpx_py_appliedfields; +/** + * \brief Function to check if the given name is a key in warpx_callback_py_map + */ +bool IsPythonCallBackInstalled ( std::string name ); -} +/** + * \brief Function to look for and execute Python callbacks + */ +void ExecutePythonCallback ( std::string name ); #endif diff --git a/Source/Python/WarpX_py.cpp b/Source/Python/WarpX_py.cpp index 1d69ea8f321..7e0e491f8a2 100644 --- a/Source/Python/WarpX_py.cpp +++ b/Source/Python/WarpX_py.cpp @@ -7,16 +7,18 @@ */ #include "WarpX_py.H" -WARPX_CALLBACK_PY_FUNC_0 warpx_py_afterinit = nullptr; -WARPX_CALLBACK_PY_FUNC_0 warpx_py_beforeEsolve = nullptr; -WARPX_CALLBACK_PY_FUNC_0 warpx_py_poissonsolver = nullptr; -WARPX_CALLBACK_PY_FUNC_0 warpx_py_afterEsolve = nullptr; -WARPX_CALLBACK_PY_FUNC_0 warpx_py_beforedeposition = nullptr; -WARPX_CALLBACK_PY_FUNC_0 warpx_py_afterdeposition = nullptr; -WARPX_CALLBACK_PY_FUNC_0 warpx_py_particlescraper = nullptr; -WARPX_CALLBACK_PY_FUNC_0 warpx_py_particleloader = nullptr; -WARPX_CALLBACK_PY_FUNC_0 warpx_py_beforestep = nullptr; -WARPX_CALLBACK_PY_FUNC_0 warpx_py_afterstep = nullptr; -WARPX_CALLBACK_PY_FUNC_0 warpx_py_afterrestart = nullptr; -WARPX_CALLBACK_PY_FUNC_0 warpx_py_particleinjection = nullptr; -WARPX_CALLBACK_PY_FUNC_0 warpx_py_appliedfields = nullptr; +std::map< std::string, WARPX_CALLBACK_PY_FUNC_0 > warpx_callback_py_map; + +bool IsPythonCallBackInstalled ( std::string name ) +{ + return (warpx_callback_py_map.count(name) == 1u); +} + +// Execute Python callbacks of the type given by the input string +void ExecutePythonCallback ( std::string name ) +{ + if ( IsPythonCallBackInstalled(name) ) { + WARPX_PROFILE("warpx_py_"+name); + warpx_callback_py_map[name](); + } +}