Skip to content

Commit

Permalink
fix: default solver is 'cg' when none specified
Browse files Browse the repository at this point in the history
  • Loading branch information
lachlangrose committed Aug 7, 2024
1 parent 0e75342 commit 5fdd296
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions LoopStructural/interpolators/_discrete_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def __init__(self, support, data={}, c=None, up_to_date=False):
self.interpolation_weights = {}
logger.info("Creating discrete interpolator with {} degrees of freedom".format(self.nx))
self.type = InterpolatorType.BASE_DISCRETE
self.c = np.zeros(self.support.n_nodes)

@property
def nx(self) -> int:
Expand Down Expand Up @@ -140,7 +139,7 @@ def reset(self):
"""
self.constraints = {}
self.c_ = 0
logger.debug("Resetting interpolation constraints")
logger.info("Resetting interpolation constraints")

def add_constraints_to_least_squares(self, A, B, idc, w=1.0, name="undefined"):
"""
Expand Down Expand Up @@ -511,14 +510,15 @@ def build_inequality_matrix(self):
mats.append(c['matrix'])
bounds.append(c['bounds'])
if len(mats) == 0:
return None, None
return sparse.csr_matrix((0, self.nx), dtype=float), np.zeros((0, 3))
Q = sparse.vstack(mats)
bounds = np.vstack(bounds)
return Q, bounds

def solve_system(
self,
solver: Optional[Union[Callable[[sparse.csr_matrix, np.ndarray], np.ndarray], str]] = None,
tol: Optional[float] = None,
solver_kwargs: dict = {},
) -> bool:
"""
Expand Down Expand Up @@ -557,10 +557,18 @@ def solve_system(
f'Unknown solver {solver} using cg. \n Available solvers are cg and lsmr or a custom solver as a callable function'
)
solver = 'cg'
if solver is None:
solver = 'cg'
if solver == 'cg':
logger.info("Solving using cg")
ATA = A.T.dot(A)
ATB = A.T.dot(b)
if 'atol' not in solver_kwargs or 'rtol' not in solver_kwargs:
if tol is not None:
solver_kwargs['atol'] = tol

ATA = A.T @ A
ATB = A.T @ b
logger.info(f"Solver kwargs: {solver_kwargs}")

res = sparse.linalg.cg(ATA, ATB, **solver_kwargs)
if res[1] > 0:
logger.warning(
Expand All @@ -571,6 +579,13 @@ def solve_system(
return True
elif solver == 'lsmr':
logger.info("Solving using lsmr")
if 'atol' not in solver_kwargs:
if tol is not None:
solver_kwargs['atol'] = tol
if 'btol' not in solver_kwargs:
if tol is not None:
solver_kwargs['btol'] = tol
logger.info(f"Solver kwargs: {solver_kwargs}")
res = sparse.linalg.lsmr(A, b, **solver_kwargs)
if res[1] == 1 or res[1] == 4 or res[1] == 2 or res[1] == 5:
self.c = res[0]
Expand Down Expand Up @@ -622,6 +637,7 @@ def solve_system(
except ValueError as e:
logger.error(f"ADMM solver failed: {e}")
return False
logger.info(f"{solver} not recognised")
return False

def update(self) -> bool:
Expand All @@ -641,7 +657,8 @@ def update(self) -> bool:
return False
if not self.up_to_date:
self.setup_interpolator()
return self.solve_system(self.solver)
self.up_to_date = self.solve_system(self.solver)
return self.up_to_date

def evaluate_value(self, locations: np.ndarray) -> np.ndarray:
"""Evaluate the value of the interpolator at location
Expand Down

0 comments on commit 5fdd296

Please sign in to comment.