@@ -52,10 +52,6 @@ class SNMFOptimizer:
5252 num_updates : int
5353 The total number of times that any of (stretch, components, and weights) have had their values changed.
5454 If not terminated by other means, this value is used to stop when reaching max_iter.
55- objective_function: float
56- The value corresponding to the minimization of the difference between the source_matrix and the
57- products of (stretch, components, and weights). For full details see the sNMF paper. Smaller corresponds to
58- better agreement and is desirable.
5955 objective_difference : float
6056 The change in the objective function value since the last update. A negative value
6157 means that the result improved.
@@ -165,9 +161,9 @@ def __init__(
165161
166162 # Set up residual matrix, objective function, and history
167163 self .residuals = self .get_residual_matrix ()
168- self .objective_function = self .get_objective_function ()
164+ self ._objective_history = []
165+ self .update_objective ()
169166 self .objective_difference = None
170- self ._objective_history = [self .objective_function ]
171167
172168 # Set up tracking variables for update_components()
173169 self ._prev_components = None
@@ -177,8 +173,8 @@ def __init__(
177173 regularization_term = 0.5 * rho * np .linalg .norm (self ._spline_smooth_operator @ self .stretch .T , "fro" ) ** 2
178174 sparsity_term = eta * np .sum (np .sqrt (self .components )) # Square root penalty
179175 print (
180- f"Start, Objective function: { self .objective_function :.5e} "
181- f", Obj - reg/sparse: { self .objective_function - regularization_term - sparsity_term :.5e} "
176+ f"Start, Objective function: { self ._objective_history [ - 1 ] :.5e} "
177+ f", Obj - reg/sparse: { self ._objective_history [ - 1 ] - regularization_term - sparsity_term :.5e} "
182178 )
183179
184180 # Main optimization loop
@@ -191,15 +187,15 @@ def __init__(
191187 sparsity_term = eta * np .sum (np .sqrt (self .components )) # Square root penalty
192188 print (
193189 f"Num_updates: { self .num_updates } , "
194- f"Obj fun: { self .objective_function :.5e} , "
195- f"Obj - reg/sparse: { self .objective_function - regularization_term - sparsity_term :.5e} , "
190+ f"Obj fun: { self ._objective_history [ - 1 ] :.5e} , "
191+ f"Obj - reg/sparse: { self ._objective_history [ - 1 ] - regularization_term - sparsity_term :.5e} , "
196192 f"Iter: { iter } "
197193 )
198194
199195 # Convergence check: decide when to terminate for small/no improvement
200- print (self .objective_difference , " < " , self .objective_function * tol )
201- if self .objective_difference < self .objective_function * tol and iter >= 20 :
196+ if self .objective_difference < self ._objective_history [- 1 ] * tol and iter >= 20 :
202197 break
198+ print (self .objective_difference , " < " , self ._objective_history [- 1 ] * tol )
203199
204200 # Normalize our results
205201 weights_row_max = np .max (self .weights , axis = 1 , keepdims = True )
@@ -214,17 +210,17 @@ def __init__(
214210 self .grad_components = np .zeros_like (self .components )
215211 self ._prev_grad_components = np .zeros_like (self .components )
216212 self .residuals = self .get_residual_matrix ()
217- self .objective_function = self .get_objective_function ()
218213 self .objective_difference = None
219- self ._objective_history = [self .objective_function ]
214+ self ._objective_history = []
215+ self .update_objective ()
220216 for norm_iter in range (100 ):
221217 self .update_components ()
222218 self .residuals = self .get_residual_matrix ()
223- self .objective_function = self . get_objective_function ()
224- print (f"Objective function after normX : { self .objective_function :.5e} " )
225- self ._objective_history .append (self .objective_function )
219+ self .update_objective ()
220+ print (f"Objective function after normalize_components : { self ._objective_history [ - 1 ] :.5e} " )
221+ self ._objective_history .append (self ._objective_history [ - 1 ] )
226222 self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
227- if self .objective_difference < self .objective_function * tol and norm_iter >= 20 :
223+ if self .objective_difference < self ._objective_history [ - 1 ] * tol and norm_iter >= 20 :
228224 break
229225 # end of normalization (and program)
230226 # note that objective function may not fully recover after normalization, this is okay
@@ -238,29 +234,27 @@ def optimize_loop(self):
238234
239235 self .num_updates += 1
240236 self .residuals = self .get_residual_matrix ()
241- self .objective_function = self .get_objective_function ()
242- print (f"Objective function after update_components: { self .objective_function :.5e} " )
243- self ._objective_history .append (self .objective_function )
237+ self .update_objective ()
238+ print (f"Objective function after update_components: { self ._objective_history [- 1 ]:.5e} " )
244239
245240 if self .objective_difference is None :
246- self .objective_difference = self ._objective_history [- 1 ] - self .objective_function
241+ self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [ - 1 ]
247242
248243 # Now we update weights
249244 self .update_weights ()
250245
251246 self .num_updates += 1
252247 self .residuals = self .get_residual_matrix ()
253- self .objective_function = self .get_objective_function ()
254- print (f"Objective function after update_weights: { self .objective_function :.5e} " )
255- self ._objective_history .append (self .objective_function )
248+ self .update_objective ()
249+ print (f"Objective function after update_weights: { self ._objective_history [- 1 ]:.5e} " )
256250
257251 # Now we update stretch
258252 self .update_stretch ()
253+
259254 self .num_updates += 1
260255 self .residuals = self .get_residual_matrix ()
261- self .objective_function = self .get_objective_function ()
262- print (f"Objective function after update_stretch: { self .objective_function :.5e} " )
263- self ._objective_history .append (self .objective_function )
256+ self .update_objective ()
257+ print (f"Objective function after update_stretch: { self ._objective_history [- 1 ]:.5e} " )
264258
265259 self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
266260
@@ -333,7 +327,8 @@ def get_residual_matrix(self, components=None, weights=None, stretch=None):
333327 residuals += weights [k , :] * stretched_components # Element-wise scaling and sum
334328 return residuals
335329
336- def get_objective_function (self , residuals = None , stretch = None ):
330+ def update_objective (self , residuals = None , stretch = None ):
331+ to_return = not (residuals is None and stretch is None )
337332 if residuals is None :
338333 residuals = self .residuals
339334 if stretch is None :
@@ -343,7 +338,11 @@ def get_objective_function(self, residuals=None, stretch=None):
343338 sparsity_term = self .eta * np .sum (np .sqrt (self .components )) # Square root penalty
344339 # Final objective function value
345340 function = residual_term + regularization_term + sparsity_term
346- return function
341+
342+ if to_return :
343+ return function # Get value directly for use
344+ else :
345+ self ._objective_history .append (function ) # Store value
347346
348347 def apply_interpolation_matrix (self , components = None , weights = None , stretch = None , return_derivatives = False ):
349348 """
@@ -595,7 +594,7 @@ def update_components(self):
595594 )
596595 self .components = mask * self .components
597596
598- objective_improvement = self ._objective_history [- 1 ] - self .get_objective_function (
597+ objective_improvement = self ._objective_history [- 1 ] - self .update_objective (
599598 residuals = self .get_residual_matrix ()
600599 )
601600
@@ -650,7 +649,7 @@ def regularize_function(self, stretch=None):
650649 stretch_difference = stretch_difference - self .source_matrix
651650
652651 # Compute objective function
653- reg_func = self .get_objective_function (stretch_difference , stretch )
652+ reg_func = self .update_objective (stretch_difference , stretch )
654653
655654 # Compute gradient
656655 tiled_derivative = np .sum (
0 commit comments