Skip to content

Commit

Permalink
Modified crf
Browse files Browse the repository at this point in the history
  • Loading branch information
shawntan committed Jun 19, 2013
1 parent 95e52bf commit c22249b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
33 changes: 30 additions & 3 deletions crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,38 @@ def predict(self,x_vec, debug=False):
"""
all_features = self.all_features(x_vec)
log_potential = np.dot(all_features,self.theta)
N = len(x_vec)
return self.log_predict(log_potential,N,K)
return [ self.labels[i] for i in self._predict(log_potential,len(x_vec),len(self.labels)) ]

def _predict(self,log_potential,N,K,debug=False):
"""
Find the most likely assignment to labels given parameters using the
Viterbi algorithm.
"""
g0 = log_potential[0,0]
g = log_potential[1:]

B = np.ones((N,K), dtype=np.int32) * -1
# compute max-marginals and backtrace matrix
V = g0
for t in xrange(1,N):
U = np.empty(K)
for y in xrange(K):
w = V + g[t-1,:,y]
B[t,y] = b = w.argmax()
U[y] = w[b]
V = U
# extract the best path by brack-tracking
y = V.argmax()
trace = []
for t in reversed(xrange(N)):
trace.append(y)
y = B[t, y]
trace.reverse()
return trace



def log_predict(self,log_potential,N,K):
def log_predict(self,log_potential,N,K,debug=False):
if debug:
print
print
Expand Down
13 changes: 9 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class TestCRF(unittest.TestCase):
def setUp(self):
self.matrix = 0.001 + np.random.poisson(lam=1.5, size=(3,3)).astype(np.float)
self.vector = 0.001 + np.random.poisson(lam=1.5, size=(3,)).astype(np.float)
self.M = 0.001 + np.random.poisson(lam=1.5, size=(3,3,3)).astype(np.float)
self.M = 0.001 + np.random.poisson(lam=1.5, size=(10,3,3)).astype(np.float)
self.crf = crf.CRF([],[])

def test_log_dot_mv(self):
Expand Down Expand Up @@ -138,9 +138,14 @@ def test_forward(self):
self.assertTrue((res == res_true).all())

def test_predict(self):
print self.log_predict(self.M,self.M.shape[0],self.M.shape[1])
print argmax(self.M)
label_pred = self.crf.log_predict(self.M,self.M.shape[0],self.M.shape[1])
label_act = argmax(self.M)
print label_pred
print label_act
self.assertTrue(label_pred == label_act)


if __name__ == '__main__':
unittest.main()
for i in range(10):
unittest.main()

0 comments on commit c22249b

Please sign in to comment.