@@ -132,6 +132,28 @@ def set_interpolation_weights(self, weights):
132132 self .up_to_date = False
133133 self .interpolation_weights [key ] = weights [key ]
134134
135+ def _pre_solve (self ):
136+ """
137+ Pre solve function to be run before solving the interpolation
138+ """
139+ self .c = np .zeros (self .support .n_nodes )
140+ self .c [:] = np .nan
141+ return True
142+
143+ def _post_solve (self ):
144+ """Post solve function(s) to be run after the solver has been called"""
145+ self .clear_constraints ()
146+ return True
147+
148+ def clear_constraints (self ):
149+ """
150+ Clear the constraints from the interpolator, this makes sure we are not storing
151+ the constraints after the solver has been run
152+ """
153+ self .constraints = {}
154+ self .ineq_constraints = {}
155+ self .equal_constraints = {}
156+
135157 def reset (self ):
136158 """
137159 Reset the interpolation constraints
@@ -540,43 +562,37 @@ def solve_system(
540562
541563 """
542564 starttime = time ()
543- self .c = np .zeros (self .support .n_nodes )
544- self .c [:] = np .nan
565+ if not self ._pre_solve ():
566+ raise ValueError ("Pre solve failed" )
567+
545568 A , b = self .build_matrix ()
546569 Q , bounds = self .build_inequality_matrix ()
547570 if callable (solver ):
548571 logger .warning ('Using custom solver' )
549572 self .c = solver (A .tocsr (), b )
550573 self .up_to_date = True
551-
552- return True
553- ## solve with lsmr
554- if isinstance (solver , str ):
574+ elif isinstance (solver , str ) or solver is None :
555575 if solver not in ['cg' , 'lsmr' , 'admm' ]:
556576 logger .warning (
557577 f'Unknown solver { solver } using cg. \n Available solvers are cg and lsmr or a custom solver as a callable function'
558578 )
559579 solver = 'cg'
560- if solver is None :
561- solver = 'cg'
562580 if solver == 'cg' :
563581 logger .info ("Solving using cg" )
564582 if 'atol' not in solver_kwargs or 'rtol' not in solver_kwargs :
565583 if tol is not None :
566584 solver_kwargs ['atol' ] = tol
567585
568- ATA = A .T @ A
569- ATB = A .T @ b
570586 logger .info (f"Solver kwargs: { solver_kwargs } " )
571587
572- res = sparse .linalg .cg (ATA , ATB , ** solver_kwargs )
588+ res = sparse .linalg .cg (A . T @ A , A . T @ b , ** solver_kwargs )
573589 if res [1 ] > 0 :
574590 logger .warning (
575591 f'CG reached iteration limit ({ res [1 ]} )and did not converge, check input data. Setting solution to last iteration'
576592 )
577593 self .c = res [0 ]
578594 self .up_to_date = True
579- return True
595+
580596 elif solver == 'lsmr' :
581597 logger .info ("Solving using lsmr" )
582598 if 'atol' not in solver_kwargs :
@@ -600,8 +616,7 @@ def solve_system(
600616 )
601617 self .c = res [0 ]
602618 self .up_to_date = True
603- logger .info ("Interpolation took %f seconds" % (time () - starttime ))
604- return True
619+
605620 elif solver == 'admm' :
606621 logger .info ("Solving using admm" )
607622
@@ -616,29 +631,35 @@ def solve_system(
616631
617632 try :
618633 from loopsolver import admm_solve
634+
635+ try :
636+ linsys_solver = solver_kwargs .pop ('linsys_solver' , 'lsmr' )
637+ res = admm_solve (
638+ A ,
639+ b ,
640+ Q ,
641+ bounds ,
642+ x0 = x0 ,
643+ admm_weight = solver_kwargs .pop ('admm_weight' , 0.01 ),
644+ nmajor = solver_kwargs .pop ('nmajor' , 200 ),
645+ linsys_solver_kwargs = solver_kwargs ,
646+ linsys_solver = linsys_solver ,
647+ )
648+ self .c = res
649+ self .up_to_date = True
650+ except ValueError as e :
651+ logger .error (f"ADMM solver failed: { e } " )
652+ self .up_to_date = False
619653 except ImportError :
620654 logger .warning (
621655 "Cannot import admm solver. Please install loopsolver or use lsmr or cg"
622656 )
623- return False
624- try :
625- res = admm_solve (
626- A ,
627- b ,
628- Q ,
629- bounds ,
630- x0 = x0 ,
631- admm_weight = solver_kwargs .pop ('admm_weight' , 0.01 ),
632- nmajor = solver_kwargs .pop ('nmajor' , 200 ),
633- linsys_solver_kwargs = solver_kwargs ,
634- )
635- self .c = res
636- self .up_to_date = True
637- except ValueError as e :
638- logger .error (f"ADMM solver failed: { e } " )
639- return False
640- logger .info (f"{ solver } not recognised" )
641- return False
657+ self .up_to_date = False
658+ else :
659+ logger .error (f"Unknown solver { solver } " )
660+ self .up_to_date = False
661+ # self._post_solve()
662+ return self .up_to_date
642663
643664 def update (self ) -> bool :
644665 """
0 commit comments