Skip to content

Commit fc2a213

Browse files
Updated Jacobian precond functions; updated cvode step call
1 parent 8c6534b commit fc2a213

File tree

1 file changed

+73
-41
lines changed

1 file changed

+73
-41
lines changed

fidimag/common/sundials/cvode.pyx

Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ cdef extern from "cvode/cvode.h":
9292
ctypedef int (*CVRootFn)(realtype t, N_Vector y, realtype *gout, void *user_data)
9393

9494
void *CVodeCreate(int lmm, SUNContext sunctx)
95-
int CVodeStep "CVode"(void *cvode_mem, realtype tout, N_Vector yout, realtype *tret, int itask) nogil
95+
int CVode "CVode"(void *cvode_mem, realtype tout, N_Vector yout, realtype *tret, int itask) nogil
9696
int CVodeSetUserData(void *cvode_mem, void *user_data)
9797
int CVodeSetMaxOrd(void *cvode_mem, int maxord)
9898
int CVodeSetMaxNumSteps(void *cvode_mem, long int mxsteps)
@@ -146,7 +146,7 @@ cdef extern from "cvode/cvode.h":
146146

147147
int CVDlsGetNumJacEvals(void *cvode_mem, long int *njevals)
148148
int CVDlsGetNumRhsEvals(void *cvode_mem, long int *nrevalsLS)
149-
int CVSpilsGetNumJtimesEvals(void *cvode_mem, long int *njevals)
149+
int CVodeGetNumJtimesEvals(void *cvode_mem, long int *njevals)
150150

151151
cdef extern from "sunlinsol/sunlinsol_spgmr.h":
152152
int CVSpgmr(void *cvode_mem, int pretype, int max1)
@@ -177,10 +177,10 @@ cdef extern from "cvode/cvode_ls.h":
177177
# int CVSpilsSetJacTimesVecFn(void *cvode_mem, CVSpilsJacTimesVecFn jtv)
178178

179179
cdef extern from "sundials/sundials_iterative.h":
180-
int PREC_NONE
181-
int PREC_LEFT
182-
int PREC_RIGHT
183-
int PREC_BOTH
180+
int SUN_PREC_NONE
181+
int SUN_PREC_LEFT
182+
int SUN_PREC_RIGHT
183+
int SUN_PREC_BOTH
184184

185185
int MODIFIED_GS
186186
int CLASSICAL_GS
@@ -290,16 +290,15 @@ cdef int cv_jtimes_openmp(N_Vector v, N_Vector Jv, double t, N_Vector y, N_Vecto
290290
return 0
291291

292292

293-
cdef int psolve(double t, N_Vector y, N_Vector fy,
294-
N_Vector r, N_Vector z, double gamma, double delta, int lr,
295-
void * user_data, N_Vector tmp):
293+
# static int PSolve(realtype tn, N_Vector u, N_Vector fu, N_Vector r, N_Vector z,
294+
# realtype gamma, realtype delta, int lr, void *user_data);
295+
cdef int psolve(double t, N_Vector y, N_Vector fy, N_Vector r, N_Vector z,
296+
double gamma, double delta, int lr, void * user_data):
296297
copy_nv2nv(z, r)
297298
return 0
298299

299-
cdef int psolve_openmp(double t, N_Vector y, N_Vector fy,
300-
N_Vector r, N_Vector z, double gamma,
301-
double delta, int lr,
302-
void * user_data, N_Vector tmp):
300+
cdef int psolve_openmp(double t, N_Vector y, N_Vector fy, N_Vector r, N_Vector z,
301+
double gamma, double delta, int lr, void * user_data):
303302
copy_nv2nv_openmp(z, r)
304303
return 0
305304

@@ -401,13 +400,35 @@ cdef class CvodeSolver(object):
401400
self.check_flag(flag, "CVDiag")
402401
elif self.linear_solver == "spgmr":
403402
if self.has_jtimes:
403+
# The Jacobian preconditioner is set here based on the
404+
# cvDiurnal_kry.c example from Sundials 6.1.1
405+
404406
# CVSpgmr(cvode_mem, pretype, maxl) p. 27 of CVODE 2.7 manual
405-
flag = CVSpgmr(self.cvode_mem, PREC_LEFT, 300)
406-
self.check_flag(flag, "CVSpgmr")
407+
# flag = CVSpgmr(self.cvode_mem, PREC_LEFT, 300)
408+
# self.check_flag(flag, "CVSpgmr")
409+
410+
# Call SUNLinSol_SPGMR to specify the linear solver SPGMR
411+
# with left preconditioning and the default Krylov dimension
412+
LS = SUNLinSol_SPGMR(u, SUN_PREC_LEFT, 0, sunctx);
413+
# TODO:
414+
# if(check_retval((void *)LS, "SUNLinSol_SPGMR", 0)) return(1);
415+
416+
# Call CVodeSetLinearSolver to attach the linear sovler to CVode
417+
flag = CVodeSetLinearSolver(cvode_mem, LS, NULL);
418+
self.check_flag(flag, "CVodeSetLinearSolver")
419+
# if (check_retval(&retval, "CVodeSetLinearSolver", 1)) return 1;
420+
421+
# Set the Jacobian-times-vector function */
422+
flag = CVodeSetJacTimes(cvode_mem, NULL, jtv);
423+
self.check_flag(flag, "CVodeSetJacTimes")
424+
407425
# functions below in p. 37 CVODE 2.7 manual
408-
flag = CVodeSetJacTimes(self.cvode_mem, NULL, < CVSpilsJacTimesVecFn > self.jvn_fun)
409-
self.check_flag(flag, "CVSpilsSetJacTimesVecFn")
410-
flag = CVodeSetPreconditioner(self.cvode_mem, < CVSpilsPrecSetupFn > self.psetup, < CVSpilsPrecSolveFn > psolve)
426+
# flag = CVodeSetJacTimes(self.cvode_mem, NULL, < CVSpilsJacTimesVecFn > self.jvn_fun)
427+
# self.check_flag(flag, "CVSpilsSetJacTimesVecFn")
428+
429+
flag = CVodeSetPreconditioner(self.cvode_mem,
430+
< CVLsPrecSetupFn > self.Precond,
431+
< CVLsPrecSolveFn > psolve)
411432
self.check_flag(flag, "CVodeSetPreconditioner")
412433
else:
413434
# this will use the SPGMR without preconditioner and without
@@ -417,8 +438,9 @@ cdef class CvodeSolver(object):
417438
# Actually, it's the same Jacobian approximation as used
418439
# in CVDiag (only difference is CVDiag is a direct linear
419440
# solver).
420-
flag = CVSpgmr(self.cvode_mem, PREC_NONE, 300)
421-
self.check_flag(flag, "CVSpgmr")
441+
flag = SUNLinSol_SPGMR(self.cvode_mem, SUN_PREC_NONE, 300)
442+
# TODO:
443+
# self.check_flag(flag, "SUNLinSol_SPGMR")
422444
else:
423445
raise RuntimeError(
424446
"linear_solver is {}, should be spgmr or diag".format(self.linear_solver))
@@ -446,17 +468,21 @@ cdef class CvodeSolver(object):
446468
flag = CVodeReInit(self.cvode_mem, self.t, self.u_y)
447469
self.check_flag(flag, "CVodeReInit")
448470

471+
# TODO: Instead of using flags we should use Sundials internal flag
472+
# like CV_SUCCESS
449473
cpdef int run_until(self, double t_final) except -1:
450474
cdef int flag
451475
cdef double t_returned
452-
flag = CVodeStep(self.cvode_mem, t_final, self.u_y, & t_returned, CV_NORMAL)
453-
self.check_flag(flag, "CVodeStep")
476+
flag = CVode(self.cvode_mem, t_final, self.u_y, & t_returned, CV_NORMAL)
477+
self.check_flag(flag, "CVode")
454478
self.t = t_returned
455479
return 0
456480

457-
cdef int psetup(self, double t, N_Vector y, N_Vector fy,
458-
booleantype jok, booleantype * jcurPtr, double gamma,
459-
void * user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3):
481+
# From exmaples: cvDiurnal_kry.c in Sundials repo:
482+
# static int Precond(realtype tn, N_Vector u, N_Vector fu, booleantype jok,
483+
# booleantype *jcurPtr, realtype gamma, void *user_data)
484+
cdef int Precond(self, double t, N_Vector y, N_Vector fy, booleantype jok,
485+
booleantype * jcurPtr, double gamma, void * user_data):
460486
if not jok:
461487
copy_nv2arr(y, self.y)
462488
return 0
@@ -472,7 +498,7 @@ cdef class CvodeSolver(object):
472498
def stat(self):
473499
CVodeGetNumSteps(self.cvode_mem, & self.nsteps)
474500
CVodeGetNumRhsEvals(self.cvode_mem, & self.nfevals)
475-
CVSpilsGetNumJtimesEvals(self.cvode_mem, & self.njevals)
501+
CVodeGetNumJtimesEvals(self.cvode_mem, & self.njevals)
476502
return self.nsteps, self.nfevals, self.njevals
477503

478504
def get_current_step(self):
@@ -592,14 +618,20 @@ cdef class CvodeSolver_OpenMP(object):
592618
self.check_flag(flag, "CVDiag")
593619
elif self.linear_solver == "spgmr":
594620
if self.has_jtimes:
595-
# CVSpgmr(cvode_mem, pretype, maxl) p. 27 of CVODE 2.7 manual
596-
flag = CVSpgmr(self.cvode_mem, PREC_LEFT, 300)
597-
self.check_flag(flag, "CVSpgmr")
598-
# functions below in p. 37 CVODE 2.7 manual
599-
flag = CVSpilsSetJacTimesVecFn(self.cvode_mem, < CVSpilsJacTimesVecFn > self.jvn_fun)
600-
self.check_flag(flag, "CVSpilsSetJacTimesVecFn")
601-
flag = CVSpilsSetPreconditioner(self.cvode_mem, < CVSpilsPrecSetupFn > self.psetup, < CVSpilsPrecSolveFn > psolve_openmp)
602-
self.check_flag(flag, "CVSpilsSetPreconditioner")
621+
622+
LS = SUNLinSol_SPGMR(u, SUN_PREC_LEFT, 0, sunctx);
623+
624+
flag = CVodeSetLinearSolver(cvode_mem, LS, NULL);
625+
self.check_flag(flag, "CVodeSetLinearSolver")
626+
627+
flag = CVodeSetJacTimes(cvode_mem, NULL, jtv);
628+
self.check_flag(flag, "CVodeSetJacTimes")
629+
630+
flag = CVodeSetPreconditioner(self.cvode_mem,
631+
< CVLsPrecSetupFn > self.Precond,
632+
< CVLsPrecSolveFn > psolve_openmp)
633+
self.check_flag(flag, "CVodeSetPreconditioner")
634+
603635
else:
604636
# this will use the SPGMR without preconditioner and without
605637
# our computation of the product J * m'. Instead, it uses
@@ -608,8 +640,8 @@ cdef class CvodeSolver_OpenMP(object):
608640
# Actually, it's the same Jacobian approximation as used
609641
# in CVDiag (only difference is CVDiag is a direct linear
610642
# solver).
611-
flag = CVSpgmr(self.cvode_mem, PREC_NONE, 300)
612-
self.check_flag(flag, "CVSpgmr")
643+
flag = SUNLinSol_SPGMR(self.cvode_mem, SUN_PREC_NONE, 300)
644+
# self.check_flag(flag, "CVSpgmr")
613645
else:
614646
raise RuntimeError(
615647
"linear_solver is {}, should be spgmr or diag".format(self.linear_solver))
@@ -640,14 +672,14 @@ cdef class CvodeSolver_OpenMP(object):
640672
cpdef int run_until(self, double t_final) except -1:
641673
cdef int flag
642674
cdef double t_returned
643-
flag = CVodeStep(self.cvode_mem, t_final, self.u_y, & t_returned, CV_NORMAL)
644-
self.check_flag(flag, "CVodeStep")
675+
flag = CVode(self.cvode_mem, t_final, self.u_y, & t_returned, CV_NORMAL)
676+
self.check_flag(flag, "CVode")
645677
self.t = t_returned
646678
return 0
647679

648-
cdef int psetup(self, double t, N_Vector y, N_Vector fy,
649-
booleantype jok, booleantype * jcurPtr, double gamma,
650-
void * user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3):
680+
cdef int Precond(self, double t, N_Vector y, N_Vector fy,
681+
booleantype jok, booleantype * jcurPtr, double gamma,
682+
void * user_data):
651683
if not jok:
652684
copy_nv2arr_openmp(y, self.y)
653685
return 0

0 commit comments

Comments
 (0)