Skip to content

Commit 5fdd296

Browse files
committed
fix: default solver is 'cg' when none specified
1 parent 0e75342 commit 5fdd296

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

LoopStructural/interpolators/_discrete_interpolator.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def __init__(self, support, data={}, c=None, up_to_date=False):
6262
self.interpolation_weights = {}
6363
logger.info("Creating discrete interpolator with {} degrees of freedom".format(self.nx))
6464
self.type = InterpolatorType.BASE_DISCRETE
65-
self.c = np.zeros(self.support.n_nodes)
6665

6766
@property
6867
def nx(self) -> int:
@@ -140,7 +139,7 @@ def reset(self):
140139
"""
141140
self.constraints = {}
142141
self.c_ = 0
143-
logger.debug("Resetting interpolation constraints")
142+
logger.info("Resetting interpolation constraints")
144143

145144
def add_constraints_to_least_squares(self, A, B, idc, w=1.0, name="undefined"):
146145
"""
@@ -511,14 +510,15 @@ def build_inequality_matrix(self):
511510
mats.append(c['matrix'])
512511
bounds.append(c['bounds'])
513512
if len(mats) == 0:
514-
return None, None
513+
return sparse.csr_matrix((0, self.nx), dtype=float), np.zeros((0, 3))
515514
Q = sparse.vstack(mats)
516515
bounds = np.vstack(bounds)
517516
return Q, bounds
518517

519518
def solve_system(
520519
self,
521520
solver: Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]] = None,
521+
tol: Optional[float] = None,
522522
solver_kwargs: dict = {},
523523
) -> bool:
524524
"""
@@ -557,10 +557,18 @@ def solve_system(
557557
f'Unknown solver {solver} using cg. \n Available solvers are cg and lsmr or a custom solver as a callable function'
558558
)
559559
solver = 'cg'
560+
if solver is None:
561+
solver = 'cg'
560562
if solver == 'cg':
561563
logger.info("Solving using cg")
562-
ATA = A.T.dot(A)
563-
ATB = A.T.dot(b)
564+
if 'atol' not in solver_kwargs or 'rtol' not in solver_kwargs:
565+
if tol is not None:
566+
solver_kwargs['atol'] = tol
567+
568+
ATA = A.T @ A
569+
ATB = A.T @ b
570+
logger.info(f"Solver kwargs: {solver_kwargs}")
571+
564572
res = sparse.linalg.cg(ATA, ATB, **solver_kwargs)
565573
if res[1] > 0:
566574
logger.warning(
@@ -571,6 +579,13 @@ def solve_system(
571579
return True
572580
elif solver == 'lsmr':
573581
logger.info("Solving using lsmr")
582+
if 'atol' not in solver_kwargs:
583+
if tol is not None:
584+
solver_kwargs['atol'] = tol
585+
if 'btol' not in solver_kwargs:
586+
if tol is not None:
587+
solver_kwargs['btol'] = tol
588+
logger.info(f"Solver kwargs: {solver_kwargs}")
574589
res = sparse.linalg.lsmr(A, b, **solver_kwargs)
575590
if res[1] == 1 or res[1] == 4 or res[1] == 2 or res[1] == 5:
576591
self.c = res[0]
@@ -622,6 +637,7 @@ def solve_system(
622637
except ValueError as e:
623638
logger.error(f"ADMM solver failed: {e}")
624639
return False
640+
logger.info(f"{solver} not recognised")
625641
return False
626642

627643
def update(self) -> bool:
@@ -641,7 +657,8 @@ def update(self) -> bool:
641657
return False
642658
if not self.up_to_date:
643659
self.setup_interpolator()
644-
return self.solve_system(self.solver)
660+
self.up_to_date = self.solve_system(self.solver)
661+
return self.up_to_date
645662

646663
def evaluate_value(self, locations: np.ndarray) -> np.ndarray:
647664
"""Evaluate the value of the interpolator at location

0 commit comments

Comments
 (0)