Skip to content

[ENH] Convert InterpolationOptions to Pydantic model #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gempy_engine/API/server/main_server_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

# Default interpolation options
range_ = 1
default_interpolation_options: InterpolationOptions = InterpolationOptions(
default_interpolation_options: InterpolationOptions = InterpolationOptions.from_args(
range=range_,
c_o=(range_ ** 2) / 14 / 3,
number_octree_levels=4,
Expand Down
2 changes: 1 addition & 1 deletion gempy_engine/core/data/options/evaluation_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class MeshExtractionMaskingOptions(enum.Enum):
class EvaluationOptions:
_number_octree_levels: int = 1
_number_octree_levels_surface: int = 4
octree_curvature_threshold: float = -1 #: Threshold to do octree refinement due to curvature to deal with angular geometries. This curvature assumes that 1 is the maximum curvature of any voxel
octree_curvature_threshold: float = -1. #: Threshold to do octree refinement due to curvature to deal with angular geometries. This curvature assumes that 1 is the maximum curvature of any voxel
octree_error_threshold: float = 1. #: Number of standard deviations to consider a voxel as candidate to refine
octree_min_level: int = 2

Expand Down
97 changes: 60 additions & 37 deletions gempy_engine/core/data/options/interpolation_options.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum
import warnings
from dataclasses import dataclass, asdict, field

from pydantic import BaseModel, ConfigDict, Field, model_validator, PrivateAttr

import gempy_engine.config
from .evaluation_options import MeshExtractionMaskingOptions, EvaluationOptions
Expand All @@ -10,40 +11,50 @@
from ..raw_arrays_solution import RawArraysSolution


@dataclass
class InterpolationOptions:
__slots__ = ['kernel_options', 'evaluation_options', 'temp_interpolation_values', 'debug',
'cache_mode', 'cache_model_name', 'block_solutions_type', 'sigmoid_slope']

class InterpolationOptions(BaseModel):
class CacheMode(enum.Enum):
""" Cache mode for the interpolation"""
NO_CACHE: int = enum.auto() #: No cache at all even during the interpolation computation. This is quite expensive for no good reason.
CACHE = enum.auto()
IN_MEMORY_CACHE = enum.auto()
CLEAR_CACHE = enum.auto()

model_config = ConfigDict(
arbitrary_types_allowed=False,
use_enum_values=False,
json_encoders={
CacheMode: lambda e: e.value,
AvailableKernelFunctions: lambda e: e.name
}
)

# @off
kernel_options: KernelOptions # * This is the compression of the fields above and the way to go in the future
evaluation_options: EvaluationOptions
temp_interpolation_values: TempInterpolationValues
kernel_options: KernelOptions = Field(init=True, exclude=False) # * This is the compression of the fields above and the way to go in the future
evaluation_options: EvaluationOptions = Field(init=True, exclude= False)

debug: bool
cache_mode: CacheMode
cache_model_name: str # : Model name for the cache

block_solutions_type: RawArraysSolution.BlockSolutionType

sigmoid_slope: int

debug_water_tight: bool = False

def __init__(
self,
# region Volatile
temp_interpolation_values: TempInterpolationValues = Field(
default_factory=TempInterpolationValues,
exclude=True
)

# endregion

@classmethod
def from_args(
cls,
range: int | float,
c_o: float,
uni_degree: int = 1,
i_res: float = 4,
gi_res: float = 2, # ! This should be DEP
i_res: float = 4.,
gi_res: float = 2., # ! This should be DEP
number_dimensions: int = 3, # ? This probably too
number_octree_levels: int = 1,
kernel_function: AvailableKernelFunctions = AvailableKernelFunctions.cubic,
Expand All @@ -52,7 +63,7 @@ def __init__(
compute_condition_number: bool = False,
):

self.kernel_options = KernelOptions(
kernel_options = KernelOptions(
range=range,
c_o=c_o,
uni_degree=uni_degree,
Expand All @@ -63,7 +74,7 @@ def __init__(
compute_condition_number=compute_condition_number
)

self.evaluation_options = EvaluationOptions(
evaluation_options = EvaluationOptions(
_number_octree_levels=number_octree_levels,
_number_octree_levels_surface=4,
mesh_extraction=mesh_extraction,
Expand All @@ -73,18 +84,30 @@ def __init__(

)

self.temp_interpolation_values = TempInterpolationValues()
self.debug = gempy_engine.config.DEBUG_MODE
self.cache_mode = InterpolationOptions.CacheMode.IN_MEMORY_CACHE
self.cache_model_name = ""
self.block_solutions_type = RawArraysSolution.BlockSolutionType.OCTREE
self.sigmoid_slope = 5_000_000
temp_interpolation_values = TempInterpolationValues()
debug = gempy_engine.config.DEBUG_MODE
cache_mode = InterpolationOptions.CacheMode.IN_MEMORY_CACHE
cache_model_name = ""
block_solutions_type = RawArraysSolution.BlockSolutionType.OCTREE
sigmoid_slope = 5_000_000

return InterpolationOptions(
kernel_options=kernel_options,
evaluation_options=evaluation_options,
# temp_interpolation_values=temp_interpolation_values,
debug=debug,
cache_mode=cache_mode,
cache_model_name=cache_model_name,
block_solutions_type=block_solutions_type,
sigmoid_slope=sigmoid_slope,
debug_water_tight=False,
)

# @on

@classmethod
def init_octree_options(cls, range=1.7, c_o=10, refinement: int = 1):
return InterpolationOptions(
return InterpolationOptions.from_args(
range=range,
c_o=c_o,
mesh_extraction=True,
Expand All @@ -93,7 +116,7 @@ def init_octree_options(cls, range=1.7, c_o=10, refinement: int = 1):

@classmethod
def init_dense_grid_options(cls):
options = InterpolationOptions(
options = InterpolationOptions.from_args(
range=1.7,
c_o=10,
mesh_extraction=False,
Expand All @@ -107,17 +130,17 @@ def probabilistic_options(cls):
# TODO: This should have the sigmoid slope different
raise NotImplementedError("Probabilistic interpolation is not yet implemented.")

def __repr__(self):
return f"InterpolationOptions({', '.join(f'{k}={v}' for k, v in asdict(self).items())})"

def _repr_html_(self):
html = f"""
<table>
<tr><td colspan='2' style='text-align:center'><b>InterpolationOptions</b></td></tr>
{''.join(f'<tr><td>{k}</td><td>{v._repr_html_() if isinstance(v, KernelOptions) else v}</td></tr>' for k, v in asdict(self).items())}
</table>
"""
return html
# def __repr__(self):
# return f"InterpolationOptions.from_args({', '.join(f'{k}={v}' for k, v in asdict(self).items())})"

# def _repr_html_(self):
# html = f"""
# <table>
# <tr><td colspan='2' style='text-align:center'><b>InterpolationOptions</b></td></tr>
# {''.join(f'<tr><td>{k}</td><td>{v._repr_html_() if isinstance(v, KernelOptions) else v}</td></tr>' for k, v in asdict(self).items())}
# </table>
# """
# return html

def update_options(self, **kwargs):
"""
Expand Down
43 changes: 32 additions & 11 deletions gempy_engine/core/data/options/kernel_options.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import warnings

from dataclasses import dataclass, asdict
from typing import Optional

from pydantic import field_validator

from gempy_engine.core.data.kernel_classes.kernel_functions import AvailableKernelFunctions
from gempy_engine.core.data.kernel_classes.solvers import Solvers


@dataclass(frozen=False)
class KernelOptions:
range: int # TODO: have constructor from RegularGrid
range: int | float # TODO: have constructor from RegularGrid
c_o: float # TODO: This should be a property
uni_degree: int = 1
i_res: float = 4.
Expand All @@ -20,7 +23,25 @@ class KernelOptions:

compute_condition_number: bool = False
optimizing_condition_number: bool = False
condition_number: float = None
condition_number: Optional[float] = None

@field_validator('kernel_function', mode='before')
@classmethod
def _deserialize_kernel_function_from_name(cls, value):
"""
Ensures that a string input (e.g., "cubic" from JSON)
is correctly converted to an AvailableKernelFunctions enum member.
"""
if isinstance(value, str):
try:
return AvailableKernelFunctions[value] # Lookup enum member by name
except KeyError:
# This provides a more specific error if the name doesn't exist
valid_names = [member.name for member in AvailableKernelFunctions]
raise ValueError(f"Invalid kernel function name '{value}'. Must be one of: {valid_names}")
# If it's already an AvailableKernelFunctions member (e.g., during direct model instantiation),
# or if it's another type that Pydantic's later validation will catch as an error.
return value

@property
def n_uni_eq(self):
Expand Down Expand Up @@ -65,16 +86,16 @@ def update_options(self, **kwargs):
def __hash__(self):
# Using a tuple to hash all the values together
return hash((
self.range,
self.c_o,
self.uni_degree,
self.i_res,
self.gi_res,
self.number_dimensions,
self.kernel_function,
self.compute_condition_number,
self.range,
self.c_o,
self.uni_degree,
self.i_res,
self.gi_res,
self.number_dimensions,
self.kernel_function,
self.compute_condition_number,
))

def __repr__(self):
return f"KernelOptions({', '.join(f'{k}={v}' for k, v in asdict(self).items())})"

Expand Down
4 changes: 4 additions & 0 deletions gempy_engine/core/data/options/temp_interpolation_values.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
from dataclasses import dataclass


@dataclass
class TempInterpolationValues:
current_octree_level: int = 0 # * Make this a read only property
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def simple_model_2():
ori_i = Orientations(dip_positions, nugget_effect_grad)

range = 5 ** 2
kri = InterpolationOptions(range, 1, 0, i_res=1, gi_res=1,
kri = InterpolationOptions.from_args(range, 1, 0, i_res=1, gi_res=1,
number_dimensions=2)

_ = np.ones(3)
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
numpy
pydantic
python-dotenv
2 changes: 1 addition & 1 deletion tests/benchmark/one_fault_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def one_fault_model():
range_ = 7 ** 2 # ? Since we are not getting the square root should we also square this?
c_o = 1

options = InterpolationOptions(
options = InterpolationOptions.from_args(
range_, c_o,
uni_degree=1,
number_dimensions=3,
Expand Down
6 changes: 3 additions & 3 deletions tests/fixtures/complex_geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def one_fault_model():
range_ = 7 ** 2 # ? Since we are not getting the square root should we also square this?
c_o = 1

options = InterpolationOptions(
options = InterpolationOptions.from_args(
range_, c_o,
uni_degree=1,
number_dimensions=3,
Expand Down Expand Up @@ -144,7 +144,7 @@ def one_finite_fault_model():
range_ = 7 ** 2 # ? Since we are not getting the square root should we also square this?
c_o = 1

options = InterpolationOptions(
options = InterpolationOptions.from_args(
range_, c_o,
uni_degree=1,
number_dimensions=3,
Expand Down Expand Up @@ -211,7 +211,7 @@ def graben_fault_model():
range_ = 7 ** 2 # ? Since we are not getting the square root should we also square this?
c_o = 1

options = InterpolationOptions(
options = InterpolationOptions.from_args(
range_, c_o,
uni_degree=1,
number_dimensions=3,
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/heavy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def moureze_model_factory(path_to_root: str, pick_every=8, octree_lvls=3, solver
# endregion

# region InterpolationOptions
interpolation_options: InterpolationOptions = InterpolationOptions(
interpolation_options: InterpolationOptions = InterpolationOptions.from_args(
range=100.,
c_o=10.,
number_octree_levels=octree_lvls,
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/simple_geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def unconformity() -> Tuple[InterpolationInput, InterpolationOptions, InputDataD
i_r = 4
gi_r = 2

options = InterpolationOptions(range_, c_o, uni_degree=1, i_res=i_r, gi_res=gi_r,
options = InterpolationOptions.from_args(range_, c_o, uni_degree=1, i_res=i_r, gi_res=gi_r,
number_dimensions=3,
kernel_function=AvailableKernelFunctions.cubic)

Expand Down
Loading