Skip to content

[ENH] Add numpy array validator and fix type annotations #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
4 changes: 4 additions & 0 deletions gempy_engine/core/data/encoders/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import numpy as np
from pydantic import BeforeValidator

numpy_array_short_validator = BeforeValidator(lambda v: np.array(v) if v is not None else None)
4 changes: 2 additions & 2 deletions gempy_engine/core/data/options/interpolation_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def from_args(
# @on

@classmethod
def init_octree_options(cls, range=1.7, c_o=10, refinement: int = 1):
def init_octree_options(cls, range=1.7, c_o=10., refinement: int = 1):
return InterpolationOptions.from_args(
range=range,
c_o=c_o,
Expand All @@ -118,7 +118,7 @@ def init_octree_options(cls, range=1.7, c_o=10, refinement: int = 1):
def init_dense_grid_options(cls):
options = InterpolationOptions.from_args(
range=1.7,
c_o=10,
c_o=10.,
mesh_extraction=False,
number_octree_levels=1
)
Expand Down
3 changes: 2 additions & 1 deletion gempy_engine/core/data/options/kernel_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class KernelOptions:
optimizing_condition_number: bool = False
condition_number: Optional[float] = None

@field_validator('kernel_function', mode='before')

@field_validator('kernel_function', mode='before', json_schema_input_type=str)
@classmethod
def _deserialize_kernel_function_from_name(cls, value):
"""
Expand Down
55 changes: 28 additions & 27 deletions gempy_engine/core/data/transforms.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import pprint
import warnings
from dataclasses import dataclass
from enum import Enum, auto
from typing import Optional, Union
from typing import Optional

import numpy as np
from dataclasses import dataclass
from typing_extensions import Annotated

from .encoders.converters import numpy_array_short_validator


class TransformOpsOrder(Enum):
Expand All @@ -13,16 +16,16 @@ class TransformOpsOrder(Enum):


class GlobalAnisotropy(Enum):
CUBE = auto() # * Transform data to be as close as possible to a cube
NONE = auto() # * Do not transform data
MANUAL = auto() # * Use the user defined transform
CUBE = auto() # * Transform data to be as close as possible to a cube
NONE = auto() # * Do not transform data
MANUAL = auto() # * Use the user defined transform


@dataclass
class Transform:
position: np.ndarray
rotation: np.ndarray
scale: np.ndarray
position: Annotated[np.ndarray, numpy_array_short_validator]
rotation: Annotated[np.ndarray, numpy_array_short_validator]
scale: Annotated[np.ndarray, numpy_array_short_validator]

_is_default_transform: bool = False
_cached_pivot: Optional[np.ndarray] = None
Expand Down Expand Up @@ -68,11 +71,10 @@ def from_matrix(cls, matrix: np.ndarray):
])
return cls(position, rotation_degrees, scale)


@property
def cached_pivot(self):
return self._cached_pivot

@cached_pivot.setter
def cached_pivot(self, pivot: np.ndarray):
self._cached_pivot = pivot
Expand All @@ -96,7 +98,7 @@ def from_input_points(cls, surface_points: 'gempy.data.SurfacePointsTable', orie

# The scaling factor for each dimension is the inverse of its range
scaling_factors = 1 / range_coord

# ! Be careful with toy models
center: np.ndarray = (max_coord + min_coord) / 2
return cls(
Expand Down Expand Up @@ -127,14 +129,14 @@ def apply_anisotropy(self, anisotropy_type: GlobalAnisotropy, anisotropy_limit:
)
else:
raise NotImplementedError

@staticmethod
def _adjust_scale_to_limit_ratio(s, anisotropic_limit=np.array([10, 10, 10])):
# Calculate the ratios
ratios = [
s[0] / s[1], s[0] / s[2],
s[1] / s[0], s[1] / s[2],
s[2] / s[0], s[2] / s[1]
s[0] / s[1], s[0] / s[2],
s[1] / s[0], s[1] / s[2],
s[2] / s[0], s[2] / s[1]
]

# Adjust the scales based on the index of the max ratio
Expand All @@ -158,9 +160,9 @@ def _adjust_scale_to_limit_ratio(s, anisotropic_limit=np.array([10, 10, 10])):
@staticmethod
def _max_scale_ratio(s):
ratios = [
s[0] / s[1], s[0] / s[2],
s[1] / s[0], s[1] / s[2],
s[2] / s[0], s[2] / s[1]
s[0] / s[1], s[0] / s[2],
s[1] / s[0], s[1] / s[2],
s[2] / s[0], s[2] / s[1]
]
return max(ratios)

Expand Down Expand Up @@ -223,7 +225,7 @@ def apply(self, points: np.ndarray, transform_op_order: TransformOpsOrder = Tran

def scale_points(self, points: np.ndarray):
return points * self.scale

def apply_inverse(self, points: np.ndarray, transform_op_order: TransformOpsOrder = TransformOpsOrder.SRT):
# * NOTE: to compare with legacy we would have to add 0.5 to the coords
assert points.shape[1] == 3
Expand All @@ -233,12 +235,11 @@ def apply_inverse(self, points: np.ndarray, transform_op_order: TransformOpsOrde
transformed_points = (inv @ homogeneous_points.T).T
return transformed_points[:, :3]


def apply_with_cached_pivot(self, points: np.ndarray, transform_op_order: TransformOpsOrder = TransformOpsOrder.SRT):
if self._cached_pivot is None:
raise ValueError("A pivot must be set before calling this method")
return self.apply_with_pivot(points, self._cached_pivot, transform_op_order)

def apply_inverse_with_cached_pivot(self, points: np.ndarray, transform_op_order: TransformOpsOrder = TransformOpsOrder.SRT):
if self._cached_pivot is None:
raise ValueError("A pivot must be set before calling this method")
Expand Down Expand Up @@ -269,7 +270,7 @@ def apply_with_pivot(self, points: np.ndarray, pivot: np.ndarray,
def apply_inverse_with_pivot(self, points: np.ndarray, pivot: np.ndarray,
transform_op_order: TransformOpsOrder = TransformOpsOrder.SRT):
assert points.shape[1] == 3

# Translation matrices to and from the pivot
T_to_origin = self._translation_matrix(-pivot[0], -pivot[1], -pivot[2])
T_back = self._translation_matrix(*pivot)
Expand All @@ -284,10 +285,10 @@ def apply_inverse_with_pivot(self, points: np.ndarray, pivot: np.ndarray,
@staticmethod
def _translation_matrix(tx, ty, tz):
return np.array([
[1, 0, 0, tx],
[0, 1, 0, ty],
[0, 0, 1, tz],
[0, 0, 0, 1]
[1, 0, 0, tx],
[0, 1, 0, ty],
[0, 0, 1, tz],
[0, 0, 0, 1]
])

def transform_gradient(self, gradients: np.ndarray, transform_op_order: TransformOpsOrder = TransformOpsOrder.SRT,
Expand Down