Skip to content

[ENH] Improve type annotations with short_array_type and field defaults #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions gempy_engine/core/data/centered_grid.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
from dataclasses import dataclass
from typing import Sequence, Union
from dataclasses import dataclass, field

import numpy as np

from gempy_engine.core.backend_tensor import BackendTensor
from gempy_engine.core.utils import cast_type_inplace
from .encoders.converters import short_array_type


@dataclass
class CenteredGrid:
centers: np.ndarray #: This is just used to calculate xyz to interpolate. Tz is independent
resolution: Sequence[float]
radius: Union[float, Sequence[float]]
centers: short_array_type #: This is just used to calculate xyz to interpolate. Tz is independent
resolution: short_array_type
radius: float | short_array_type

kernel_grid_centers: np.ndarray = None
left_voxel_edges: np.ndarray = None
right_voxel_edges: np.ndarray = None
kernel_grid_centers: np.ndarray = field(init=False)
left_voxel_edges: np.ndarray = field(init=False)
right_voxel_edges: np.ndarray = field(init=False)

def __len__(self):
return self.centers.shape[0] * self.kernel_grid_centers.shape[0]
Expand Down
3 changes: 3 additions & 0 deletions gempy_engine/core/data/encoders/converters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Annotated

import numpy as np
from pydantic import BeforeValidator

numpy_array_short_validator = BeforeValidator(lambda v: np.array(v) if v is not None else None)
short_array_type = Annotated[np.ndarray, numpy_array_short_validator]
11 changes: 7 additions & 4 deletions gempy_engine/core/data/geophysics_input.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from dataclasses import dataclass
from typing import Annotated

from ..backend_tensor import BackendTensor
import numpy as np

from .encoders.converters import numpy_array_short_validator


@dataclass
class GeophysicsInput():
tz: BackendTensor.t
densities: BackendTensor.t
class GeophysicsInput:
tz: Annotated[np.ndarray, numpy_array_short_validator]
densities: Annotated[np.ndarray, numpy_array_short_validator]
25 changes: 15 additions & 10 deletions gempy_engine/core/data/kernel_classes/faults.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,39 @@
import dataclasses
from typing import Optional
from typing import Optional, Callable

import numpy as np
from pydantic import Field

from gempy_engine.core.data.transforms import Transform
from ..encoders.converters import short_array_type
from ..transforms import Transform


@dataclasses.dataclass
class FiniteFaultData:
implicit_function: callable
implicit_function_transform: Transform
pivot: np.ndarray
implicit_function: Callable | None = Field(exclude=True, default=None)#, default=None)
implicit_function_transform: Transform = Field()
pivot: short_array_type = Field()

def apply(self, points: np.ndarray) -> np.ndarray:
transformed_points = self.implicit_function_transform.apply_inverse_with_pivot(
points=points,
pivot=self.pivot
)
if self.implicit_function is None:
raise ValueError("No implicit function defined. This can happen after deserializing (loading).")

scalar_block = self.implicit_function(transformed_points)
return scalar_block



@dataclasses.dataclass
class FaultsData:
fault_values_everywhere: np.ndarray = None
fault_values_on_sp: np.ndarray = None
fault_values_everywhere: short_array_type | None = None
fault_values_on_sp: short_array_type | None = None

fault_values_ref: np.ndarray = None
fault_values_rest: np.ndarray = None
fault_values_ref: short_array_type | None = None
fault_values_rest: short_array_type | None = None

# User given data:
thickness: Optional[float] = None
Expand Down
3 changes: 2 additions & 1 deletion gempy_engine/core/data/options/interpolation_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class CacheMode(enum.Enum):
# region Volatile
temp_interpolation_values: TempInterpolationValues = Field(
default_factory=TempInterpolationValues,
exclude=True
exclude=True,
repr=False
)

# endregion
Expand Down