@@ -303,18 +303,48 @@ def outer_loop(self):
303
303
)
304
304
305
305
def get_residual_matrix (self , components = None , weights = None , stretch = None ):
306
- # Initialize residual matrix as negative of source_matrix
306
+ """
307
+ Return the residuals (difference) between the source matrix and its reconstruction
308
+ from the given components, weights, and stretch factors.
309
+
310
+ Each component profile is stretched, interpolated to fractional positions,
311
+ weighted per signal, and summed to form the reconstruction. The residuals
312
+ are the source matrix minus this reconstruction.
313
+
314
+ Parameters
315
+ ----------
316
+ components : (signal_len, n_components) array, optional
317
+ weights : (n_components, n_signals) array, optional
318
+ stretch : (n_components, n_signals) array, optional
319
+
320
+ Returns
321
+ -------
322
+ residuals : (signal_len, n_signals) array
323
+ """
324
+
307
325
if components is None :
308
326
components = self .components
309
327
if weights is None :
310
328
weights = self .weights
311
329
if stretch is None :
312
330
stretch = self .stretch
331
+
313
332
residuals = - self .source_matrix .copy ()
314
- # Compute transformed components for all (k, m) pairs
315
- for k in range (weights .shape [0 ]): # K
316
- stretched_components , _ , _ = apply_interpolation (stretch [k , :], components [:, k ]) # Only use Ax
317
- residuals += weights [k , :] * stretched_components # Element-wise scaling and sum
333
+ sample_indices = np .arange (components .shape [0 ]) # (signal_len,)
334
+
335
+ for comp in range (components .shape [1 ]): # loop over components
336
+ residuals += (
337
+ np .interp (
338
+ sample_indices [:, None ]
339
+ / stretch [comp ][None , :], # fractional positions (signal_len, n_signals)
340
+ sample_indices , # (signal_len,)
341
+ components [:, comp ], # component profile (signal_len,)
342
+ left = components [0 , comp ],
343
+ right = components [- 1 , comp ],
344
+ )
345
+ * weights [comp ][None , :] # broadcast (n_signals,) over rows
346
+ )
347
+
318
348
return residuals
319
349
320
350
def get_objective_function (self , residuals = None , stretch = None ):
@@ -579,42 +609,47 @@ def update_components(self):
579
609
580
610
def update_weights (self ):
581
611
"""
582
- Updates weights using matrix operations, solving a quadratic program to do so.
612
+ Updates weights by building the stretched component matrix `stretched_comps` with np.interp
613
+ and solving a quadratic program for each signal.
583
614
"""
584
615
585
- signal_length = self .signal_length
586
- n_signals = self .n_signals
587
-
588
- for m in range (n_signals ):
589
- t = np .zeros ((signal_length , self .n_components ))
590
-
591
- # Populate t using apply_interpolation
592
- for k in range (self .n_components ):
593
- t [:, k ] = apply_interpolation (self .stretch [k , m ], self .components [:, k ])[0 ].squeeze ()
594
-
595
- # Solve quadratic problem for y
596
- y = self .solve_quadratic_program (t = t , m = m )
616
+ sample_indices = np .arange (self .signal_length )
617
+ for signal in range (self .n_signals ):
618
+ # Stretch factors for this signal across components:
619
+ this_stretch = self .stretch [:, signal ]
620
+ # Build stretched_comps[:, k] by interpolating component at frac. pos. index / this_stretch[comp]
621
+ stretched_comps = np .empty ((self .signal_length , self .n_components ), dtype = self .components .dtype )
622
+ for comp in range (self .n_components ):
623
+ pos = sample_indices / this_stretch [comp ]
624
+ stretched_comps [:, comp ] = np .interp (
625
+ pos ,
626
+ sample_indices ,
627
+ self .components [:, comp ],
628
+ left = self .components [0 , comp ],
629
+ right = self .components [- 1 , comp ],
630
+ )
597
631
598
- # Update Y
599
- self .weights [:, m ] = y
632
+ # Solve quadratic problem for a given signal and update its weight
633
+ new_weight = self .solve_quadratic_program (t = stretched_comps , m = signal )
634
+ self .weights [:, signal ] = new_weight
600
635
601
636
def regularize_function (self , stretch = None ):
602
637
if stretch is None :
603
638
stretch = self .stretch
604
639
605
- K = self .n_components
606
- M = self .n_signals
607
- N = self .signal_length
608
-
609
640
stretched_components , d_stretch_comps , dd_stretch_comps = self .apply_interpolation_matrix (stretch = stretch )
610
- intermediate = stretched_components .flatten (order = "F" ).reshape ((N * M , K ), order = "F" )
611
- residuals = intermediate .sum (axis = 1 ).reshape ((N , M ), order = "F" ) - self .source_matrix
641
+ intermediate = stretched_components .flatten (order = "F" ).reshape (
642
+ (self .signal_length * self .n_signals , self .n_components ), order = "F"
643
+ )
644
+ residuals = (
645
+ intermediate .sum (axis = 1 ).reshape ((self .signal_length , self .n_signals ), order = "F" ) - self .source_matrix
646
+ )
612
647
613
648
fun = self .get_objective_function (residuals , stretch )
614
649
615
- tiled_res = np .tile (residuals , (1 , K ))
650
+ tiled_res = np .tile (residuals , (1 , self . n_components ))
616
651
grad_flat = np .sum (d_stretch_comps * tiled_res , axis = 0 )
617
- gra = grad_flat .reshape ((M , K ), order = "F" ).T
652
+ gra = grad_flat .reshape ((self . n_signals , self . n_components ), order = "F" ).T
618
653
gra += self .rho * stretch @ (self ._spline_smooth_operator .T @ self ._spline_smooth_operator )
619
654
620
655
# Hessian would go here
@@ -623,10 +658,10 @@ def regularize_function(self, stretch=None):
623
658
624
659
def update_stretch (self ):
625
660
"""
626
- Updates matrix A using constrained optimization (equivalent to fmincon in MATLAB).
661
+ Updates stretching matrix using constrained optimization (equivalent to fmincon in MATLAB).
627
662
"""
628
663
629
- # Flatten A for compatibility with the optimizer (since SciPy expects 1D inputs )
664
+ # Flatten stretch for compatibility with the optimizer (since SciPy expects 1D input )
630
665
stretch_flat_initial = self .stretch .flatten ()
631
666
632
667
# Define the optimization function
@@ -648,7 +683,7 @@ def objective(stretch_vec):
648
683
bounds = bounds ,
649
684
)
650
685
651
- # Update A with the optimized values
686
+ # Update stretch with the optimized values
652
687
self .stretch = result .x .reshape (self .stretch .shape )
653
688
654
689
@@ -683,48 +718,3 @@ def cubic_largest_real_root(p, q):
683
718
y = np .max (real_roots , axis = 0 ) * (delta < 0 ) # Keep only real roots when delta < 0
684
719
685
720
return y
686
-
687
-
688
- def apply_interpolation (a , x ):
689
- """
690
- Applies an interpolation-based transformation to `x` based on scaling `a`.
691
- Also computes first (`d_intr_x`) and second (`dd_intr_x`) derivatives.
692
- """
693
- x_len = len (x )
694
-
695
- # Ensure `a` is an array and reshape for broadcasting
696
- a = np .atleast_1d (np .asarray (a )) # Ensures a is at least 1D
697
-
698
- # Compute fractional indices, broadcasting over `a`
699
- fractional_indices = np .arange (x_len )[:, None ] / a # Shape (N, M)
700
-
701
- integer_indices = np .floor (fractional_indices ).astype (int ) # Integer part (still (N, M))
702
- valid_mask = integer_indices < (x_len - 1 ) # Ensure indices are within bounds
703
-
704
- # Apply valid_mask to keep correct indices
705
- idx_int = np .where (valid_mask , integer_indices , x_len - 2 ) # Prevent out-of-bounds indexing (previously "I")
706
- idx_frac = np .where (valid_mask , fractional_indices , integer_indices ) # Keep aligned (previously "i")
707
-
708
- # Ensure x is a 1D array
709
- x = np .asarray (x ).ravel ()
710
-
711
- # Compute interpolated_x (linear interpolation)
712
- interpolated_x = x [idx_int ] * (1 - idx_frac + idx_int ) + x [np .minimum (idx_int + 1 , x_len - 1 )] * (
713
- idx_frac - idx_int
714
- )
715
-
716
- # Fill the tail with the last valid value
717
- intr_x_tail = np .full ((x_len - len (idx_int ), interpolated_x .shape [1 ]), interpolated_x [- 1 , :])
718
- interpolated_x = np .vstack ([interpolated_x , intr_x_tail ])
719
-
720
- # Compute first derivative (d_intr_x)
721
- di = - idx_frac / a
722
- d_intr_x = x [idx_int ] * (- di ) + x [np .minimum (idx_int + 1 , x_len - 1 )] * di
723
- d_intr_x = np .vstack ([d_intr_x , np .zeros ((x_len - len (idx_int ), d_intr_x .shape [1 ]))])
724
-
725
- # Compute second derivative (dd_intr_x)
726
- ddi = - di / a + idx_frac * a ** - 2
727
- dd_intr_x = x [idx_int ] * (- ddi ) + x [np .minimum (idx_int + 1 , x_len - 1 )] * ddi
728
- dd_intr_x = np .vstack ([dd_intr_x , np .zeros ((x_len - len (idx_int ), dd_intr_x .shape [1 ]))])
729
-
730
- return interpolated_x , d_intr_x , dd_intr_x
0 commit comments