@@ -134,7 +134,7 @@ def __init__(
134134 # Initialize weights and determine number of components
135135 if init_weights is None :
136136 self .n_components = n_components
137- self .weights = self ._rng .beta (a = 2.5 , b = 1.5 , size = (self .n_components , self .n_signals ))
137+ self .weights = self ._rng .beta (a = 2.0 , b = 2.0 , size = (self .n_components , self .n_signals ))
138138 else :
139139 self .n_components = init_weights .shape [0 ]
140140 self .weights = init_weights
@@ -169,7 +169,7 @@ def __init__(
169169 self .objective_difference = None
170170 self ._objective_history = [self .objective_function ]
171171
172- # Set up tracking variables for updateX ()
172+ # Set up tracking variables for update_components ()
173173 self ._prev_components = None
174174 self .grad_components = np .zeros_like (self .components ) # Gradient of X (zeros for now)
175175 self ._prev_grad_components = np .zeros_like (self .components ) # Previous gradient of X (zeros for now)
@@ -233,17 +233,21 @@ def __init__(
233233 def optimize_loop (self ):
234234 # Update components first
235235 self ._prev_grad_components = self .grad_components .copy ()
236+
236237 self .update_components ()
238+
237239 self .num_updates += 1
238240 self .residuals = self .get_residual_matrix ()
239241 self .objective_function = self .get_objective_function ()
240242 print (f"Objective function after update_components: { self .objective_function :.5e} " )
241243 self ._objective_history .append (self .objective_function )
244+
242245 if self .objective_difference is None :
243246 self .objective_difference = self ._objective_history [- 1 ] - self .objective_function
244247
245248 # Now we update weights
246249 self .update_weights ()
250+
247251 self .num_updates += 1
248252 self .residuals = self .get_residual_matrix ()
249253 self .objective_function = self .get_objective_function ()
@@ -257,6 +261,7 @@ def optimize_loop(self):
257261 self .objective_function = self .get_objective_function ()
258262 print (f"Objective function after update_stretch: { self .objective_function :.5e} " )
259263 self ._objective_history .append (self .objective_function )
264+
260265 self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
261266
262267 def apply_interpolation (self , a , x , return_derivatives = False ):
0 commit comments