Skip to content

Commit 39bd273

Browse files
committed
correcting some errors in sampling
1 parent 5c03142 commit 39bd273

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

ptm/slda_gibbs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,11 @@ def stochasticEM(self, max_iter, docs, responses):
7676
self.doc_topic_sum[di,old_topic] -= 1
7777

7878
z_bar = np.zeros([self.K,self.K]) + self.doc_topic_sum[di,:] + np.identity(self.K)
79-
z_bar /= z_bar.sum(1)
79+
# this seems more straightforward than z_bar/z_bar.sum(1)
80+
z_bar /= self.doc_topic_sum[di,:].sum() + 1
8081

8182
#update
82-
prob = (self.WK[word, :])/(self.sumK[:]) * (self.doc_topic_sum[di,:]) * np.exp(np.negative((responses[di] - np.dot(z_bar.T,self.eta))**2)/2/self.sigma)
83+
prob = (self.WK[word, :])/(self.sumK[:]) * (self.doc_topic_sum[di,:]) * np.exp(np.negative((responses[di] - np.dot(z_bar,self.eta))**2)/2/self.sigma)
8384

8485
new_topic = sampling_from_dist(prob)
8586

@@ -90,7 +91,7 @@ def stochasticEM(self, max_iter, docs, responses):
9091

9192
#estimate parameters
9293
z_bar = self.doc_topic_sum / self.doc_topic_sum.sum(1)[:,np.newaxis] # DxK
93-
self.eta = solve(np.dot(z_bar.T,z_bar), np.dot(z_bar.T, responses) )
94+
self.eta = solve(np.dot(z_bar.T,z_bar), np.dot(z_bar, responses) )
9495

9596
#compute mean absolute error
9697
mae = np.abs(responses - np.dot(z_bar, self.eta)).sum()

0 commit comments

Comments
 (0)