Skip to content

Commit 2790d76

Browse files
committed
fix: moving get interpolator to separate function
1 parent 0d05ebc commit 2790d76

File tree

6 files changed

+104
-14
lines changed

6 files changed

+104
-14
lines changed

LoopStructural/api/_interpolate.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from typing import Optional
66
from LoopStructural.interpolators import (
7+
GeologicalInterpolator,
78
get_interpolator,
89
)
910
from LoopStructural.utils import BoundingBox
@@ -39,7 +40,9 @@ def __init__(
3940
"""
4041
self.dimensions = dimensions
4142
self.type = "FDI"
42-
self.interpolator = get_interpolator(bounding_box, type, dimensions, nelements)
43+
self.interpolator: GeologicalInterpolator = get_interpolator(
44+
bounding_box, type, dimensions, nelements
45+
)
4346

4447
def fit(
4548
self,
@@ -48,6 +51,19 @@ def fit(
4851
normal_vectors: Optional[np.ndarray] = None,
4952
inequality_constraints: Optional[np.ndarray] = None,
5053
):
54+
"""_summary_
55+
56+
Parameters
57+
----------
58+
values : Optional[np.ndarray], optional
59+
Value constraints for implicit function, by default None
60+
tangent_vectors : Optional[np.ndarray], optional
61+
tangent constraints for implicit function, by default None
62+
normal_vectors : Optional[np.ndarray], optional
63+
gradient norm constraints for implicit function, by default None
64+
inequality_constraints : Optional[np.ndarray], optional
65+
_description_, by default None
66+
"""
5167
if values is not None:
5268
self.interpolator.set_value_constraints(values)
5369
if tangent_vectors is not None:
@@ -59,11 +75,35 @@ def fit(
5975

6076
self.interpolator.setup()
6177

62-
def evalute_scalar_value(self, locations: np.ndarray):
78+
def evalute_scalar_value(self, locations: np.ndarray) -> np.ndarray:
79+
"""Evaluate the value of the interpolator at locations
80+
81+
Parameters
82+
----------
83+
locations : np.ndarray
84+
Nx3 array of locations to evaluate the interpolator at
85+
86+
Returns
87+
-------
88+
np.ndarray
89+
value of implicit function at locations
90+
"""
6391
self.interpolator.update()
6492
return self.interpolator.evaluate_value(locations)
6593

66-
def evaluate_gradient(self, locations: np.ndarray):
94+
def evaluate_gradient(self, locations: np.ndarray) -> np.ndarray:
95+
"""Evaluate the gradient of the interpolator at locations
96+
97+
Parameters
98+
----------
99+
locations : np.ndarray
100+
Nx3 locations
101+
102+
Returns
103+
-------
104+
np.ndarray
105+
Nx3 gradient of implicit function
106+
"""
67107
self.interpolator.update()
68108
return self.interpolator.evaluate_gradient(locations)
69109

@@ -81,7 +121,7 @@ def fit_and_evaluate_value(
81121
normal_vectors=normal_vectors,
82122
inequality_constraints=inequality_constraints,
83123
)
84-
124+
locations = self.interpolator.get_data_locations()
85125
return self.evalute_scalar_value(locations)
86126

87127
def fit_and_evaluate_gradient(
@@ -97,4 +137,5 @@ def fit_and_evaluate_gradient(
97137
normal_vectors=normal_vectors,
98138
inequality_constraints=inequality_constraints,
99139
)
140+
locations = self.interpolator.get_data_locations()
100141
return self.evaluate_gradient(locations)

LoopStructural/interpolators/_builders.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from LoopStructural.utils.exceptions import LoopException
12
import numpy as np
23
from typing import Optional, Union
34
from LoopStructural.interpolators import (
@@ -18,12 +19,12 @@
1819
def get_interpolator(
1920
bounding_box: BoundingBox,
2021
interpolatortype: str,
21-
dimensions: int,
2222
nelements: int,
2323
element_volume: Optional[float] = None,
2424
buffer: float = 0.2,
25+
dimensions: int = 3,
2526
support=None,
26-
):
27+
) -> GeologicalInterpolator:
2728
# add a buffer to the interpolation domain, this is necessary for
2829
# faults but also generally a good
2930
# idea to avoid boundary problems
@@ -97,7 +98,7 @@ def get_interpolator(
9798
"for modelling using FDI"
9899
)
99100
return FiniteDifferenceInterpolator(grid)
100-
101+
raise LoopException("No interpolator")
101102
# fi interpolatortype == "DFI" and dfi is True:
102103
# if element_volume is None:
103104
# nelements /= 5

LoopStructural/interpolators/_geological_interpolator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,12 @@ def solve_system(self, **kwargs):
230230
def update(self):
231231
return False
232232

233+
def evaluate_value(self, locations: np.ndarray):
234+
raise NotImplementedError("evaluate_value not implemented")
235+
236+
def evaluate_gradient(self, locations: np.ndarray):
237+
raise NotImplementedError("evaluate_gradient not implemented")
238+
233239
def clean(self):
234240
"""
235241
Removes all of the data from an interpolator

LoopStructural/modelling/core/geological_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ def get_interpolator(
890890
raise InterpolatorError("Could not create interpolator")
891891

892892
def create_and_add_foliation(
893-
self, series_surface_data, tol=None, faults=None, **kwargs
893+
self, series_surface_data:str, interpolatortype:str='FDI',nelements:int=1000, tol=None, faults=None, **kwargs
894894
):
895895
"""
896896
Parameters
@@ -923,9 +923,8 @@ def create_and_add_foliation(
923923
if tol is None:
924924
tol = self.tol
925925

926-
interpolator = self.get_interpolator(**kwargs)
927926
series_builder = GeologicalFeatureBuilder(
928-
interpolator, name=series_surface_data, **kwargs
927+
bounding_box=self.bounding_box, interpolatortype=interpolatortype,nelements=nelements, name=series_surface_data, **kwargs
929928
)
930929
# add data
931930
series_data = self.data[self.data["feature_name"] == series_surface_data]
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
class GeologicalModelBuilderFactory:
2+
def __init__(self):
3+
pass
4+
5+
def create_builder(self):
6+
pass
7+
8+
class GeologicalModelBuilder:
9+
def __init__(self):
10+
pass
11+
12+
def set_model_name(self, name):
13+
pass
14+
15+
def add_surface(self, surface):
16+
pass
17+
18+
def add_fault(self, fault):
19+
pass
20+
21+
def build(self):
22+
pass
23+
24+
class GeologicalModel:
25+
def __init__(self, name):
26+
pass
27+
28+
def add_surface(self, surface):
29+
pass
30+
31+
def add_fault(self, fault):
32+
pass
33+
34+
def get_surface(self, name):
35+
pass
36+
37+
def get_fault(self, name):
38+
pass

LoopStructural/modelling/features/builders/_geological_feature_builder.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,17 @@
2626
)
2727
from ....utils import RegionEverywhere
2828
from ....interpolators import DiscreteInterpolator
29+
from ....interpolators import get_interpolator
2930

3031
logger = getLogger(__name__)
3132

3233

3334
class GeologicalFeatureBuilder(BaseBuilder):
3435
def __init__(
3536
self,
36-
interpolator: GeologicalInterpolator,
37+
interpolatortype: str,
38+
bounding_box,
39+
nelements: int = 1000,
3740
name="Feature",
3841
interpolation_region=None,
3942
**kwargs,
@@ -56,7 +59,11 @@ def __init__(
5659
type(interpolator)
5760
)
5861
)
59-
self._interpolator = interpolator
62+
self._interpolator = get_interpolator(
63+
bounding_box=bounding_box,
64+
interpolatortype=interpolatortype,
65+
nelements=nelements,
66+
)
6067
self._interpolator.set_property_name(self._name)
6168
# everywhere region is just a lambda that returns true for all locations
6269

@@ -198,7 +205,6 @@ def add_data_to_interpolator(
198205
# change gradient constraints to normal vector constraints
199206
mask = np.all(~np.isnan(data.loc[:, gradient_vec_names()]), axis=1)
200207
if mask.shape[0] > 0:
201-
202208
data.loc[mask, normal_vec_names()] = data.loc[
203209
mask, gradient_vec_names()
204210
].to_numpy(float)
@@ -288,7 +294,6 @@ def install_gradient_constraint(self):
288294
)
289295

290296
def add_equality_constraints(self, feature, region, scalefactor=1.0):
291-
292297
self._equality_constraints[feature.name] = [feature, region, scalefactor]
293298
self._up_to_date = False
294299

0 commit comments

Comments
 (0)