Skip to content

Commit 9cbd801

Browse files
john-halloranJohn Halloran
andauthored
Replace 1D apply_interpolation with np.interp (#168)
* refactor: get residual matrix without a helper * perf: remove unused derivatives from apply_interpolation * chore: remove old residual matrix and reference to derivatives * refactor: replace remaining apply_interpolation with np.interp * style: remove references to old variable names --------- Co-authored-by: John Halloran <jhalloran@oxy.edu>
1 parent dae2f8e commit 9cbd801

File tree

1 file changed

+66
-76
lines changed

1 file changed

+66
-76
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 66 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -303,18 +303,48 @@ def outer_loop(self):
303303
)
304304

305305
def get_residual_matrix(self, components=None, weights=None, stretch=None):
306-
# Initialize residual matrix as negative of source_matrix
306+
"""
307+
Return the residuals (difference) between the source matrix and its reconstruction
308+
from the given components, weights, and stretch factors.
309+
310+
Each component profile is stretched, interpolated to fractional positions,
311+
weighted per signal, and summed to form the reconstruction. The residuals
312+
are the source matrix minus this reconstruction.
313+
314+
Parameters
315+
----------
316+
components : (signal_len, n_components) array, optional
317+
weights : (n_components, n_signals) array, optional
318+
stretch : (n_components, n_signals) array, optional
319+
320+
Returns
321+
-------
322+
residuals : (signal_len, n_signals) array
323+
"""
324+
307325
if components is None:
308326
components = self.components
309327
if weights is None:
310328
weights = self.weights
311329
if stretch is None:
312330
stretch = self.stretch
331+
313332
residuals = -self.source_matrix.copy()
314-
# Compute transformed components for all (k, m) pairs
315-
for k in range(weights.shape[0]): # K
316-
stretched_components, _, _ = apply_interpolation(stretch[k, :], components[:, k]) # Only use Ax
317-
residuals += weights[k, :] * stretched_components # Element-wise scaling and sum
333+
sample_indices = np.arange(components.shape[0]) # (signal_len,)
334+
335+
for comp in range(components.shape[1]): # loop over components
336+
residuals += (
337+
np.interp(
338+
sample_indices[:, None]
339+
/ stretch[comp][None, :], # fractional positions (signal_len, n_signals)
340+
sample_indices, # (signal_len,)
341+
components[:, comp], # component profile (signal_len,)
342+
left=components[0, comp],
343+
right=components[-1, comp],
344+
)
345+
* weights[comp][None, :] # broadcast (n_signals,) over rows
346+
)
347+
318348
return residuals
319349

320350
def get_objective_function(self, residuals=None, stretch=None):
@@ -579,42 +609,47 @@ def update_components(self):
579609

580610
def update_weights(self):
581611
"""
582-
Updates weights using matrix operations, solving a quadratic program to do so.
612+
Updates weights by building the stretched component matrix `stretched_comps` with np.interp
613+
and solving a quadratic program for each signal.
583614
"""
584615

585-
signal_length = self.signal_length
586-
n_signals = self.n_signals
587-
588-
for m in range(n_signals):
589-
t = np.zeros((signal_length, self.n_components))
590-
591-
# Populate t using apply_interpolation
592-
for k in range(self.n_components):
593-
t[:, k] = apply_interpolation(self.stretch[k, m], self.components[:, k])[0].squeeze()
594-
595-
# Solve quadratic problem for y
596-
y = self.solve_quadratic_program(t=t, m=m)
616+
sample_indices = np.arange(self.signal_length)
617+
for signal in range(self.n_signals):
618+
# Stretch factors for this signal across components:
619+
this_stretch = self.stretch[:, signal]
620+
# Build stretched_comps[:, k] by interpolating component at frac. pos. index / this_stretch[comp]
621+
stretched_comps = np.empty((self.signal_length, self.n_components), dtype=self.components.dtype)
622+
for comp in range(self.n_components):
623+
pos = sample_indices / this_stretch[comp]
624+
stretched_comps[:, comp] = np.interp(
625+
pos,
626+
sample_indices,
627+
self.components[:, comp],
628+
left=self.components[0, comp],
629+
right=self.components[-1, comp],
630+
)
597631

598-
# Update Y
599-
self.weights[:, m] = y
632+
# Solve quadratic problem for a given signal and update its weight
633+
new_weight = self.solve_quadratic_program(t=stretched_comps, m=signal)
634+
self.weights[:, signal] = new_weight
600635

601636
def regularize_function(self, stretch=None):
602637
if stretch is None:
603638
stretch = self.stretch
604639

605-
K = self.n_components
606-
M = self.n_signals
607-
N = self.signal_length
608-
609640
stretched_components, d_stretch_comps, dd_stretch_comps = self.apply_interpolation_matrix(stretch=stretch)
610-
intermediate = stretched_components.flatten(order="F").reshape((N * M, K), order="F")
611-
residuals = intermediate.sum(axis=1).reshape((N, M), order="F") - self.source_matrix
641+
intermediate = stretched_components.flatten(order="F").reshape(
642+
(self.signal_length * self.n_signals, self.n_components), order="F"
643+
)
644+
residuals = (
645+
intermediate.sum(axis=1).reshape((self.signal_length, self.n_signals), order="F") - self.source_matrix
646+
)
612647

613648
fun = self.get_objective_function(residuals, stretch)
614649

615-
tiled_res = np.tile(residuals, (1, K))
650+
tiled_res = np.tile(residuals, (1, self.n_components))
616651
grad_flat = np.sum(d_stretch_comps * tiled_res, axis=0)
617-
gra = grad_flat.reshape((M, K), order="F").T
652+
gra = grad_flat.reshape((self.n_signals, self.n_components), order="F").T
618653
gra += self.rho * stretch @ (self._spline_smooth_operator.T @ self._spline_smooth_operator)
619654

620655
# Hessian would go here
@@ -623,10 +658,10 @@ def regularize_function(self, stretch=None):
623658

624659
def update_stretch(self):
625660
"""
626-
Updates matrix A using constrained optimization (equivalent to fmincon in MATLAB).
661+
Updates stretching matrix using constrained optimization (equivalent to fmincon in MATLAB).
627662
"""
628663

629-
# Flatten A for compatibility with the optimizer (since SciPy expects 1D inputs)
664+
# Flatten stretch for compatibility with the optimizer (since SciPy expects 1D input)
630665
stretch_flat_initial = self.stretch.flatten()
631666

632667
# Define the optimization function
@@ -648,7 +683,7 @@ def objective(stretch_vec):
648683
bounds=bounds,
649684
)
650685

651-
# Update A with the optimized values
686+
# Update stretch with the optimized values
652687
self.stretch = result.x.reshape(self.stretch.shape)
653688

654689

@@ -683,48 +718,3 @@ def cubic_largest_real_root(p, q):
683718
y = np.max(real_roots, axis=0) * (delta < 0) # Keep only real roots when delta < 0
684719

685720
return y
686-
687-
688-
def apply_interpolation(a, x):
689-
"""
690-
Applies an interpolation-based transformation to `x` based on scaling `a`.
691-
Also computes first (`d_intr_x`) and second (`dd_intr_x`) derivatives.
692-
"""
693-
x_len = len(x)
694-
695-
# Ensure `a` is an array and reshape for broadcasting
696-
a = np.atleast_1d(np.asarray(a)) # Ensures a is at least 1D
697-
698-
# Compute fractional indices, broadcasting over `a`
699-
fractional_indices = np.arange(x_len)[:, None] / a # Shape (N, M)
700-
701-
integer_indices = np.floor(fractional_indices).astype(int) # Integer part (still (N, M))
702-
valid_mask = integer_indices < (x_len - 1) # Ensure indices are within bounds
703-
704-
# Apply valid_mask to keep correct indices
705-
idx_int = np.where(valid_mask, integer_indices, x_len - 2) # Prevent out-of-bounds indexing (previously "I")
706-
idx_frac = np.where(valid_mask, fractional_indices, integer_indices) # Keep aligned (previously "i")
707-
708-
# Ensure x is a 1D array
709-
x = np.asarray(x).ravel()
710-
711-
# Compute interpolated_x (linear interpolation)
712-
interpolated_x = x[idx_int] * (1 - idx_frac + idx_int) + x[np.minimum(idx_int + 1, x_len - 1)] * (
713-
idx_frac - idx_int
714-
)
715-
716-
# Fill the tail with the last valid value
717-
intr_x_tail = np.full((x_len - len(idx_int), interpolated_x.shape[1]), interpolated_x[-1, :])
718-
interpolated_x = np.vstack([interpolated_x, intr_x_tail])
719-
720-
# Compute first derivative (d_intr_x)
721-
di = -idx_frac / a
722-
d_intr_x = x[idx_int] * (-di) + x[np.minimum(idx_int + 1, x_len - 1)] * di
723-
d_intr_x = np.vstack([d_intr_x, np.zeros((x_len - len(idx_int), d_intr_x.shape[1]))])
724-
725-
# Compute second derivative (dd_intr_x)
726-
ddi = -di / a + idx_frac * a**-2
727-
dd_intr_x = x[idx_int] * (-ddi) + x[np.minimum(idx_int + 1, x_len - 1)] * ddi
728-
dd_intr_x = np.vstack([dd_intr_x, np.zeros((x_len - len(idx_int), dd_intr_x.shape[1]))])
729-
730-
return interpolated_x, d_intr_x, dd_intr_x

0 commit comments

Comments
 (0)