Skip to content

[ENH] Add soft segment activation function for improved layer segmentation #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions gempy_engine/modules/activator/_soft_segment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import numbers

import numpy as np

from ...core.backend_tensor import BackendTensor as bt, BackendTensor

try:
import torch
except ModuleNotFoundError:
pass


def soft_segment_unbounded(Z, edges, ids, sigmoid_slope):
"""
Z: array of shape (...,) of scalar values
edges: array of shape (K-1,) of finite split points [e1, e2, ..., e_{K-1}]
ids: array of shape (K,) of the id for each of the K bins
sigmoid_slope: scalar target peak slope m > 0
returns: array of shape (...,) of the soft-assigned id
"""
ids = bt.t.array(ids[::-1].copy())

# Check if sigmoid function is num or array
match sigmoid_slope:
case numbers.Number():
membership = _lith_segmentation(Z, edges, ids, sigmoid_slope)
case _ if isinstance(sigmoid_slope, (np.ndarray, torch.Tensor)):
membership = _final_faults_segmentation(Z, edges, sigmoid_slope)
case _:
raise ValueError("sigmoid_slope must be a float or an array")

ids__sum = bt.t.sum(membership * ids, axis=-1)
return ids__sum[None, :]


def _final_faults_segmentation(Z, edges, sigmoid_slope):
first = _sigmoid(
scalar_field=Z,
edges=edges[0],
tau_k=1 / sigmoid_slope
) # shape (...,)
last = _sigmoid(
scalar_field=Z,
edges=edges[-1],
tau_k=1 / sigmoid_slope
)
membership = bt.t.concatenate(
[first[..., None], last[..., None]],
axis=-1
) # shape (...,K)
return membership


def _lith_segmentation(Z, edges, ids, sigmoid_slope):
# 1) per-edge temperatures τ_k = |Δ_k|/(4·m)
jumps = bt.t.abs(ids[1:] - ids[:-1]) # shape (K-1,)
tau_k = jumps / float(sigmoid_slope) # shape (K-1,)
# 2) first bin (-∞, e1) via σ((e1 - Z)/τ₁)
first = _sigmoid(
scalar_field=-Z,
edges=-edges[0],
tau_k=tau_k[0]
) # shape (...,)
# 3) last bin [e_{K-1}, ∞) via σ((Z - e_{K-1})/τ_{K-1})
# last = 1.0 / (1.0 + np.exp(-(Z - edges[-1]) / tau_k[-1])) # shape (...,)
last = _sigmoid(
scalar_field=Z,
edges=edges[-1],
tau_k=tau_k[-1]
)
# 4) middle bins [e_i, e_{i+1}): σ((Z - e_i)/τ_i) - σ((Z - e_{i+1})/τ_{i+1})
# shape (...,1)
left = _sigmoid(
scalar_field=(Z[..., None]),
edges=edges[:-1],
tau_k=tau_k[:-1]
)
right = _sigmoid(
scalar_field=(Z[..., None]),
edges=edges[1:],
tau_k=tau_k[1:]
)
middle = left - right # (...,K-2)
# 5) assemble memberships and weight by ids
membership = bt.t.concatenate(
[first[..., None], middle, last[..., None]],
axis=-1
) # shape (...,K)
return membership


def _sigmoid(scalar_field, edges, tau_k):
return 1.0 / (1.0 + bt.t.exp(-(scalar_field - edges) / tau_k))
31 changes: 20 additions & 11 deletions gempy_engine/modules/activator/activator_interface.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,34 @@
import warnings

from gempy_engine.config import DEBUG_MODE, AvailableBackends
from gempy_engine.core.backend_tensor import BackendTensor as bt, BackendTensor
import numpy as np
from ...config import DEBUG_MODE, AvailableBackends
from ...core.backend_tensor import BackendTensor as bt, BackendTensor
from ...core.data.exported_fields import ExportedFields
from ._soft_segment import soft_segment_unbounded

from gempy_engine.core.data.exported_fields import ExportedFields
import numpy as np


def activate_formation_block(exported_fields: ExportedFields, ids: np.ndarray,
sigmoid_slope: float) -> np.ndarray:
Z_x: np.ndarray = exported_fields.scalar_field_everywhere
scalar_value_at_sp: np.ndarray = exported_fields.scalar_field_at_surface_points

sigmoid_slope_negative = isinstance(sigmoid_slope, float) and sigmoid_slope < 0 # * sigmoid_slope can be array for finite faultskA

if LEGACY := True and not sigmoid_slope_negative: # * Here we branch to the experimental activation function with hard sigmoid
sigm = activate_formation_block_from_args(Z_x, ids, scalar_value_at_sp, sigmoid_slope)
else:
from .torch_activation import activate_formation_block_from_args_hard_sigmoid
sigm = activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp)
sigmoid_slope_negative = isinstance(sigmoid_slope, float) and sigmoid_slope < 0 # * sigmoid_slope can be array for finite faultskA

if LEGACY := False and not sigmoid_slope_negative: # * Here we branch to the experimental activation function with hard sigmoid
sigm = activate_formation_block_from_args(
Z_x=Z_x,
ids=ids,
scalar_value_at_sp=scalar_value_at_sp,
sigmoid_slope=sigmoid_slope
)
else:
sigm = soft_segment_unbounded(
Z=Z_x,
edges=scalar_value_at_sp,
ids=ids,
sigmoid_slope=sigmoid_slope
)
return sigm


Expand Down
5 changes: 5 additions & 0 deletions gempy_engine/modules/activator/torch_activation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import torch
from ...core.backend_tensor import BackendTensor as bt, BackendTensor

Expand All @@ -14,6 +16,9 @@


def activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp):

warnings.warn(DeprecationWarning("This function is deprecated. Use activate_formation_block instead."))

element_0 = bt.t.array([0], dtype=BackendTensor.dtype_obj)

min_Z_x = BackendTensor.t.min(Z_x, axis=0).reshape(-1) # ? Is this as good as it gets?
Expand Down
4 changes: 2 additions & 2 deletions tests/fixtures/complex_geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def one_fault_model():

spi: SurfacePoints = SurfacePoints(sp_coords)
ori: Orientations = Orientations(dip_postions, dip_gradients)
ids = np.array([1, 2, 3, 4, 5, 6])
ids = np.array([1, 2, 3, 4, 5, 6, 7])

resolution = [2, 2, 2]
extent = np.array([-500, 500., -500, 500, -450, 550]) / rescaling_factor
Expand Down Expand Up @@ -172,7 +172,7 @@ def graben_fault_model():

spi = SurfacePoints(sp_coords)
ori = Orientations(dip_postions, dip_gradients)
ids = np.array([1, 2, 3, 4, 5, 6])
ids = np.array([1, 2, 3, 4, 5, 6, 7])

resolution = [2, 2, 2]
extent = np.array([-500, 500., -500, 500, -450, 550]) / rescaling_factor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_graben_fault_model(graben_fault_model):
options.evaluation_options.dual_conturing_fancy = True
options.debug=True

options.evaluation_options.number_octree_levels = 4
options.evaluation_options.number_octree_levels = 5
solutions: Solutions = compute_model(interpolation_input, options, structure)

outputs: list[OctreeLevel] = solutions.octrees_output
Expand Down
4 changes: 2 additions & 2 deletions tests/test_common/test_api/test_faults/test_one_fault.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from gempy_engine.plugins.plotting.helper_functions import plot_block_and_input_2d, plot_scalar_and_input_2d


def test_one_fault_model(one_fault_model, n_oct_levels=3):
def test_one_fault_model(one_fault_model, n_oct_levels=5):
interpolation_input: InterpolationInput
structure: InputDataDescriptor
options: InterpolationOptions
Expand All @@ -44,7 +44,7 @@ def test_one_fault_model(one_fault_model, n_oct_levels=3):
gempy_v2_cov = _covariance_for_one_fault_model_from_gempy_v2()
diff = last_cov - gempy_v2_cov

if plot_2d := False:
if plot_2d := True:
_plot_stack_raw(interpolation_input, outputs, structure)
_plot_stack_squeezed_mask(interpolation_input, outputs, structure)
_plot_stack_mask_component(interpolation_input, outputs, structure)
Expand Down
100 changes: 100 additions & 0 deletions tests/test_common/test_modules/test_activator_fns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import dataclasses
import os

import matplotlib.pyplot as plt
import numpy as np

from gempy_engine.API.interp_single._interp_scalar_field import _solve_interpolation, _evaluate_sys_eq
from gempy_engine.API.interp_single._interp_single_feature import input_preprocess
from gempy_engine.config import AvailableBackends
from gempy_engine.core.data.internal_structs import SolverInput
from gempy_engine.modules.activator.activator_interface import activate_formation_block
from gempy_engine.core.backend_tensor import BackendTensor

dir_name = os.path.dirname(__file__)

plot = True


def test_activator_3_layers_segmentation_function(simple_model_3_layers, simple_grid_3d_more_points_grid):
Z_x, grid, ids_block, interpolation_input = _run_test(
backend=AvailableBackends.numpy,
ids=np.array([1, 20, 3, 4]),
simple_grid_3d_more_points_grid=simple_grid_3d_more_points_grid,
simple_model_3_layers=simple_model_3_layers
)

if plot:
_plot_continious(grid, ids_block, interpolation_input)


def test_activator_3_layers_segmentation_function_II(simple_model_3_layers, simple_grid_3d_more_points_grid):
Z_x, grid, ids_block, interpolation_input = _run_test(
backend=AvailableBackends.numpy,
ids=np.array([1, 2, 3, 4]),
simple_grid_3d_more_points_grid=simple_grid_3d_more_points_grid,
simple_model_3_layers=simple_model_3_layers
)

BackendTensor.change_backend_gempy(AvailableBackends.numpy)

if plot:
_plot_continious(grid, ids_block, interpolation_input)


def test_activator_3_layers_segmentation_function_torch(simple_model_3_layers, simple_grid_3d_more_points_grid):
Z_x, grid, ids_block, interpolation_input = _run_test(
backend=AvailableBackends.PYTORCH,
ids=np.array([1, 2, 3, 4]),
simple_grid_3d_more_points_grid=simple_grid_3d_more_points_grid,
simple_model_3_layers=simple_model_3_layers
)

BackendTensor.change_backend_gempy(AvailableBackends.numpy)
if plot:
_plot_continious(grid, ids_block, interpolation_input)


def _run_test(backend, ids, simple_grid_3d_more_points_grid, simple_model_3_layers):
interpolation_input = simple_model_3_layers[0]
options = simple_model_3_layers[1]
data_shape = simple_model_3_layers[2].tensors_structure
grid = dataclasses.replace(simple_grid_3d_more_points_grid)
interpolation_input.set_temp_grid(grid)
interp_input: SolverInput = input_preprocess(data_shape, interpolation_input)
weights = _solve_interpolation(interp_input, options.kernel_options)
exported_fields = _evaluate_sys_eq(interp_input, weights, options)
exported_fields.set_structure_values(
reference_sp_position=data_shape.reference_sp_position,
slice_feature=interpolation_input.slice_feature,
grid_size=interpolation_input.grid.len_all_grids)
Z_x: np.ndarray = exported_fields.scalar_field
sasp = exported_fields.scalar_field_at_surface_points
print(Z_x, Z_x.shape[0])
print(sasp)
BackendTensor.change_backend_gempy(backend)
ids_block = activate_formation_block(
exported_fields=exported_fields,
ids=ids,
sigmoid_slope=500 * 4
)[0, :-7]
return Z_x, grid, ids_block, interpolation_input


def _plot_continious(grid, ids_block, interpolation_input):
block__ = ids_block[grid.dense_grid_slice]
unique = np.unique(block__)
t = block__.reshape(50, 5, 50)[:, 2, :].T
unique = np.unique(t)

levels = np.linspace(t.min(), t.max(), 40)
plt.contourf(
t,
levels=levels,
cmap="jet",
extent=(.25, .75, .25, .75)
)
xyz = interpolation_input.surface_points.sp_coords
plt.plot(xyz[:, 0], xyz[:, 2], "o")
plt.colorbar()
plt.show()