@@ -285,39 +285,66 @@ def add_value_inequality_constraints(self, w: float = 1.0):
285285 cols = self .support .elements [element [inside ]]
286286 self .add_inequality_constraints_to_matrix (a , points [:, 3 :5 ], cols , 'inequality_value' )
287287
288- def add_inequality_pairs_constraints (self , w : float = 1.0 ):
289- pairs = self .get_inequality_pairs_constraints ()
290- if pairs ['upper' ].shape [0 ] == 0 or pairs ['lower' ].shape [0 ] == 0 :
291- return
292- upper_interpolation = self .support .get_element_for_location (pairs ['upper' ])
293- lower_interpolation = self .support .get_element_for_location (pairs ['lower' ])
294- ij = np .array (
295- [
296- * np .meshgrid (
297- np .arange (0 , int (upper_interpolation [3 ].sum ()), dtype = int ),
298- np .arange (0 , int (lower_interpolation [3 ].sum ()), dtype = int ),
288+ def add_inequality_pairs_constraints (
289+ self , w : float = 1.0 , upper_bound = np .finfo (float ).eps , lower_bound = - np .inf
290+ ):
291+
292+ points = self .get_inequality_pairs_constraints ()
293+ if points .shape [0 ] > 0 :
294+
295+ # assemble a list of pairs in the model
296+ # this will make pairs even across stratigraphic boundaries
297+ # TODO add option to only add stratigraphic pairs
298+ pairs = {}
299+ k = 0
300+ for i in np .unique (points [:, self .support .dimension ]):
301+ for j in np .unique (points [:, self .support .dimension ]):
302+ if i == j :
303+ continue
304+ if tuple (sorted ([i , j ])) not in pairs :
305+ pairs [tuple (sorted ([i , j ]))] = k
306+ k += 1
307+ pairs = list (pairs .keys ())
308+ for pair in pairs :
309+ upper_points = points [points [:, self .support .dimension ] == pair [0 ]]
310+ lower_points = points [points [:, self .support .dimension ] == pair [1 ]]
311+
312+ upper_interpolation = self .support .get_element_for_location (upper_points )
313+ lower_interpolation = self .support .get_element_for_location (lower_points )
314+ ij = np .array (
315+ [
316+ * np .meshgrid (
317+ np .arange (0 , int (upper_interpolation [3 ].sum ()), dtype = int ),
318+ np .arange (0 , int (lower_interpolation [3 ].sum ()), dtype = int ),
319+ )
320+ ],
321+ dtype = int ,
299322 )
300- ],
301- dtype = int ,
302- )
303323
304- ij = ij .reshape (2 , - 1 ).T
305- rows = np .arange (0 , ij .shape [0 ], dtype = int )
306- rows = np .tile (rows , (upper_interpolation [1 ].shape [- 1 ], 1 )).T
307- rows = np .hstack ([rows , rows ])
308- a = upper_interpolation [1 ][upper_interpolation [3 ]][ij [:, 0 ]] # np.ones(ij.shape[0])
309- a = np .hstack ([a , - lower_interpolation [1 ][lower_interpolation [3 ]][ij [:, 1 ]]])
310- cols = np .hstack (
311- [
312- self .support .elements [upper_interpolation [2 ][upper_interpolation [3 ]][ij [:, 0 ]]],
313- self .support .elements [lower_interpolation [2 ][lower_interpolation [3 ]][ij [:, 1 ]]],
314- ]
315- )
324+ ij = ij .reshape (2 , - 1 ).T
325+ rows = np .arange (0 , ij .shape [0 ], dtype = int )
326+ rows = np .tile (rows , (upper_interpolation [1 ].shape [- 1 ], 1 )).T
327+ rows = np .hstack ([rows , rows ])
328+ a = upper_interpolation [1 ][upper_interpolation [3 ]][ij [:, 0 ]] # np.ones(ij.shape[0])
329+ a = np .hstack ([a , - lower_interpolation [1 ][lower_interpolation [3 ]][ij [:, 1 ]]])
330+ cols = np .hstack (
331+ [
332+ self .support .elements [
333+ upper_interpolation [2 ][upper_interpolation [3 ]][ij [:, 0 ]]
334+ ],
335+ self .support .elements [
336+ lower_interpolation [2 ][lower_interpolation [3 ]][ij [:, 1 ]]
337+ ],
338+ ]
339+ )
316340
317- bounds = np .zeros ((ij .shape [0 ], 2 ))
318- bounds [:, 0 ] = np .finfo ('float' ).eps
319- bounds [:, 1 ] = 1e10
320- self .add_inequality_constraints_to_matrix (a , bounds , cols , 'inequality_pairs' )
341+ bounds = np .zeros ((ij .shape [0 ], 2 ))
342+ bounds [:, 0 ] = lower_bound
343+ bounds [:, 1 ] = upper_bound
344+
345+ self .add_inequality_constraints_to_matrix (
346+ a , bounds , cols , f'inequality_pairs_{ pair [0 ]} _{ pair [1 ]} '
347+ )
321348
322349 def add_inequality_feature (
323350 self ,
@@ -486,7 +513,7 @@ def build_inequality_matrix(self):
486513 if len (mats ) == 0 :
487514 return None , None
488515 Q = sparse .vstack (mats )
489- bounds = np .hstack (bounds )
516+ bounds = np .vstack (bounds )
490517 return Q , bounds
491518
492519 def solve_system (
@@ -562,6 +589,12 @@ def solve_system(
562589 return True
563590 elif solver == 'admm' :
564591 logger .info ("Solving using admm" )
592+
593+ if 'x0' in solver_kwargs :
594+ x0 = solver_kwargs ['x0' ](self .support )
595+ else :
596+ x0 = np .zeros (A .shape [1 ])
597+ solver_kwargs .pop ('x0' , None )
565598 if Q is None :
566599 logger .warning ("No inequality constraints, using lsmr" )
567600 return self .solve_system ('lsmr' , solver_kwargs )
@@ -579,7 +612,7 @@ def solve_system(
579612 b ,
580613 Q ,
581614 bounds ,
582- x0 = solver_kwargs . pop ( 'x0' , np . zeros ( A . shape [ 1 ])) ,
615+ x0 = x0 ,
583616 admm_weight = solver_kwargs .pop ('admm_weight' , 0.01 ),
584617 nmajor = solver_kwargs .pop ('nmajor' , 200 ),
585618 linsys_solver_kwargs = solver_kwargs ,
0 commit comments