Skip to content

Commit b0e44f8

Browse files
author
John Halloran
committed
fix: use symmetric initial phase fractions
1 parent 27ea989 commit b0e44f8

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(
134134
# Initialize weights and determine number of components
135135
if init_weights is None:
136136
self.n_components = n_components
137-
self.weights = self._rng.beta(a=2.5, b=1.5, size=(self.n_components, self.n_signals))
137+
self.weights = self._rng.beta(a=2.0, b=2.0, size=(self.n_components, self.n_signals))
138138
else:
139139
self.n_components = init_weights.shape[0]
140140
self.weights = init_weights
@@ -169,7 +169,7 @@ def __init__(
169169
self.objective_difference = None
170170
self._objective_history = [self.objective_function]
171171

172-
# Set up tracking variables for updateX()
172+
# Set up tracking variables for update_components()
173173
self._prev_components = None
174174
self.grad_components = np.zeros_like(self.components) # Gradient of X (zeros for now)
175175
self._prev_grad_components = np.zeros_like(self.components) # Previous gradient of X (zeros for now)
@@ -233,17 +233,21 @@ def __init__(
233233
def optimize_loop(self):
234234
# Update components first
235235
self._prev_grad_components = self.grad_components.copy()
236+
236237
self.update_components()
238+
237239
self.num_updates += 1
238240
self.residuals = self.get_residual_matrix()
239241
self.objective_function = self.get_objective_function()
240242
print(f"Objective function after update_components: {self.objective_function:.5e}")
241243
self._objective_history.append(self.objective_function)
244+
242245
if self.objective_difference is None:
243246
self.objective_difference = self._objective_history[-1] - self.objective_function
244247

245248
# Now we update weights
246249
self.update_weights()
250+
247251
self.num_updates += 1
248252
self.residuals = self.get_residual_matrix()
249253
self.objective_function = self.get_objective_function()
@@ -257,6 +261,7 @@ def optimize_loop(self):
257261
self.objective_function = self.get_objective_function()
258262
print(f"Objective function after update_stretch: {self.objective_function:.5e}")
259263
self._objective_history.append(self.objective_function)
264+
260265
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]
261266

262267
def apply_interpolation(self, a, x, return_derivatives=False):

0 commit comments

Comments
 (0)