Skip to content

Commit 46476bb

Browse files
author
John Halloran
committed
style: don't store objective function in a class attribute, just use history
1 parent b0e44f8 commit 46476bb

File tree

1 file changed

+31
-32
lines changed

1 file changed

+31
-32
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ class SNMFOptimizer:
5252
num_updates : int
5353
The total number of times that any of (stretch, components, and weights) have had their values changed.
5454
If not terminated by other means, this value is used to stop when reaching max_iter.
55-
objective_function: float
56-
The value corresponding to the minimization of the difference between the source_matrix and the
57-
products of (stretch, components, and weights). For full details see the sNMF paper. Smaller corresponds to
58-
better agreement and is desirable.
5955
objective_difference : float
6056
The change in the objective function value since the last update. A negative value
6157
means that the result improved.
@@ -165,9 +161,9 @@ def __init__(
165161

166162
# Set up residual matrix, objective function, and history
167163
self.residuals = self.get_residual_matrix()
168-
self.objective_function = self.get_objective_function()
164+
self._objective_history = []
165+
self.update_objective()
169166
self.objective_difference = None
170-
self._objective_history = [self.objective_function]
171167

172168
# Set up tracking variables for update_components()
173169
self._prev_components = None
@@ -177,8 +173,8 @@ def __init__(
177173
regularization_term = 0.5 * rho * np.linalg.norm(self._spline_smooth_operator @ self.stretch.T, "fro") ** 2
178174
sparsity_term = eta * np.sum(np.sqrt(self.components)) # Square root penalty
179175
print(
180-
f"Start, Objective function: {self.objective_function:.5e}"
181-
f", Obj - reg/sparse: {self.objective_function - regularization_term - sparsity_term:.5e}"
176+
f"Start, Objective function: {self._objective_history[-1]:.5e}"
177+
f", Obj - reg/sparse: {self._objective_history[-1] - regularization_term - sparsity_term:.5e}"
182178
)
183179

184180
# Main optimization loop
@@ -191,15 +187,15 @@ def __init__(
191187
sparsity_term = eta * np.sum(np.sqrt(self.components)) # Square root penalty
192188
print(
193189
f"Num_updates: {self.num_updates}, "
194-
f"Obj fun: {self.objective_function:.5e}, "
195-
f"Obj - reg/sparse: {self.objective_function - regularization_term - sparsity_term:.5e}, "
190+
f"Obj fun: {self._objective_history[-1]:.5e}, "
191+
f"Obj - reg/sparse: {self._objective_history[-1] - regularization_term - sparsity_term:.5e}, "
196192
f"Iter: {iter}"
197193
)
198194

199195
# Convergence check: decide when to terminate for small/no improvement
200-
print(self.objective_difference, " < ", self.objective_function * tol)
201-
if self.objective_difference < self.objective_function * tol and iter >= 20:
196+
if self.objective_difference < self._objective_history[-1] * tol and iter >= 20:
202197
break
198+
print(self.objective_difference, " < ", self._objective_history[-1] * tol)
203199

204200
# Normalize our results
205201
weights_row_max = np.max(self.weights, axis=1, keepdims=True)
@@ -214,17 +210,17 @@ def __init__(
214210
self.grad_components = np.zeros_like(self.components)
215211
self._prev_grad_components = np.zeros_like(self.components)
216212
self.residuals = self.get_residual_matrix()
217-
self.objective_function = self.get_objective_function()
218213
self.objective_difference = None
219-
self._objective_history = [self.objective_function]
214+
self._objective_history = []
215+
self.update_objective()
220216
for norm_iter in range(100):
221217
self.update_components()
222218
self.residuals = self.get_residual_matrix()
223-
self.objective_function = self.get_objective_function()
224-
print(f"Objective function after normX: {self.objective_function:.5e}")
225-
self._objective_history.append(self.objective_function)
219+
self.update_objective()
220+
print(f"Objective function after normalize_components: {self._objective_history[-1]:.5e}")
221+
self._objective_history.append(self._objective_history[-1])
226222
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]
227-
if self.objective_difference < self.objective_function * tol and norm_iter >= 20:
223+
if self.objective_difference < self._objective_history[-1] * tol and norm_iter >= 20:
228224
break
229225
# end of normalization (and program)
230226
# note that objective function may not fully recover after normalization, this is okay
@@ -238,29 +234,27 @@ def optimize_loop(self):
238234

239235
self.num_updates += 1
240236
self.residuals = self.get_residual_matrix()
241-
self.objective_function = self.get_objective_function()
242-
print(f"Objective function after update_components: {self.objective_function:.5e}")
243-
self._objective_history.append(self.objective_function)
237+
self.update_objective()
238+
print(f"Objective function after update_components: {self._objective_history[-1]:.5e}")
244239

245240
if self.objective_difference is None:
246-
self.objective_difference = self._objective_history[-1] - self.objective_function
241+
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]
247242

248243
# Now we update weights
249244
self.update_weights()
250245

251246
self.num_updates += 1
252247
self.residuals = self.get_residual_matrix()
253-
self.objective_function = self.get_objective_function()
254-
print(f"Objective function after update_weights: {self.objective_function:.5e}")
255-
self._objective_history.append(self.objective_function)
248+
self.update_objective()
249+
print(f"Objective function after update_weights: {self._objective_history[-1]:.5e}")
256250

257251
# Now we update stretch
258252
self.update_stretch()
253+
259254
self.num_updates += 1
260255
self.residuals = self.get_residual_matrix()
261-
self.objective_function = self.get_objective_function()
262-
print(f"Objective function after update_stretch: {self.objective_function:.5e}")
263-
self._objective_history.append(self.objective_function)
256+
self.update_objective()
257+
print(f"Objective function after update_stretch: {self._objective_history[-1]:.5e}")
264258

265259
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]
266260

@@ -333,7 +327,8 @@ def get_residual_matrix(self, components=None, weights=None, stretch=None):
333327
residuals += weights[k, :] * stretched_components # Element-wise scaling and sum
334328
return residuals
335329

336-
def get_objective_function(self, residuals=None, stretch=None):
330+
def update_objective(self, residuals=None, stretch=None):
331+
to_return = not (residuals is None and stretch is None)
337332
if residuals is None:
338333
residuals = self.residuals
339334
if stretch is None:
@@ -343,7 +338,11 @@ def get_objective_function(self, residuals=None, stretch=None):
343338
sparsity_term = self.eta * np.sum(np.sqrt(self.components)) # Square root penalty
344339
# Final objective function value
345340
function = residual_term + regularization_term + sparsity_term
346-
return function
341+
342+
if to_return:
343+
return function # Get value directly for use
344+
else:
345+
self._objective_history.append(function) # Store value
347346

348347
def apply_interpolation_matrix(self, components=None, weights=None, stretch=None, return_derivatives=False):
349348
"""
@@ -595,7 +594,7 @@ def update_components(self):
595594
)
596595
self.components = mask * self.components
597596

598-
objective_improvement = self._objective_history[-1] - self.get_objective_function(
597+
objective_improvement = self._objective_history[-1] - self.update_objective(
599598
residuals=self.get_residual_matrix()
600599
)
601600

@@ -650,7 +649,7 @@ def regularize_function(self, stretch=None):
650649
stretch_difference = stretch_difference - self.source_matrix
651650

652651
# Compute objective function
653-
reg_func = self.get_objective_function(stretch_difference, stretch)
652+
reg_func = self.update_objective(stretch_difference, stretch)
654653

655654
# Compute gradient
656655
tiled_derivative = np.sum(

0 commit comments

Comments
 (0)