@@ -620,14 +620,7 @@ def _solve_pyamg(self, A, B, tol=1e-12, x0=None, verb=False, **kwargs):
620620 def solve_system (
621621 self ,
622622 solver : Optional [Callable [[sparse .csr_matrix , np .ndarray ], np .ndarray ]] = None ,
623- btol = 1e-6 ,
624- atol = 1e-6 ,
625- maxiter = None ,
626- conlim = 1e8 ,
627- show = False ,
628- x0 = None ,
629- calc_var = False ,
630- damp = 0.0 ,
623+ solver_kwargs : dict = {},
631624 ** kwargs ,
632625 ):
633626 """
@@ -658,32 +651,41 @@ def solve_system(
658651 self .up_to_date = True
659652
660653 return True
661- res = sparse .linalg .lsmr (
662- A ,
663- b ,
664- damp = damp ,
665- atol = atol ,
666- btol = btol ,
667- maxiter = maxiter ,
668- conlim = conlim ,
669- show = show ,
670- x0 = x0 ,
671- )
672- if res [1 ] == 1 or res [1 ] == 4 or res [1 ] == 2 or res [1 ] == 5 :
673- self .c = res [0 ]
674- elif res [1 ] == 0 :
675- logger .warning ("Solution to least squares problem is all zeros, check input data" )
676- elif res [1 ] == 3 or res [1 ] == 6 :
677- logger .warning ("COND(A) seems to be greater than CONLIM, check input data" )
678- # self.c = res[0]
679- elif res [1 ] == 7 :
680- logger .warning (
681- "LSMR reached iteration limit and did not converge, check input data. Setting solution to last iteration"
682- )
654+ ## solve with lsmr
655+ if isinstance (solver , str ):
656+ if solver not in ['cg' , 'lsmr' ]:
657+ logger .warning (
658+ f'Unknown solver { solver } using cg. \n Available solvers are cg and lsmr or a custom solver as a callable function'
659+ )
660+ solver = 'cg'
661+ if solver == 'cg' :
662+ ATA = A .T .dot (A )
663+ ATB = A .T .dot (b )
664+ res = sparse .linalg .cg (ATA , ATB , ** solver_kwargs )
665+ if res [1 ] > 0 :
666+ logger .warning (
667+ f'CG reached iteration limit ({ res [1 ]} )and did not converge, check input data. Setting solution to last iteration'
668+ )
683669 self .c = res [0 ]
684- self .up_to_date = True
685- logger .info ("Interpolation took %f seconds" % (time () - starttime ))
686- return True
670+ self .up_to_date = True
671+ return True
672+ elif solver == 'lsmr' :
673+ res = sparse .linalg .lsmr (A , b , ** solver_kwargs )
674+ if res [1 ] == 1 or res [1 ] == 4 or res [1 ] == 2 or res [1 ] == 5 :
675+ self .c = res [0 ]
676+ elif res [1 ] == 0 :
677+ logger .warning ("Solution to least squares problem is all zeros, check input data" )
678+ elif res [1 ] == 3 or res [1 ] == 6 :
679+ logger .warning ("COND(A) seems to be greater than CONLIM, check input data" )
680+ # self.c = res[0]
681+ elif res [1 ] == 7 :
682+ logger .warning (
683+ "LSMR reached iteration limit and did not converge, check input data. Setting solution to last iteration"
684+ )
685+ self .c = res [0 ]
686+ self .up_to_date = True
687+ logger .info ("Interpolation took %f seconds" % (time () - starttime ))
688+ return True
687689
688690 def update (self ):
689691 """
0 commit comments