@@ -76,10 +76,11 @@ def stochasticEM(self, max_iter, docs, responses):
76
76
self .doc_topic_sum [di ,old_topic ] -= 1
77
77
78
78
z_bar = np .zeros ([self .K ,self .K ]) + self .doc_topic_sum [di ,:] + np .identity (self .K )
79
- z_bar /= z_bar .sum (1 )
79
+ # this seems more straightforward than z_bar/z_bar.sum(1)
80
+ z_bar /= self .doc_topic_sum [di ,:].sum () + 1
80
81
81
82
#update
82
- prob = (self .WK [word , :])/ (self .sumK [:]) * (self .doc_topic_sum [di ,:]) * np .exp (np .negative ((responses [di ] - np .dot (z_bar . T ,self .eta ))** 2 )/ 2 / self .sigma )
83
+ prob = (self .WK [word , :])/ (self .sumK [:]) * (self .doc_topic_sum [di ,:]) * np .exp (np .negative ((responses [di ] - np .dot (z_bar ,self .eta ))** 2 )/ 2 / self .sigma )
83
84
84
85
new_topic = sampling_from_dist (prob )
85
86
@@ -90,7 +91,7 @@ def stochasticEM(self, max_iter, docs, responses):
90
91
91
92
#estimate parameters
92
93
z_bar = self .doc_topic_sum / self .doc_topic_sum .sum (1 )[:,np .newaxis ] # DxK
93
- self .eta = solve (np .dot (z_bar .T ,z_bar ), np .dot (z_bar . T , responses ) )
94
+ self .eta = solve (np .dot (z_bar .T ,z_bar ), np .dot (z_bar , responses ) )
94
95
95
96
#compute mean absolute error
96
97
mae = np .abs (responses - np .dot (z_bar , self .eta )).sum ()
0 commit comments