@@ -62,7 +62,9 @@ def __init__(self, support, data={}, c=None, up_to_date=False):
6262 self .interpolation_weights = {}
6363 logger .info ("Creating discrete interpolator with {} degrees of freedom" .format (self .dof ))
6464 self .type = InterpolatorType .BASE_DISCRETE
65-
65+ self .apply_scaling_matrix = True
66+ self .add_ridge_regulatisation = True
67+ self .ridge_factor = 1e-8
6668 def set_nelements (self , nelements : int ) -> int :
6769 return self .support .set_nelements (nelements )
6870
@@ -511,6 +513,25 @@ def build_matrix(self):
511513
512514 B = np .hstack (bs )
513515 return A , B
516+ def compute_column_scaling_matrix (self , A : sparse .csr_matrix ) -> sparse .dia_matrix :
517+ """Compute column scaling matrix S for matrix A so that A @ S has columns with unit norm.
518+
519+ Parameters
520+ ----------
521+ A : sparse.csr_matrix
522+ interpolation matrix
523+
524+ Returns
525+ -------
526+ scipy.sparse.dia_matrix
527+ diagonal scaling matrix S
528+ """
529+ col_norms = sparse .linalg .norm (A , axis = 0 )
530+ scaling_factors = np .ones (A .shape [1 ])
531+ mask = col_norms > 0
532+ scaling_factors [mask ] = 1.0 / col_norms [mask ]
533+ S = sparse .diags (scaling_factors )
534+ return S
514535
515536 def add_equality_block (self , A , B ):
516537 if len (self .equal_constraints ) > 0 :
@@ -591,6 +612,15 @@ def solve_system(
591612 raise ValueError ("Pre solve failed" )
592613
593614 A , b = self .build_matrix ()
615+ if self .add_ridge_regulatisation :
616+ ridge = sparse .eye (A .shape [1 ]) * self .ridge_factor
617+ A = sparse .vstack ([A , ridge ])
618+ b = np .hstack ([b , np .zeros (A .shape [1 ])])
619+ logger .info ("Adding ridge regularisation to interpolation matrix" )
620+ if self .apply_scaling_matrix :
621+ S = self .compute_column_scaling_matrix (A )
622+ A = A @ S
623+
594624 Q , bounds = self .build_inequality_matrix ()
595625 if callable (solver ):
596626 logger .warning ('Using custom solver' )
@@ -620,12 +650,14 @@ def solve_system(
620650
621651 elif solver == 'lsmr' :
622652 logger .info ("Solving using lsmr" )
623- if 'atol' not in solver_kwargs :
624- if tol is not None :
625- solver_kwargs ['atol' ] = tol
653+ # if 'atol' not in solver_kwargs:
654+ # if tol is not None:
655+ # solver_kwargs['atol'] = tol
626656 if 'btol' not in solver_kwargs :
627657 if tol is not None :
628658 solver_kwargs ['btol' ] = tol
659+ solver_kwargs ['atol' ] = 0.
660+ logger .info (f"Setting lsmr btol to { tol } " )
629661 logger .info (f"Solver kwargs: { solver_kwargs } " )
630662 res = sparse .linalg .lsmr (A , b , ** solver_kwargs )
631663 if res [1 ] == 1 or res [1 ] == 4 or res [1 ] == 2 or res [1 ] == 5 :
@@ -684,6 +716,9 @@ def solve_system(
684716 logger .error (f"Unknown solver { solver } " )
685717 self .up_to_date = False
686718 # self._post_solve()
719+ # apply scaling matrix to solution
720+ if self .apply_scaling_matrix :
721+ self .c = S @ self .c
687722 return self .up_to_date
688723
689724 def update (self ) -> bool :
0 commit comments