@@ -62,7 +62,6 @@ def __init__(self, support, data={}, c=None, up_to_date=False):
6262 self .interpolation_weights = {}
6363 logger .info ("Creating discrete interpolator with {} degrees of freedom" .format (self .nx ))
6464 self .type = InterpolatorType .BASE_DISCRETE
65- self .c = np .zeros (self .support .n_nodes )
6665
6766 @property
6867 def nx (self ) -> int :
@@ -140,7 +139,7 @@ def reset(self):
140139 """
141140 self .constraints = {}
142141 self .c_ = 0
143- logger .debug ("Resetting interpolation constraints" )
142+ logger .info ("Resetting interpolation constraints" )
144143
145144 def add_constraints_to_least_squares (self , A , B , idc , w = 1.0 , name = "undefined" ):
146145 """
@@ -511,14 +510,15 @@ def build_inequality_matrix(self):
511510 mats .append (c ['matrix' ])
512511 bounds .append (c ['bounds' ])
513512 if len (mats ) == 0 :
514- return None , None
513+ return sparse . csr_matrix (( 0 , self . nx ), dtype = float ), np . zeros (( 0 , 3 ))
515514 Q = sparse .vstack (mats )
516515 bounds = np .vstack (bounds )
517516 return Q , bounds
518517
519518 def solve_system (
520519 self ,
521520 solver : Optional [Union [Callable [[sparse .csr_matrix , np .ndarray ], np .ndarray ], str ]] = None ,
521+ tol : Optional [float ] = None ,
522522 solver_kwargs : dict = {},
523523 ) -> bool :
524524 """
@@ -557,10 +557,18 @@ def solve_system(
557557 f'Unknown solver { solver } using cg. \n Available solvers are cg and lsmr or a custom solver as a callable function'
558558 )
559559 solver = 'cg'
560+ if solver is None :
561+ solver = 'cg'
560562 if solver == 'cg' :
561563 logger .info ("Solving using cg" )
562- ATA = A .T .dot (A )
563- ATB = A .T .dot (b )
564+ if 'atol' not in solver_kwargs or 'rtol' not in solver_kwargs :
565+ if tol is not None :
566+ solver_kwargs ['atol' ] = tol
567+
568+ ATA = A .T @ A
569+ ATB = A .T @ b
570+ logger .info (f"Solver kwargs: { solver_kwargs } " )
571+
564572 res = sparse .linalg .cg (ATA , ATB , ** solver_kwargs )
565573 if res [1 ] > 0 :
566574 logger .warning (
@@ -571,6 +579,13 @@ def solve_system(
571579 return True
572580 elif solver == 'lsmr' :
573581 logger .info ("Solving using lsmr" )
582+ if 'atol' not in solver_kwargs :
583+ if tol is not None :
584+ solver_kwargs ['atol' ] = tol
585+ if 'btol' not in solver_kwargs :
586+ if tol is not None :
587+ solver_kwargs ['btol' ] = tol
588+ logger .info (f"Solver kwargs: { solver_kwargs } " )
574589 res = sparse .linalg .lsmr (A , b , ** solver_kwargs )
575590 if res [1 ] == 1 or res [1 ] == 4 or res [1 ] == 2 or res [1 ] == 5 :
576591 self .c = res [0 ]
@@ -622,6 +637,7 @@ def solve_system(
622637 except ValueError as e :
623638 logger .error (f"ADMM solver failed: { e } " )
624639 return False
640+ logger .info (f"{ solver } not recognised" )
625641 return False
626642
627643 def update (self ) -> bool :
@@ -641,7 +657,8 @@ def update(self) -> bool:
641657 return False
642658 if not self .up_to_date :
643659 self .setup_interpolator ()
644- return self .solve_system (self .solver )
660+ self .up_to_date = self .solve_system (self .solver )
661+ return self .up_to_date
645662
646663 def evaluate_value (self , locations : np .ndarray ) -> np .ndarray :
647664 """Evaluate the value of the interpolator at location
0 commit comments