Skip to content

Commit a093cdc

Browse files
use eye
1 parent eddbb78 commit a093cdc

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

unsupervised_class/gmm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def gmm(X, K, max_iter=20):
1414
# initialize M to random, initialize C to spherical with variance 1
1515
for k in xrange(K):
1616
M[k] = X[np.random.choice(N)]
17-
C[k] = np.diag(np.ones(D))
17+
C[k] = np.eye(D)
1818

1919
costs = np.zeros(max_iter)
2020
weighted_pdfs = np.zeros((N, K)) # we'll use these to store the PDF value of sample n and Gaussian k
@@ -33,7 +33,7 @@ def gmm(X, K, max_iter=20):
3333
Nk = R[:,k].sum()
3434
pi[k] = Nk / N
3535
M[k] = R[:,k].dot(X) / Nk
36-
C[k] = np.sum(R[n,k]*np.outer(X[n] - M[k], X[n] - M[k]) for n in xrange(N)) / Nk + np.diag(np.ones(D)*0.001)
36+
C[k] = np.sum(R[n,k]*np.outer(X[n] - M[k], X[n] - M[k]) for n in xrange(N)) / Nk + np.eye(D)*0.001
3737

3838

3939
costs[i] = np.log(weighted_pdfs.sum(axis=1)).sum()

0 commit comments

Comments
 (0)