Skip to content

Commit dc2cd50

Browse files
Merge pull request #156 from hyperion-ml/persephone-refactor
debug gmm
2 parents 05843b7 + 44f7abb commit dc2cd50

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

hyperion/np/pdfs/mixtures/exp_family_mixture.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def fit(
106106
elbo_val[epoch] = self.elbo(None, N=N, u_x=u_x, log_h=log_h_val)
107107

108108
print(
109-
elbo[epoch],
109+
elbo[epoch] / x.shape[0],
110110
np.mean(self.log_prob(x, mode="nat")),
111111
np.mean(self.log_prob(x, mode="std")),
112112
)
@@ -210,7 +210,6 @@ def _accum_suff_stats_1batch(self, x, u_x=None, sample_weight=None):
210210

211211
N = np.sum(z, axis=0)
212212
acc_u_x = np.dot(z.T, u_x)
213-
# L_z=gmm.ElnP_z_w(N,gmm.lnw)-gmm.Elnq_z(z);
214213
return N, acc_u_x
215214

216215
def _accum_suff_stats_nbatches(self, x, sample_weight, batch_size):
@@ -473,8 +472,8 @@ def sum_suff_stats(self, N, u_x):
473472
Accumalted N and u_x.
474473
"""
475474
assert len(N) == len(u_x)
476-
acc_N = N[1]
477-
acc_u_x = u_x[1]
475+
acc_N = N[0]
476+
acc_u_x = u_x[0]
478477
for i in range(1, len(N)):
479478
acc_N += N[i]
480479
acc_u_x += u_x[i]

0 commit comments

Comments
 (0)