Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 283 additions & 0 deletions src/grid/cubic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@

from grid.basegrid import Grid, OneDGrid

from collections import deque
from typing import Optional, Callable
from scipy.spatial import cKDTree


class _HyperRectangleGrid(Grid):
def __init__(self, points, weights, shape):
Expand Down Expand Up @@ -553,6 +557,7 @@ def __init__(self, origin, axes, shape, weight="Trapezoid"):
dim = self._origin.size
# Make an array to store coordinates of grid points
self._points = np.zeros((np.prod(shape), dim))
self._weight_scheme = weight
if dim == 3:
coords = np.array(
np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2]))
Expand Down Expand Up @@ -784,6 +789,11 @@ def origin(self):
"""Return the Cartesian coordinates of the uniform grid origin."""
return self._origin

@property
def weight_scheme(self):
r"""Return the weight scheme of the uniform grid."""
return self._weight_scheme

def save(self, filename):
r"""
Save uniform cubic grid attributes as a npz file.
Expand Down Expand Up @@ -1013,3 +1023,276 @@ def generate_cube(self, fname, data, atcoords, atnums, pseudo_numbers=None):
row_data = data.flat[i : i + num_chunks]
f.write((row_data.size * " {:12.5E}").format(*row_data))
f.write("\n")


class AdaptiveUniformGrid:
"""
This is a wrapper class that provides adaptive refinement for a UniformGrid instance.

This class takes a UniformGrid object and applies a recursive subdivision
algorithm to generate a new, non-uniform grid with points concentrated in
regions of high function error, leading to more efficient and accurate integration.

The main entry point is the `refinement` method.
"""

def __init__(self, uniform_grid: UniformGrid, error_estimate = "quadrature"):
"""Initialization.

Parameters
----------
uniform_grid : UniformGrid
The coarse, uniform grid that will serve as the starting point for refinement.
error_estimate: str, optional
The type of error used to choose which points to refine. Either
'quadrature' (suited for integration), or `gradient`. Default is 'quadrature'.
"""
if not isinstance(uniform_grid, UniformGrid):
raise ValueError("The input grid should be a UniformGrid instance.")
if uniform_grid.weight_scheme != "Rectangle":
raise ValueError(f"The weight scheme {uniform_grid.weight_scheme} should be Rectangle.")
self.grid = uniform_grid
self.ndim = uniform_grid.ndim
self.axis_spacings = np.array([np.linalg.norm(axis) for axis in self.grid.axes])
self.axes_norm = self.grid.axes / self.axis_spacings
self.error_estimate = error_estimate

if self.error_estimate == "gradient":
# Build flat indices for (-) and (+) along each axis, grouped together
# makes it faster to do central finite-difference
# This should match the output order of `_generate_subdivision_points`.
# in 3D- it is [[9, 18], [3, 6], [1, 2]] corresponding to each dimension
# in 2D- it is [[3, 6], [1, 2]]
idx_pairs = []
shape = (3,) * self.ndim # grid shape per axis
for i in range(self.ndim):
neg = [0] * self.ndim
neg[i] = 1 # index 1 corresponds to -1 offset
pos = [0] * self.ndim
pos[i] = 2 # index 2 corresponds to +1 offset
idx_pairs.append((
np.ravel_multi_index(neg, shape),
np.ravel_multi_index(pos, shape),
))
idx_pairs = np.array(idx_pairs)
self._idx_pairs = idx_pairs

def _get_func_values(
self, points: np.ndarray, func: Callable, evaluated_points: dict
) -> np.ndarray:
if len(points) == 0:
return np.array([])

# Round points for consistent cache keys
rounded_points = np.round(points, 10)

# Check which points need evaluation
keys = [tuple(p) for p in rounded_points]
missing_indices = []
values = np.zeros(len(points))

for i, key in enumerate(keys):
if key in evaluated_points:
values[i] = evaluated_points[key]
else:
missing_indices.append(i)

# Batch evaluate missing points
if missing_indices:
missing_points = points[missing_indices]
missing_values = func(missing_points)

# Update cache and values array
for i, missing_idx in enumerate(missing_indices):
key = keys[missing_idx]
value = missing_values[i]
evaluated_points[key] = value
values[missing_idx] = value

return values

def _estimate_error(self, point, weight, func_vals, spacings, subdivision_points):
if self.error_estimate == "quadrature":
return self._estimate_error_quadrature(
point, weight, func_vals, spacings, subdivision_points
)
elif self.error_estimate == "gradient":
return self._estimate_error_gradient(
point, weight, func_vals, spacings, subdivision_points
)
raise ValueError(f"Could not recognize the type of error estimate {self.error_estimate}.")

def _estimate_error_quadrature(self, point, weight, func_vals, _, subdivision_points):
child_weight = weight / len(subdivision_points)
quad_at_pt = func_vals[0] * weight # At the center point
quad_child = np.sum(func_vals * child_weight)
err = np.abs(quad_at_pt - quad_child)
return err

def _estimate_error_gradient(
self, point, _, func_vals, spacings, __
) -> float:
"""
Estimate error using finite difference gradient approximation with batch evaluation.
Uses efficient batch function evaluation to minimize function call overhead.
"""
gradient_magnitude = 0.0
for dim in range(self.ndim):
# Grabs the minus and forward index
i_minus = self._idx_pairs[dim][0]
i_plus = self._idx_pairs[dim][1]

f_forward = func_vals[i_plus]
f_backward = func_vals[i_minus]

# The spacing between center point and i_minus/i_plus is spacing/3.0
grad_dim = (f_forward - f_backward) / (2 * spacings[dim] / 3.0)
gradient_magnitude += grad_dim**2

gradient_magnitude = np.sqrt(gradient_magnitude)

# Error estimate: gradient magnitude times spacing
# Use geometric mean of spacings as characteristic length
characteristic_spacing = np.prod(spacings) ** (1 / self.ndim)

return gradient_magnitude * characteristic_spacing

def _generate_subdivision_points(
self, center_point: np.ndarray, spacings: np.ndarray
) -> np.ndarray:
# Generate all subdivision points for uniform 3^D cube subdivision.
# The number of subdivision points is 3^D
child_spacings = spacings / 3.0
# Generate all possible combinations of {0, +, -} such that
# the first point is the `center_point` within the subdivision.
ranges = (np.array([0.0, -1.0, 1.0])[:, None] * child_spacings).T
grids = np.meshgrid(*ranges, indexing="ij")
points_spacing = np.column_stack([grid.ravel() for grid in grids]) # Shape: (3^D, 3)
# With the spacing in each dimension, multiply it by the axes of the cubic grid.
# here k = D, j = D, and the number of subdivision is i = 3^D.
spacing_axes = points_spacing @ self.axes_norm
subdivision_points = center_point + spacing_axes
return subdivision_points

def refinement(
self,
func: Callable,
tolerance: float = 1e-4,
min_spacing: Optional[float] = None,
max_depth: int = 10,
refine_contrib_threshold: float = 1e-4
) -> dict:
"""
Parameters
----------
func : Callable
The real-valued, scalar function to be integrated.
The function must be vectorized, and takes in ndarray(M,3) -> float.
tolerance : float, optional
The error tolerance for a local point.
min_spacing : float, optional
The minimum allowed spacing for subdivision.
max_depth : int, optional
Maximum refinement depth to prevent infinite loops.
refine_contrib_threshold: float, optional
Skip refinement for cells with |f(center)| V < refine_value_threshold,
where V is the volume element.
Prevents work in negligible regions. Units match f. Default: 0.0.

Returns
-------
dict
A dictionary containing the final integral value, refined grid, and statistics.
"""
if min_spacing is None:
min_spacing = np.min(self.axis_spacings) / 100

points = self.grid.points.copy()
weights = self.grid.weights.copy()
initial_spacings = self.axis_spacings.copy()
func_evals = func(points)

# Refinement dequeue takes in 5 arguments:
# (index of point, point, current spacing, current weight, depth)
refinement_queue = deque()

# Potentially do refinement on that satisfies this error criteria
# Speeds up computation
indices = np.where(np.abs(func_evals) * weights > refine_contrib_threshold)[0]
for index in indices:
refinement_queue.append(
(
index,
points[index, :],
initial_spacings.copy(),
weights[index],
0
)
)

# Process refinement queue
while refinement_queue:
index, point, spacings, weight, depth = refinement_queue.popleft()

if depth > max_depth or np.any(spacings < min_spacing):
continue

subdivision_pts = self._generate_subdivision_points(point, spacings)

# Compute function values at the subdivision points
func_vals_center = func_evals[index]
func_vals_extra = func(subdivision_pts[1:])
func_vals_subdiv = np.concatenate(([func_vals_center], func_vals_extra))

# Use the subdivision points and function values to compute the error
local_error = self._estimate_error(
point, weight, func_vals_subdiv, spacings, subdivision_pts
)

# Do refinement on the center point
if local_error > tolerance:
num_sub_points = len(subdivision_pts) # 3^ndim

# Update the weight/spacing of that center point
weights[index] = weight / num_sub_points
child_spacings = spacings / 3

# Add center point back to the queue with updated weight and spacing
refinement_queue.append(
(
index,
point,
child_spacings.copy(),
weights[index],
depth + 1
)
)

# Add all subdivision points back to queue for further processing
# Ignore the first point since it is the center point, and added before.
for i_subpt, sub_point in enumerate(subdivision_pts[1:]):
refinement_queue.append(
(
len(points) + i_subpt,
sub_point,
child_spacings.copy(),
weights[index],
depth + 1
)
)
# Add the points, func_evals and weights to the initial list.
points = np.vstack((points, subdivision_pts[1:]))
weights = np.append(
weights,
np.full((num_sub_points - 1), fill_value = weights[index])
)
func_evals = np.append(func_evals, func_vals_extra)

final_grid = Grid(points, weights)
final_integral = final_grid.integrate(func_evals)
return {
"integral": final_integral,
"final_grid": final_grid,
"num_points": len(final_grid.points),
"num_evaluations": len(final_grid.points), # Accurate count of function evaluations
}
Loading