Skip to content

Commit ce33ac9

Browse files
committed
fix: adding inequality pairs
any data with 'pair_id' is interpreted as an inequality pair data point. Inequality pairs are added with a step threshold
1 parent b75df73 commit ce33ac9

File tree

4 files changed

+83
-47
lines changed

4 files changed

+83
-47
lines changed

LoopStructural/interpolators/_discrete_interpolator.py

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -285,39 +285,66 @@ def add_value_inequality_constraints(self, w: float = 1.0):
285285
cols = self.support.elements[element[inside]]
286286
self.add_inequality_constraints_to_matrix(a, points[:, 3:5], cols, 'inequality_value')
287287

288-
def add_inequality_pairs_constraints(self, w: float = 1.0):
289-
pairs = self.get_inequality_pairs_constraints()
290-
if pairs['upper'].shape[0] == 0 or pairs['lower'].shape[0] == 0:
291-
return
292-
upper_interpolation = self.support.get_element_for_location(pairs['upper'])
293-
lower_interpolation = self.support.get_element_for_location(pairs['lower'])
294-
ij = np.array(
295-
[
296-
*np.meshgrid(
297-
np.arange(0, int(upper_interpolation[3].sum()), dtype=int),
298-
np.arange(0, int(lower_interpolation[3].sum()), dtype=int),
288+
def add_inequality_pairs_constraints(
289+
self, w: float = 1.0, upper_bound=np.finfo(float).eps, lower_bound=-np.inf
290+
):
291+
292+
points = self.get_inequality_pairs_constraints()
293+
if points.shape[0] > 0:
294+
295+
# assemble a list of pairs in the model
296+
# this will make pairs even across stratigraphic boundaries
297+
# TODO add option to only add stratigraphic pairs
298+
pairs = {}
299+
k = 0
300+
for i in np.unique(points[:, self.support.dimension]):
301+
for j in np.unique(points[:, self.support.dimension]):
302+
if i == j:
303+
continue
304+
if tuple(sorted([i, j])) not in pairs:
305+
pairs[tuple(sorted([i, j]))] = k
306+
k += 1
307+
pairs = list(pairs.keys())
308+
for pair in pairs:
309+
upper_points = points[points[:, self.support.dimension] == pair[0]]
310+
lower_points = points[points[:, self.support.dimension] == pair[1]]
311+
312+
upper_interpolation = self.support.get_element_for_location(upper_points)
313+
lower_interpolation = self.support.get_element_for_location(lower_points)
314+
ij = np.array(
315+
[
316+
*np.meshgrid(
317+
np.arange(0, int(upper_interpolation[3].sum()), dtype=int),
318+
np.arange(0, int(lower_interpolation[3].sum()), dtype=int),
319+
)
320+
],
321+
dtype=int,
299322
)
300-
],
301-
dtype=int,
302-
)
303323

304-
ij = ij.reshape(2, -1).T
305-
rows = np.arange(0, ij.shape[0], dtype=int)
306-
rows = np.tile(rows, (upper_interpolation[1].shape[-1], 1)).T
307-
rows = np.hstack([rows, rows])
308-
a = upper_interpolation[1][upper_interpolation[3]][ij[:, 0]] # np.ones(ij.shape[0])
309-
a = np.hstack([a, -lower_interpolation[1][lower_interpolation[3]][ij[:, 1]]])
310-
cols = np.hstack(
311-
[
312-
self.support.elements[upper_interpolation[2][upper_interpolation[3]][ij[:, 0]]],
313-
self.support.elements[lower_interpolation[2][lower_interpolation[3]][ij[:, 1]]],
314-
]
315-
)
324+
ij = ij.reshape(2, -1).T
325+
rows = np.arange(0, ij.shape[0], dtype=int)
326+
rows = np.tile(rows, (upper_interpolation[1].shape[-1], 1)).T
327+
rows = np.hstack([rows, rows])
328+
a = upper_interpolation[1][upper_interpolation[3]][ij[:, 0]] # np.ones(ij.shape[0])
329+
a = np.hstack([a, -lower_interpolation[1][lower_interpolation[3]][ij[:, 1]]])
330+
cols = np.hstack(
331+
[
332+
self.support.elements[
333+
upper_interpolation[2][upper_interpolation[3]][ij[:, 0]]
334+
],
335+
self.support.elements[
336+
lower_interpolation[2][lower_interpolation[3]][ij[:, 1]]
337+
],
338+
]
339+
)
316340

317-
bounds = np.zeros((ij.shape[0], 2))
318-
bounds[:, 0] = np.finfo('float').eps
319-
bounds[:, 1] = 1e10
320-
self.add_inequality_constraints_to_matrix(a, bounds, cols, 'inequality_pairs')
341+
bounds = np.zeros((ij.shape[0], 2))
342+
bounds[:, 0] = lower_bound
343+
bounds[:, 1] = upper_bound
344+
345+
self.add_inequality_constraints_to_matrix(
346+
a, bounds, cols, f'inequality_pairs_{pair[0]}_{pair[1]}'
347+
)
321348

322349
def add_inequality_feature(
323350
self,
@@ -486,7 +513,7 @@ def build_inequality_matrix(self):
486513
if len(mats) == 0:
487514
return None, None
488515
Q = sparse.vstack(mats)
489-
bounds = np.hstack(bounds)
516+
bounds = np.vstack(bounds)
490517
return Q, bounds
491518

492519
def solve_system(
@@ -562,6 +589,12 @@ def solve_system(
562589
return True
563590
elif solver == 'admm':
564591
logger.info("Solving using admm")
592+
593+
if 'x0' in solver_kwargs:
594+
x0 = solver_kwargs['x0'](self.support)
595+
else:
596+
x0 = np.zeros(A.shape[1])
597+
solver_kwargs.pop('x0', None)
565598
if Q is None:
566599
logger.warning("No inequality constraints, using lsmr")
567600
return self.solve_system('lsmr', solver_kwargs)
@@ -579,7 +612,7 @@ def solve_system(
579612
b,
580613
Q,
581614
bounds,
582-
x0=solver_kwargs.pop('x0', np.zeros(A.shape[1])),
615+
x0=x0,
583616
admm_weight=solver_kwargs.pop('admm_weight', 0.01),
584617
nmajor=solver_kwargs.pop('nmajor', 200),
585618
linsys_solver_kwargs=solver_kwargs,

LoopStructural/interpolators/_geological_interpolator.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,11 @@ def set_value_inequality_constraints(self, points: np.ndarray):
186186
self.data["inequality"] = points
187187
self.up_to_date = False
188188

189-
def set_inequality_pairs_constraints(self, lower_points: np.ndarray, upper_points: np.ndarray):
190-
if lower_points.shape[1] < 3:
191-
raise ValueError("Inequality pairs constraints must at least have X,Y,Z")
192-
if upper_points.shape[1] < 3:
193-
raise ValueError("Inequality pairs constraints must at least have X,Y,Z")
194-
self.data["inequality_pairs_upper"] = upper_points
195-
self.data['inequality_pairs_lower'] = lower_points
189+
def set_inequality_pairs_constraints(self, points: np.ndarray):
190+
if points.shape[1] < 4:
191+
raise ValueError("Inequality pairs constraints must at least have X,Y,Z,rock_id")
192+
193+
self.data["inequality_pairs"] = points
196194
self.up_to_date = False
197195

198196
def get_value_constraints(self):
@@ -256,10 +254,7 @@ def get_inequality_value_constraints(self):
256254
return self.data["inequality"]
257255

258256
def get_inequality_pairs_constraints(self):
259-
return {
260-
'lower': self.data["inequality_pairs_lower"],
261-
'upper': self.data["inequality_pairs_upper"],
262-
}
257+
return self.data["inequality_pairs"]
263258

264259
# @abstractmethod
265260
def setup(self, **kwargs):
@@ -341,8 +336,7 @@ def clean(self):
341336
"tangent": np.zeros((0, 7)),
342337
"interface": np.zeros((0, 5)),
343338
"inequality": np.zeros((0, 6)),
344-
"inequality_pairs_lower": np.zeros((0, 3)),
345-
"inequality_pairs_upper": np.zeros((0, 3)),
339+
"inequality_pairs": np.zeros((0, 4)),
346340
}
347341
self.up_to_date = False
348342
self.n_g = 0

LoopStructural/modelling/features/builders/_geological_feature_builder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
tangent_vec_names,
1919
interface_name,
2020
inequality_name,
21+
pairs_name,
2122
)
2223
from ....modelling.features import GeologicalFeature
2324
from ....modelling.features.builders import BaseBuilder
@@ -283,8 +284,11 @@ def add_data_to_interpolator(self, constrained=False, force_constrained=False, *
283284
mask = np.all(~np.isnan(data.loc[:, inequality_name()].to_numpy(float)), axis=1)
284285
if mask.sum() > 0:
285286
inequality_data = data.loc[mask, xyz_names() + inequality_name()].to_numpy(float)
286-
self.interpolator.set_inequality_constraints(inequality_data)
287-
287+
self.interpolator.set_value_inequality_constraints(inequality_data)
288+
mask = np.all(~np.isnan(data.loc[:, pairs_name()].to_numpy(float)), axis=1)
289+
if mask.sum() > 0:
290+
pairs_data = data.loc[mask, xyz_names() + pairs_name()].to_numpy(float)
291+
self.interpolator.set_inequality_pairs_constraints(pairs_data)
288292
self.data_added = True
289293
self._up_to_date = False
290294

LoopStructural/utils/helper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,10 @@ def polarity_name():
267267
return ["polarity"]
268268

269269

270+
def pairs_name():
271+
return ["pair_id"]
272+
273+
270274
def all_heading():
271275
return (
272276
xyz_names()
@@ -280,6 +284,7 @@ def all_heading():
280284
+ interface_name()
281285
+ polarity_name()
282286
+ inequality_name()
287+
+ pairs_name()
283288
)
284289

285290

0 commit comments

Comments
 (0)