Skip to content

Commit

Permalink
Refactor python callback handling (#2703)
Browse files Browse the repository at this point in the history
* added support to uninstall an external Poisson solver and return to using the default MLMG solver; also updated some callbacks.py calls to Python3

* refactor callback handling - use a map to handle all the different callbacks

* warpx_callback_py_map does not need to link to C

* Apply suggestions from code review

Co-authored-by: Axel Huebl <axel.huebl@plasma.ninja>

* further suggested changes from code review

* added function ExecutePythonCallback to reduce code duplication

* moved ExecutePythonCallback to WarpX_py

* added function IsPythonCallbackInstalled

Co-authored-by: Axel Huebl <axel.huebl@plasma.ninja>
  • Loading branch information
roelof-groenewald and ax3l authored Jan 19, 2022
1 parent 7385857 commit 4e9d98c
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 170 deletions.
65 changes: 38 additions & 27 deletions Python/pywarpx/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,32 +86,36 @@ def __init__(self,name=None,lcallonce=0):
self.lcallonce = lcallonce

def __call__(self,*args,**kw):
"Call all of the functions in the list"
"""Call all of the functions in the list"""
tt = self.callfuncsinlist(*args,**kw)
self.time = self.time + tt
if self.lcallonce: self.funcs = []

def clearlist(self):
"""Unregister/clear out all registered C callbacks"""
self.funcs = []
libwarpx.libwarpx_so.warpx_clear_callback_py(
ctypes.c_char_p(self.name.encode('utf-8'))
)

def __nonzero__(self):
"Returns True if functions are installed, otherwise False"
def __bool__(self):
"""Returns True if functions are installed, otherwise False"""
return self.hasfuncsinstalled()

def __len__(self):
"Returns number of functions installed"
"""Returns number of functions installed"""
return len(self.funcs)

def hasfuncsinstalled(self):
"Checks if there are any functions installed"
"""Checks if there are any functions installed"""
return len(self.funcs) > 0

def _getmethodobject(self,func):
"For call backs that are methods, returns the method's instance"
"""For call backs that are methods, returns the method's instance"""
return func[0]

def callbackfunclist(self):
"Generator returning callable functions from the list"
"""Generator returning callable functions from the list"""
funclistcopy = copy.copy(self.funcs)
for f in funclistcopy:
if isinstance(f,list):
Expand Down Expand Up @@ -147,15 +151,16 @@ def callbackfunclist(self):
yield result

def installfuncinlist(self,f):
"Check if the specified function is installed"
"""Check if the specified function is installed"""
if len(self.funcs) == 0:
# If this is the first function installed, set the callback in the C++
# to call this class instance.
# Note that the _c_func must be saved.
_CALLBACK_FUNC_0 = ctypes.CFUNCTYPE(None)
self._c_func = _CALLBACK_FUNC_0(self)
callback_setter = getattr(libwarpx.libwarpx_so, f'warpx_set_callback_py_{self.name}')
callback_setter(self._c_func)
libwarpx.libwarpx_so.warpx_set_callback_py(
ctypes.c_char_p(self.name.encode('utf-8')), self._c_func
)
if isinstance(f,types.MethodType):
# --- If the function is a method of a class instance, then save a full
# --- reference to that instance and the method name.
Expand All @@ -177,51 +182,58 @@ def installfuncinlist(self,f):
self.funcs.append(f)

def uninstallfuncinlist(self,f):
"Uninstall the specified function"
"""Uninstall the specified function"""
# --- An element by element search is needed
# --- f can be a function or method object, or a name (string).
# --- Note that method objects can not be removed by name.
funclistcopy = copy.copy(self.funcs)
for func in funclistcopy:
if f == func:
self.funcs.remove(f)
return
break
elif isinstance(func,list) and isinstance(f,types.MethodType):
object = self._getmethodobject(func)
if f.im_self is object and f.__name__ == func[1]:
if f.__self__ is object and f.__name__ == func[1]:
self.funcs.remove(func)
return
break
elif isinstance(func,str):
if f.__name__ == func:
self.funcs.remove(func)
return
break
elif isinstance(f,str):
if isinstance(func,str): funcname = func
elif isinstance(func,list): funcname = None
else: funcname = func.__name__
if f == funcname:
self.funcs.remove(func)
return
raise Exception('Warning: no such function had been installed')
break

# check that a function was removed
if len(self.funcs) == len(funclistcopy):
raise Exception(f'Warning: no function, {f}, had been installed')

# if there are no functions left, remove the C callback
if not self.hasfuncsinstalled():
self.clearlist()

def isinstalledfuncinlist(self,f):
"Checks if the specified function is installed"
"""Checks if the specified function is installed"""
# --- An element by element search is needed
funclistcopy = copy.copy(self.funcs)
for func in funclistcopy:
if f == func:
return 1
elif isinstance(func,list) and isinstance(f,types.MethodType):
object = self._getmethodobject(func)
if f.im_self is object and f.__name__ == func[1]:
if f.__self__ is object and f.__name__ == func[1]:
return 1
elif isinstance(func,str):
if f.__name__ == func:
return 1
return 0

def callfuncsinlist(self,*args,**kw):
"Call the functions in the list"
"""Call the functions in the list"""
bb = time.time()
for f in self.callbackfunclist():
#barrier()
Expand Down Expand Up @@ -320,16 +332,15 @@ def callfrompoissonsolver(f):
installpoissonsolver(f)
return f
def installpoissonsolver(f):
"""Adds a function to solve Poisson's equation. Note that the C++ object
warpx_py_poissonsolver is declared as a nullptr but once the call to set it
to _c_poissonsolver below is executed it is no longer a nullptr, and therefore
if (warpx_py_poissonsolver) evaluates to True. For this reason a poissonsolver
cannot be uninstalled with the uninstallfuncinlist functionality at present."""
"""Installs an external function to solve Poisson's equation"""
if _poissonsolver.hasfuncsinstalled():
raise RuntimeError('Only one field solver can be installed.')
raise RuntimeError("Only one external Poisson solver can be installed.")
_poissonsolver.installfuncinlist(f)
def uninstallpoissonsolver(f):
"""Removes the external function to solve Poisson's equation"""
_poissonsolver.uninstallfuncinlist(f)
def isinstalledpoissonsolver(f):
"""Checks if the function is called for a field solve"""
"""Checks if the function is called to solve Poisson's equation"""
return _poissonsolver.isinstalledfuncinlist(f)

# ----------------------------------------------------------------------------
Expand Down
54 changes: 13 additions & 41 deletions Source/Evolve/WarpXEvolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ WarpX::Evolve (int numsteps)
if (verbose) {
amrex::Print() << "\nSTEP " << step+1 << " starts ...\n";
}
if (warpx_py_beforestep) {
WARPX_PROFILE("warpx_py_beforestep");
warpx_py_beforestep();
}
ExecutePythonCallback("beforestep");

amrex::LayoutData<amrex::Real>* cost = WarpX::getCosts(0);
if (cost) {
Expand Down Expand Up @@ -168,10 +165,7 @@ WarpX::Evolve (int numsteps)
// Main PIC operation:
// gather fields, push particles, deposit sources, update fields

if (warpx_py_particleinjection) {
WARPX_PROFILE("warpx_py_particleinjection");
warpx_py_particleinjection();
}
ExecutePythonCallback("particleinjection");
// Electrostatic case: only gather fields and push particles,
// deposition and calculation of fields done further below
if (do_electrostatic != ElectrostaticSolverAlgo::None)
Expand Down Expand Up @@ -307,10 +301,7 @@ WarpX::Evolve (int numsteps)
}

if( do_electrostatic != ElectrostaticSolverAlgo::None ) {
if (warpx_py_beforeEsolve) {
WARPX_PROFILE("warpx_py_beforeEsolve");
warpx_py_beforeEsolve();
}
ExecutePythonCallback("beforeEsolve");
// Electrostatic solver:
// For each species: deposit charge and add the associated space-charge
// E and B field to the grid ; this is done at the end of the PIC
Expand All @@ -319,18 +310,12 @@ WarpX::Evolve (int numsteps)
// and so that the fields are at the correct time in the output.
bool const reset_fields = true;
ComputeSpaceChargeField( reset_fields );
if (warpx_py_afterEsolve) {
WARPX_PROFILE("warpx_py_afterEsolve");
warpx_py_afterEsolve();
}
ExecutePythonCallback("afterEsolve");
}

// warpx_py_afterstep runs with the updated global time. It is included
// afterstep callback runs with the updated global time. It is included
// in the evolve timing.
if (warpx_py_afterstep) {
WARPX_PROFILE("warpx_py_afterstep");
warpx_py_afterstep();
}
ExecutePythonCallback("afterstep");

/// reduced diags
if (reduced_diags->m_plot_rd != 0)
Expand Down Expand Up @@ -387,20 +372,13 @@ WarpX::OneStep_nosub (Real cur_time)
// from p^{n-1/2} to p^{n+1/2}
// Deposit current j^{n+1/2}
// Deposit charge density rho^{n}
if (warpx_py_particlescraper) {
WARPX_PROFILE("warpx_py_particlescraper");
warpx_py_particlescraper();
}
if (warpx_py_beforedeposition) {
WARPX_PROFILE("warpx_py_beforedeposition");
warpx_py_beforedeposition();
}

ExecutePythonCallback("particlescraper");
ExecutePythonCallback("beforedeposition");

PushParticlesandDepose(cur_time);

if (warpx_py_afterdeposition) {
WARPX_PROFILE("warpx_py_afterdeposition");
warpx_py_afterdeposition();
}
ExecutePythonCallback("afterdeposition");

// Synchronize J and rho
SyncCurrent();
Expand All @@ -422,10 +400,7 @@ WarpX::OneStep_nosub (Real cur_time)
if (do_pml && pml_has_particles) CopyJPML();
if (do_pml && do_pml_j_damping) DampJPML();

if (warpx_py_beforeEsolve) {
WARPX_PROFILE("warpx_py_beforeEsolve");
warpx_py_beforeEsolve();
}
ExecutePythonCallback("beforeEsolve");

// Push E and B from {n} to {n+1}
// (And update guard cells immediately afterwards)
Expand Down Expand Up @@ -507,10 +482,7 @@ WarpX::OneStep_nosub (Real cur_time)
FillBoundaryB(guard_cells.ng_alloc_EB);
} // !PSATD

if (warpx_py_afterEsolve) {
WARPX_PROFILE("warpx_py_afterEsolve");
warpx_py_afterEsolve();
}
ExecutePythonCallback("afterEsolve");
}

void
Expand Down
7 changes: 2 additions & 5 deletions Source/FieldSolver/ElectrostaticSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,7 @@ WarpX::AddSpaceChargeFieldLabFrame ()
std::array<Real, 3> beta = {0._rt};

// Compute the potential phi, by solving the Poisson equation
if (warpx_py_poissonsolver) {
WARPX_PROFILE("warpx_py_poissonsolver");
warpx_py_poissonsolver();
}
if ( IsPythonCallBackInstalled("poissonsolver") ) ExecutePythonCallback("poissonsolver");
else computePhi( rho_fp, phi_fp, beta, self_fields_required_precision,
self_fields_absolute_tolerance, self_fields_max_iters,
self_fields_verbosity );
Expand All @@ -185,7 +182,7 @@ WarpX::AddSpaceChargeFieldLabFrame ()
#ifndef AMREX_USE_EB
computeE( Efield_fp, phi_fp, beta );
#else
if (warpx_py_poissonsolver) computeE( Efield_fp, phi_fp, beta );
if ( IsPythonCallBackInstalled("poissonsolver") ) computeE( Efield_fp, phi_fp, beta );
#endif

// Compute the magnetic field
Expand Down
16 changes: 3 additions & 13 deletions Source/Python/WarpXWrappers.H
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,9 @@ extern "C" {

typedef void(*WARPX_CALLBACK_PY_FUNC_0)();

void warpx_set_callback_py_afterinit (WARPX_CALLBACK_PY_FUNC_0);
void warpx_set_callback_py_beforeEsolve (WARPX_CALLBACK_PY_FUNC_0);
void warpx_set_callback_py_poissonsolver (WARPX_CALLBACK_PY_FUNC_0);
void warpx_set_callback_py_afterEsolve (WARPX_CALLBACK_PY_FUNC_0);
void warpx_set_callback_py_beforedeposition (WARPX_CALLBACK_PY_FUNC_0);
void warpx_set_callback_py_afterdeposition (WARPX_CALLBACK_PY_FUNC_0);
void warpx_set_callback_py_particlescraper (WARPX_CALLBACK_PY_FUNC_0);
void warpx_set_callback_py_particleloader (WARPX_CALLBACK_PY_FUNC_0);
void warpx_set_callback_py_beforestep (WARPX_CALLBACK_PY_FUNC_0);
void warpx_set_callback_py_afterstep (WARPX_CALLBACK_PY_FUNC_0);
void warpx_set_callback_py_afterrestart (WARPX_CALLBACK_PY_FUNC_0);
void warpx_set_callback_py_particleinjection (WARPX_CALLBACK_PY_FUNC_0);
void warpx_set_callback_py_appliedfields (WARPX_CALLBACK_PY_FUNC_0);
void warpx_set_callback_py (const char* char_callback_name,
WARPX_CALLBACK_PY_FUNC_0 callback);
void warpx_clear_callback_py (const char* char_callback_name);

void warpx_evolve (int numsteps); // -1 means the inputs parameter will be used.

Expand Down
66 changes: 10 additions & 56 deletions Source/Python/WarpXWrappers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,72 +164,26 @@ namespace
{
WarpX& warpx = WarpX::GetInstance();
warpx.InitData();
if (warpx_py_afterinit) {
WARPX_PROFILE("warpx_py_afterinit");
warpx_py_afterinit();
}
if (warpx_py_particleloader) {
WARPX_PROFILE("warpx_py_particleloader");
warpx_py_particleloader();
}
ExecutePythonCallback("afterinit");
ExecutePythonCallback("particleloader");
}

void warpx_finalize ()
{
WarpX::ResetInstance();
}

void warpx_set_callback_py_afterinit (WARPX_CALLBACK_PY_FUNC_0 callback)
{
warpx_py_afterinit = callback;
}
void warpx_set_callback_py_beforeEsolve (WARPX_CALLBACK_PY_FUNC_0 callback)
void warpx_set_callback_py (
const char* char_callback_name, WARPX_CALLBACK_PY_FUNC_0 callback)
{
warpx_py_beforeEsolve = callback;
const std::string callback_name(char_callback_name);
warpx_callback_py_map[callback_name] = callback;
}
void warpx_set_callback_py_poissonsolver (WARPX_CALLBACK_PY_FUNC_0 callback)
{
warpx_py_poissonsolver = callback;
}
void warpx_set_callback_py_afterEsolve (WARPX_CALLBACK_PY_FUNC_0 callback)
{
warpx_py_afterEsolve = callback;
}
void warpx_set_callback_py_beforedeposition (WARPX_CALLBACK_PY_FUNC_0 callback)
{
warpx_py_beforedeposition = callback;
}
void warpx_set_callback_py_afterdeposition (WARPX_CALLBACK_PY_FUNC_0 callback)
{
warpx_py_afterdeposition = callback;
}
void warpx_set_callback_py_particlescraper (WARPX_CALLBACK_PY_FUNC_0 callback)
{
warpx_py_particlescraper = callback;
}
void warpx_set_callback_py_particleloader (WARPX_CALLBACK_PY_FUNC_0 callback)
{
warpx_py_particleloader = callback;
}
void warpx_set_callback_py_beforestep (WARPX_CALLBACK_PY_FUNC_0 callback)
{
warpx_py_beforestep = callback;
}
void warpx_set_callback_py_afterstep (WARPX_CALLBACK_PY_FUNC_0 callback)
{
warpx_py_afterstep = callback;
}
void warpx_set_callback_py_afterrestart (WARPX_CALLBACK_PY_FUNC_0 callback)
{
warpx_py_afterrestart = callback;
}
void warpx_set_callback_py_particleinjection (WARPX_CALLBACK_PY_FUNC_0 callback)
{
warpx_py_particleinjection = callback;
}
void warpx_set_callback_py_appliedfields (WARPX_CALLBACK_PY_FUNC_0 callback)

void warpx_clear_callback_py (const char* char_callback_name)
{
warpx_py_appliedfields = callback;
const std::string callback_name(char_callback_name);
warpx_callback_py_map.erase(callback_name);
}

void warpx_evolve (int numsteps)
Expand Down
Loading

0 comments on commit 4e9d98c

Please sign in to comment.