Skip to content

Commit 245517b

Browse files
committed
fix: ✨ adding interpolator factory
1 parent c3ec7a8 commit 245517b

File tree

8 files changed

+117
-71
lines changed

8 files changed

+117
-71
lines changed

LoopStructural/interpolators/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,14 @@ class InterpolatorType(IntEnum):
7878
from ._p1interpolator import P1Interpolator
7979
from ._p2interpolator import P2Interpolator
8080
from ._builders import get_interpolator
81+
82+
interpolator_map = {
83+
InterpolatorType.BASE: GeologicalInterpolator,
84+
InterpolatorType.BASE_DISCRETE: DiscreteInterpolator,
85+
InterpolatorType.FINITE_DIFFERENCE: FiniteDifferenceInterpolator,
86+
InterpolatorType.DISCRETE_FOLD: DiscreteFoldInterpolator,
87+
InterpolatorType.PIECEWISE_LINEAR: PiecewiseLinearInterpolator,
88+
InterpolatorType.PIECEWISE_QUADRATIC: PiecewiseLinearInterpolator,
89+
InterpolatorType.BASE_DATA_SUPPORTED: GeologicalInterpolator,
90+
# InterpolatorType.SURFE: SurfeRBFInterpolator,
91+
}

LoopStructural/interpolators/_discrete_fold_interpolator.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,7 @@ def add_fold_constraints(
144144
"""
145145
np.random.shuffle(element_idx)
146146

147-
logger.info(
148-
f"Adding fold orientation constraint to {self.propertyname} w = {fold_orientation}"
149-
)
147+
logger.info(f"Adding fold orientation constraint to w = {fold_orientation}")
150148
A = np.einsum(
151149
"ij,ijk->ik",
152150
deformed_orientation[element_idx[::step], :],
@@ -165,9 +163,7 @@ def add_fold_constraints(
165163
"""
166164
np.random.shuffle(element_idx)
167165

168-
logger.info(
169-
f"Adding fold axis constraint to {self.propertyname} w = {fold_axis_w}"
170-
)
166+
logger.info(f"Adding fold axis constraint to w = {fold_axis_w}")
171167
A = np.einsum(
172168
"ij,ijk->ik",
173169
fold_axis[element_idx[::step], :],
@@ -188,7 +184,7 @@ def add_fold_constraints(
188184
np.random.shuffle(element_idx)
189185

190186
logger.info(
191-
f"Adding fold normalisation constraint to {self.propertyname} w = {fold_normalisation}"
187+
f"Adding fold normalisation constraint to w = {fold_normalisation}"
192188
)
193189
A = np.einsum(
194190
"ij,ijk->ik", dgz[element_idx[::step], :], eg[element_idx[::step], :, :]
@@ -212,7 +208,7 @@ def add_fold_constraints(
212208
fold constant gradient
213209
"""
214210
logger.info(
215-
f"Adding fold regularisation constraint to {self.propertyname} w = {fold_regularisation[0]} {fold_regularisation[1]} {fold_regularisation[2]}"
211+
f"Adding fold regularisation constraint to w = {fold_regularisation[0]} {fold_regularisation[1]} {fold_regularisation[2]}"
216212
)
217213

218214
idc, c, ncons = fold_cg(

LoopStructural/interpolators/_discrete_interpolator.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
class DiscreteInterpolator(GeologicalInterpolator):
2222
""" """
2323

24-
def __init__(self, support):
24+
def __init__(self, support, data={}, c=None, up_to_date=False):
2525
"""
2626
Base class for a discrete interpolator e.g. piecewise linear or finite difference which is
2727
any interpolator that solves the system using least squares approximation
@@ -31,9 +31,14 @@ def __init__(self, support):
3131
support
3232
A discrete mesh with, nodes, elements, etc
3333
"""
34-
GeologicalInterpolator.__init__(self)
34+
GeologicalInterpolator.__init__(self, data=data, up_to_date=up_to_date)
3535
self.B = []
3636
self.support = support
37+
self.c = (
38+
np.array(c)
39+
if c is not None and np.array(c).shape[0] == self.support.n_nodes
40+
else np.zeros(self.support.n_nodes)
41+
)
3742
self.region_function = lambda xyz: np.ones(xyz.shape[0], dtype=bool)
3843
# self.region_map[self.region] = np.array(range(0,
3944
# len(self.region_map[self.region])))
@@ -95,21 +100,6 @@ def region_map(self):
95100
region_map[self.region] = np.array(range(0, len(region_map[self.region])))
96101
return region_map
97102

98-
def set_property_name(self, propertyname):
99-
"""
100-
Set the property name attribute, this is usually used to
101-
save the property on the support
102-
103-
Parameters
104-
----------
105-
propertyname
106-
107-
Returns
108-
-------
109-
110-
"""
111-
self.propertyname = propertyname
112-
113103
def set_region(self, region=None):
114104
"""
115105
Set the region of the support the interpolator is working on
@@ -724,7 +714,7 @@ def _solve_pyamg(self, A, B, tol=1e-12, x0=None, verb=False, **kwargs):
724714
logger.info("Solving using pyamg: tol {}".format(tol))
725715
return pyamg.solve(A, B, tol=tol, x0=x0, verb=verb)[: self.nx]
726716

727-
def _solve(self, solver="cg", **kwargs):
717+
def solve_system(self, solver="cg", **kwargs):
728718
"""
729719
Main entry point to run the solver and update the node value
730720
attribute for the
@@ -784,24 +774,18 @@ def _solve(self, solver="cg", **kwargs):
784774
P, A, q, l, u, mkl=kwargs.get("mkl", False)
785775
) # , **kwargs)
786776
# check solution is not nan
787-
# self.support.properties[self.propertyname] = self.c
788777
if np.all(self.c == np.nan):
789778
self.valid = False
790779
logger.warning("Solver not run, no scalar field")
791780
return
792781
# if solution is all 0, probably didn't work
793782
if np.all(self.c[self.region] == 0):
794783
self.valid = False
795-
logger.warning(
796-
"No solution, {} scalar field 0. Add more data.".format(
797-
self.propertyname
798-
)
799-
)
784+
logger.warning("No solution, scalar field 0. Add more data.")
785+
800786
return
801787
self.valid = True
802-
logging.info(
803-
f"Solving interpolation: {self.propertyname} took: {time()-starttime}"
804-
)
788+
logging.info(f"Solving interpolation took: {time()-starttime}")
805789
self.up_to_date = True
806790

807791
def update(self):
@@ -822,7 +806,7 @@ def update(self):
822806
return False
823807
if not self.up_to_date:
824808
self.setup_interpolator()
825-
return self._solve(self.solver)
809+
return self.solve_system(self.solver)
826810

827811
def evaluate_value(self, evaluation_points: np.ndarray) -> np.ndarray:
828812
"""Evaluate the value of the interpolator at location
@@ -864,3 +848,12 @@ def evaluate_gradient(self, evaluation_points: np.ndarray) -> np.ndarray:
864848
if evaluation_points.shape[0] > 0:
865849
return self.support.evaluate_gradient(evaluation_points, self.c)
866850
return np.zeros((0, 3))
851+
852+
def to_dict(self):
853+
return {
854+
"type": self.type.name,
855+
"support": self.support.to_dict(),
856+
"c": self.c,
857+
**super().to_dict(),
858+
# 'region_function':self.region_function,
859+
}

LoopStructural/interpolators/_finite_difference_interpolator.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
class FiniteDifferenceInterpolator(DiscreteInterpolator):
20-
def __init__(self, grid):
20+
def __init__(self, grid, data={}):
2121
"""
2222
Finite difference interpolation on a regular cartesian grid
2323
@@ -26,7 +26,7 @@ def __init__(self, grid):
2626
grid : StructuredGrid
2727
"""
2828
self.shape = "rectangular"
29-
DiscreteInterpolator.__init__(self, grid)
29+
DiscreteInterpolator.__init__(self, grid, data=data)
3030
# default weights for the interpolation matrix are 1 in x,y,z and
3131
# 1/
3232
self.set_interpolation_weights(
@@ -52,7 +52,7 @@ def __init__(self, grid):
5252
# grid.step_vector[2]
5353
self.type = InterpolatorType.FINITE_DIFFERENCE
5454

55-
def _setup_interpolator(self, **kwargs):
55+
def setup_interpolator(self, **kwargs):
5656
"""
5757
5858
Parameters
@@ -95,7 +95,6 @@ def _setup_interpolator(self, **kwargs):
9595
self.assemble_inner(o[0], o[1])
9696
# otherwise just use defaults
9797
if "operators" not in kwargs:
98-
9998
operator = Operator.Dxy_mask
10099
weight = (
101100
self.interpolation_weights["dxy"] / 4
@@ -179,7 +178,7 @@ def add_vaue_constraints(self, w=1.0):
179178
)
180179
if np.sum(inside) <= 0:
181180
logger.warning(
182-
f"{self.propertyname}: {np.sum(~inside)} \
181+
f"{np.sum(~inside)} \
183182
value constraints not added: outside of model bounding box"
184183
)
185184

@@ -218,7 +217,7 @@ def add_inequality_constraints(self, w=1.0):
218217
)
219218
if np.sum(inside) <= 0:
220219
logger.warning(
221-
f"{self.propertyname}: {np.sum(~inside)} \
220+
f"{np.sum(~inside)} \
222221
value constraints not added: outside of model bounding box"
223222
)
224223

@@ -335,7 +334,7 @@ def add_gradient_constraints(self, w=1.0):
335334
)
336335
if np.sum(inside) <= 0:
337336
logger.warning(
338-
f"{self.propertyname}: {np.sum(~inside)} \
337+
f" {np.sum(~inside)} \
339338
norm constraints not added: outside of model bounding box"
340339
)
341340

@@ -400,7 +399,7 @@ def add_norm_constraints(self, w=1.0):
400399
)
401400
if np.sum(inside) <= 0:
402401
logger.warning(
403-
f"{self.propertyname}: {np.sum(~inside)} \
402+
f"{np.sum(~inside)} \
404403
norm constraints not added: outside of model bounding box"
405404
)
406405
self.up_to_date = False
@@ -459,7 +458,7 @@ def add_gradient_orthogonal_constraints(
459458
)
460459
if np.sum(inside) <= 0:
461460
logger.warning(
462-
f"{self.propertyname}: {np.sum(~inside)} \
461+
f"{np.sum(~inside)} \
463462
gradient constraints not added: outside of model bounding box"
464463
)
465464
self.up_to_date = False

LoopStructural/interpolators/_geological_interpolator.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""
22
Base geological interpolator
33
"""
4+
from abc import ABC, ABCMeta, abstractmethod
5+
from LoopStructural.utils.exceptions import LoopTypeError
46
from ..interpolators import InterpolatorType
57
import numpy as np
68

@@ -9,22 +11,24 @@
911
logger = getLogger(__name__)
1012

1113

12-
class GeologicalInterpolator:
14+
class GeologicalInterpolator(metaclass=ABCMeta):
1315
"""
1416
Attributes
1517
----------
1618
data : dict
1719
a dictionary with np.arrays for gradient, value, normal, tangent data
1820
"""
1921

20-
def __init__(self):
22+
@abstractmethod
23+
def __init__(self, data={}, up_to_date=False):
2124
"""
2225
This class is the base class for a geological interpolator and contains
2326
all of the main interface functions. Any class that is inheriting from
2427
this should be callable by using any of these functions. This will
2528
enable interpolators to be interchanged.
2629
"""
27-
self.data = {} # None
30+
self._data = {}
31+
self.data = data # None
2832
self.clean() # init data structure
2933

3034
self.n_g = 0
@@ -33,12 +37,22 @@ def __init__(self):
3337
self.n_t = 0
3438

3539
self.type = InterpolatorType.BASE
36-
self.up_to_date = False
40+
self.up_to_date = up_to_date
3741
self.constraints = []
38-
self.propertyname = "defaultproperty"
3942
self.__str = "Base Geological Interpolator"
4043
self.valid = False
4144

45+
@property
46+
def data(self):
47+
return self._data
48+
49+
@data.setter
50+
def data(self, data):
51+
if data is None:
52+
data = {}
53+
for k, v in data.items():
54+
self._data[k] = np.array(v)
55+
4256
def __str__(self):
4357
name = f"{self.type} \n"
4458
name += f"{self.n_g} gradient points\n"
@@ -48,23 +62,16 @@ def __str__(self):
4862
name += f"{self.n_g + self.n_i + self.n_n + self.n_t} total points\n"
4963
return name
5064

65+
def check_array(self, array: np.ndarray):
66+
try:
67+
return np.array(array)
68+
except Exception as e:
69+
raise LoopTypeError(str(e))
70+
71+
@abstractmethod
5172
def set_region(self, **kwargs):
5273
pass
5374

54-
def set_property_name(self, name):
55-
"""
56-
Set the name of the interpolated property
57-
Parameters
58-
----------
59-
name : string
60-
name of the property to be saved on a mesh
61-
62-
Returns
63-
-------
64-
65-
"""
66-
self.propertyname = name
67-
6875
def set_value_constraints(self, points: np.ndarray):
6976
"""
7077
@@ -78,6 +85,7 @@ def set_value_constraints(self, points: np.ndarray):
7885
-------
7986
8087
"""
88+
points = self.check_array(points)
8189
if points.shape[1] < 4:
8290
raise ValueError("Value points must at least have X,Y,Z,val")
8391
self.data["value"] = points
@@ -208,34 +216,48 @@ def get_interface_constraints(self):
208216
def get_inequality_constraints(self):
209217
return self.data["inequality"]
210218

219+
# @abstractmethod
211220
def setup(self, **kwargs):
212221
"""
213222
Runs all of the required setting up stuff
214223
"""
215-
self._setup_interpolator(**kwargs)
224+
self.setup_interpolator(**kwargs)
216225

226+
@abstractmethod
217227
def setup_interpolator(self, **kwargs):
218228
"""
219229
Runs all of the required setting up stuff
220230
"""
221-
self._setup_interpolator(**kwargs)
231+
self.setup_interpolator(**kwargs)
222232

233+
@abstractmethod
223234
def solve_system(self, **kwargs):
224235
"""
225236
Solves the interpolation equations
226237
"""
227238
self._solve(**kwargs)
228239
self.up_to_date = True
229240

241+
@abstractmethod
230242
def update(self):
231243
return False
232244

245+
@abstractmethod
233246
def evaluate_value(self, locations: np.ndarray):
234247
raise NotImplementedError("evaluate_value not implemented")
235248

249+
@abstractmethod
236250
def evaluate_gradient(self, locations: np.ndarray):
237251
raise NotImplementedError("evaluate_gradient not implemented")
238252

253+
def to_dict(self):
254+
return {
255+
"type": self.type,
256+
"data": self.data,
257+
"up_to_date": self.up_to_date,
258+
"valid": self.valid,
259+
}
260+
239261
def clean(self):
240262
"""
241263
Removes all of the data from an interpolator

0 commit comments

Comments
 (0)