Skip to content

Commit 26edd3f

Browse files
committed
fix: adding loopsolver optional depencency + admm solver option
1 parent ab1fa90 commit 26edd3f

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

LoopStructural/interpolators/_discrete_interpolator.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,8 @@ def build_inequality_matrix(self):
483483
for c in self.ineq_constraints.values():
484484
mats.append(c['matrix'])
485485
bounds.append(c['bounds'])
486+
if len(mats) == 0:
487+
return None, None
486488
Q = sparse.vstack(mats)
487489
bounds = np.hstack(bounds)
488490
return Q, bounds
@@ -514,6 +516,7 @@ def solve_system(
514516
self.c = np.zeros(self.support.n_nodes)
515517
self.c[:] = np.nan
516518
A, b = self.build_matrix()
519+
Q, bounds = self.build_inequality_matrix()
517520
if callable(solver):
518521
logger.warning('Using custom solver')
519522
self.c = solver(A.tocsr(), b)
@@ -522,7 +525,7 @@ def solve_system(
522525
return True
523526
## solve with lsmr
524527
if isinstance(solver, str):
525-
if solver not in ['cg', 'lsmr']:
528+
if solver not in ['cg', 'lsmr', 'admm']:
526529
logger.warning(
527530
f'Unknown solver {solver} using cg. \n Available solvers are cg and lsmr or a custom solver as a callable function'
528531
)
@@ -557,6 +560,35 @@ def solve_system(
557560
self.up_to_date = True
558561
logger.info("Interpolation took %f seconds" % (time() - starttime))
559562
return True
563+
elif solver == 'admm':
564+
logger.info("Solving using admm")
565+
if Q is None:
566+
logger.warning("No inequality constraints, using lsmr")
567+
return self.solve_system('lsmr', solver_kwargs)
568+
569+
try:
570+
from loopsolver import admm_solve
571+
except ImportError:
572+
logger.warning(
573+
"Cannot import admm solver. Please install loopsolver or use lsmr or cg"
574+
)
575+
return False
576+
try:
577+
res = admm_solve(
578+
A,
579+
b,
580+
Q,
581+
bounds,
582+
x0=solver_kwargs.pop('x0', np.zeros(A.shape[1])),
583+
admm_weight=solver_kwargs.pop('admm_weight', 0.01),
584+
nmajor=solver_kwargs.pop('nmajor', 200),
585+
linsys_solver_kwargs=solver_kwargs,
586+
)
587+
self.c = res
588+
self.up_to_date = True
589+
except ValueError as e:
590+
logger.error(f"ADMM solver failed: {e}")
591+
return False
560592
return False
561593

562594
def update(self) -> bool:

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ dependencies = [
4242
dynamic = ['version']
4343

4444
[project.optional-dependencies]
45-
all = ['loopstructural[visualisation,export]']
45+
all = ['loopstructural[visualisation,inequalities,export]']
4646
visualisation = [
4747
"matplotlib",
4848
"pyvista",
@@ -56,6 +56,8 @@ jupyter = [
5656
"pyvista[all]",
5757
"tqdm"
5858
]
59+
inequalities = [
60+
"loopsolver"]
5961
[project.urls]
6062
Documentation = 'https://Loop3d.org/LoopStructural/'
6163
"Bug Tracker" = 'https://github.com/loop3d/loopstructural/issues'

0 commit comments

Comments
 (0)