Skip to content

Commit 4af60f9

Browse files
committed
fix: adding presolve and postsolve calls for interpolator
1 parent 2b05f69 commit 4af60f9

File tree

1 file changed

+54
-33
lines changed

1 file changed

+54
-33
lines changed

LoopStructural/interpolators/_discrete_interpolator.py

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,28 @@ def set_interpolation_weights(self, weights):
132132
self.up_to_date = False
133133
self.interpolation_weights[key] = weights[key]
134134

135+
def _pre_solve(self):
136+
"""
137+
Pre solve function to be run before solving the interpolation
138+
"""
139+
self.c = np.zeros(self.support.n_nodes)
140+
self.c[:] = np.nan
141+
return True
142+
143+
def _post_solve(self):
144+
"""Post solve function(s) to be run after the solver has been called"""
145+
self.clear_constraints()
146+
return True
147+
148+
def clear_constraints(self):
149+
"""
150+
Clear the constraints from the interpolator, this makes sure we are not storing
151+
the constraints after the solver has been run
152+
"""
153+
self.constraints = {}
154+
self.ineq_constraints = {}
155+
self.equal_constraints = {}
156+
135157
def reset(self):
136158
"""
137159
Reset the interpolation constraints
@@ -540,43 +562,37 @@ def solve_system(
540562
541563
"""
542564
starttime = time()
543-
self.c = np.zeros(self.support.n_nodes)
544-
self.c[:] = np.nan
565+
if not self._pre_solve():
566+
raise ValueError("Pre solve failed")
567+
545568
A, b = self.build_matrix()
546569
Q, bounds = self.build_inequality_matrix()
547570
if callable(solver):
548571
logger.warning('Using custom solver')
549572
self.c = solver(A.tocsr(), b)
550573
self.up_to_date = True
551-
552-
return True
553-
## solve with lsmr
554-
if isinstance(solver, str):
574+
elif isinstance(solver, str) or solver is None:
555575
if solver not in ['cg', 'lsmr', 'admm']:
556576
logger.warning(
557577
f'Unknown solver {solver} using cg. \n Available solvers are cg and lsmr or a custom solver as a callable function'
558578
)
559579
solver = 'cg'
560-
if solver is None:
561-
solver = 'cg'
562580
if solver == 'cg':
563581
logger.info("Solving using cg")
564582
if 'atol' not in solver_kwargs or 'rtol' not in solver_kwargs:
565583
if tol is not None:
566584
solver_kwargs['atol'] = tol
567585

568-
ATA = A.T @ A
569-
ATB = A.T @ b
570586
logger.info(f"Solver kwargs: {solver_kwargs}")
571587

572-
res = sparse.linalg.cg(ATA, ATB, **solver_kwargs)
588+
res = sparse.linalg.cg(A.T @ A, A.T @ b, **solver_kwargs)
573589
if res[1] > 0:
574590
logger.warning(
575591
f'CG reached iteration limit ({res[1]})and did not converge, check input data. Setting solution to last iteration'
576592
)
577593
self.c = res[0]
578594
self.up_to_date = True
579-
return True
595+
580596
elif solver == 'lsmr':
581597
logger.info("Solving using lsmr")
582598
if 'atol' not in solver_kwargs:
@@ -600,8 +616,7 @@ def solve_system(
600616
)
601617
self.c = res[0]
602618
self.up_to_date = True
603-
logger.info("Interpolation took %f seconds" % (time() - starttime))
604-
return True
619+
605620
elif solver == 'admm':
606621
logger.info("Solving using admm")
607622

@@ -616,29 +631,35 @@ def solve_system(
616631

617632
try:
618633
from loopsolver import admm_solve
634+
635+
try:
636+
linsys_solver = solver_kwargs.pop('linsys_solver', 'lsmr')
637+
res = admm_solve(
638+
A,
639+
b,
640+
Q,
641+
bounds,
642+
x0=x0,
643+
admm_weight=solver_kwargs.pop('admm_weight', 0.01),
644+
nmajor=solver_kwargs.pop('nmajor', 200),
645+
linsys_solver_kwargs=solver_kwargs,
646+
linsys_solver=linsys_solver,
647+
)
648+
self.c = res
649+
self.up_to_date = True
650+
except ValueError as e:
651+
logger.error(f"ADMM solver failed: {e}")
652+
self.up_to_date = False
619653
except ImportError:
620654
logger.warning(
621655
"Cannot import admm solver. Please install loopsolver or use lsmr or cg"
622656
)
623-
return False
624-
try:
625-
res = admm_solve(
626-
A,
627-
b,
628-
Q,
629-
bounds,
630-
x0=x0,
631-
admm_weight=solver_kwargs.pop('admm_weight', 0.01),
632-
nmajor=solver_kwargs.pop('nmajor', 200),
633-
linsys_solver_kwargs=solver_kwargs,
634-
)
635-
self.c = res
636-
self.up_to_date = True
637-
except ValueError as e:
638-
logger.error(f"ADMM solver failed: {e}")
639-
return False
640-
logger.info(f"{solver} not recognised")
641-
return False
657+
self.up_to_date = False
658+
else:
659+
logger.error(f"Unknown solver {solver}")
660+
self.up_to_date = False
661+
# self._post_solve()
662+
return self.up_to_date
642663

643664
def update(self) -> bool:
644665
"""

0 commit comments

Comments
 (0)