Skip to content

Commit fba928b

Browse files
committed
fix: adding cg as a solver and dict for solver params
1 parent 69406e7 commit fba928b

File tree

1 file changed

+35
-33
lines changed

1 file changed

+35
-33
lines changed

LoopStructural/interpolators/_discrete_interpolator.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -620,14 +620,7 @@ def _solve_pyamg(self, A, B, tol=1e-12, x0=None, verb=False, **kwargs):
620620
def solve_system(
621621
self,
622622
solver: Optional[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray]] = None,
623-
btol=1e-6,
624-
atol=1e-6,
625-
maxiter=None,
626-
conlim=1e8,
627-
show=False,
628-
x0=None,
629-
calc_var=False,
630-
damp=0.0,
623+
solver_kwargs: dict = {},
631624
**kwargs,
632625
):
633626
"""
@@ -658,32 +651,41 @@ def solve_system(
658651
self.up_to_date = True
659652

660653
return True
661-
res = sparse.linalg.lsmr(
662-
A,
663-
b,
664-
damp=damp,
665-
atol=atol,
666-
btol=btol,
667-
maxiter=maxiter,
668-
conlim=conlim,
669-
show=show,
670-
x0=x0,
671-
)
672-
if res[1] == 1 or res[1] == 4 or res[1] == 2 or res[1] == 5:
673-
self.c = res[0]
674-
elif res[1] == 0:
675-
logger.warning("Solution to least squares problem is all zeros, check input data")
676-
elif res[1] == 3 or res[1] == 6:
677-
logger.warning("COND(A) seems to be greater than CONLIM, check input data")
678-
# self.c = res[0]
679-
elif res[1] == 7:
680-
logger.warning(
681-
"LSMR reached iteration limit and did not converge, check input data. Setting solution to last iteration"
682-
)
654+
## solve with lsmr
655+
if isinstance(solver, str):
656+
if solver not in ['cg', 'lsmr']:
657+
logger.warning(
658+
f'Unknown solver {solver} using cg. \n Available solvers are cg and lsmr or a custom solver as a callable function'
659+
)
660+
solver = 'cg'
661+
if solver == 'cg':
662+
ATA = A.T.dot(A)
663+
ATB = A.T.dot(b)
664+
res = sparse.linalg.cg(ATA, ATB, **solver_kwargs)
665+
if res[1] > 0:
666+
logger.warning(
667+
f'CG reached iteration limit ({res[1]})and did not converge, check input data. Setting solution to last iteration'
668+
)
683669
self.c = res[0]
684-
self.up_to_date = True
685-
logger.info("Interpolation took %f seconds" % (time() - starttime))
686-
return True
670+
self.up_to_date = True
671+
return True
672+
elif solver == 'lsmr':
673+
res = sparse.linalg.lsmr(A, b, **solver_kwargs)
674+
if res[1] == 1 or res[1] == 4 or res[1] == 2 or res[1] == 5:
675+
self.c = res[0]
676+
elif res[1] == 0:
677+
logger.warning("Solution to least squares problem is all zeros, check input data")
678+
elif res[1] == 3 or res[1] == 6:
679+
logger.warning("COND(A) seems to be greater than CONLIM, check input data")
680+
# self.c = res[0]
681+
elif res[1] == 7:
682+
logger.warning(
683+
"LSMR reached iteration limit and did not converge, check input data. Setting solution to last iteration"
684+
)
685+
self.c = res[0]
686+
self.up_to_date = True
687+
logger.info("Interpolation took %f seconds" % (time() - starttime))
688+
return True
687689

688690
def update(self):
689691
"""

0 commit comments

Comments
 (0)