@@ -483,6 +483,8 @@ def build_inequality_matrix(self):
483483 for c in self .ineq_constraints .values ():
484484 mats .append (c ['matrix' ])
485485 bounds .append (c ['bounds' ])
486+ if len (mats ) == 0 :
487+ return None , None
486488 Q = sparse .vstack (mats )
487489 bounds = np .hstack (bounds )
488490 return Q , bounds
@@ -514,6 +516,7 @@ def solve_system(
514516 self .c = np .zeros (self .support .n_nodes )
515517 self .c [:] = np .nan
516518 A , b = self .build_matrix ()
519+ Q , bounds = self .build_inequality_matrix ()
517520 if callable (solver ):
518521 logger .warning ('Using custom solver' )
519522 self .c = solver (A .tocsr (), b )
@@ -522,7 +525,7 @@ def solve_system(
522525 return True
523526 ## solve with lsmr
524527 if isinstance (solver , str ):
525- if solver not in ['cg' , 'lsmr' ]:
528+ if solver not in ['cg' , 'lsmr' , 'admm' ]:
526529 logger .warning (
527530 f'Unknown solver { solver } using cg. \n Available solvers are cg and lsmr or a custom solver as a callable function'
528531 )
@@ -557,6 +560,35 @@ def solve_system(
557560 self .up_to_date = True
558561 logger .info ("Interpolation took %f seconds" % (time () - starttime ))
559562 return True
563+ elif solver == 'admm' :
564+ logger .info ("Solving using admm" )
565+ if Q is None :
566+ logger .warning ("No inequality constraints, using lsmr" )
567+ return self .solve_system ('lsmr' , solver_kwargs )
568+
569+ try :
570+ from loopsolver import admm_solve
571+ except ImportError :
572+ logger .warning (
573+ "Cannot import admm solver. Please install loopsolver or use lsmr or cg"
574+ )
575+ return False
576+ try :
577+ res = admm_solve (
578+ A ,
579+ b ,
580+ Q ,
581+ bounds ,
582+ x0 = solver_kwargs .pop ('x0' , np .zeros (A .shape [1 ])),
583+ admm_weight = solver_kwargs .pop ('admm_weight' , 0.01 ),
584+ nmajor = solver_kwargs .pop ('nmajor' , 200 ),
585+ linsys_solver_kwargs = solver_kwargs ,
586+ )
587+ self .c = res
588+ self .up_to_date = True
589+ except ValueError as e :
590+ logger .error (f"ADMM solver failed: { e } " )
591+ return False
560592 return False
561593
562594 def update (self ) -> bool :
0 commit comments