From c22249be99331e69a705b4a2bf261f91435d991b Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Wed, 19 Jun 2013 15:28:27 +0800 Subject: [PATCH] Modified crf --- crf.py | 33 ++++++++++++++++++++++++++++++--- test.py | 13 +++++++++---- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/crf.py b/crf.py index eaab611..20261ad 100644 --- a/crf.py +++ b/crf.py @@ -138,11 +138,38 @@ def predict(self,x_vec, debug=False): """ all_features = self.all_features(x_vec) log_potential = np.dot(all_features,self.theta) - N = len(x_vec) - return self.log_predict(log_potential,N,K) + return [ self.labels[i] for i in self._predict(log_potential,len(x_vec),len(self.labels)) ] + + def _predict(self,log_potential,N,K,debug=False): + """ + Find the most likely assignment to labels given parameters using the + Viterbi algorithm. + """ + g0 = log_potential[0,0] + g = log_potential[1:] + + B = np.ones((N,K), dtype=np.int32) * -1 + # compute max-marginals and backtrace matrix + V = g0 + for t in xrange(1,N): + U = np.empty(K) + for y in xrange(K): + w = V + g[t-1,:,y] + B[t,y] = b = w.argmax() + U[y] = w[b] + V = U + # extract the best path by brack-tracking + y = V.argmax() + trace = [] + for t in reversed(xrange(N)): + trace.append(y) + y = B[t, y] + trace.reverse() + return trace + - def log_predict(self,log_potential,N,K): + def log_predict(self,log_potential,N,K,debug=False): if debug: print print diff --git a/test.py b/test.py index 1b393f6..4b30d41 100644 --- a/test.py +++ b/test.py @@ -108,7 +108,7 @@ class TestCRF(unittest.TestCase): def setUp(self): self.matrix = 0.001 + np.random.poisson(lam=1.5, size=(3,3)).astype(np.float) self.vector = 0.001 + np.random.poisson(lam=1.5, size=(3,)).astype(np.float) - self.M = 0.001 + np.random.poisson(lam=1.5, size=(3,3,3)).astype(np.float) + self.M = 0.001 + np.random.poisson(lam=1.5, size=(10,3,3)).astype(np.float) self.crf = crf.CRF([],[]) def test_log_dot_mv(self): @@ -138,9 +138,14 @@ def test_forward(self): self.assertTrue((res == res_true).all()) def test_predict(self): - print self.log_predict(self.M,self.M.shape[0],self.M.shape[1]) - print argmax(self.M) + label_pred = self.crf.log_predict(self.M,self.M.shape[0],self.M.shape[1]) + label_act = argmax(self.M) + print label_pred + print label_act + self.assertTrue(label_pred == label_act) + if __name__ == '__main__': - unittest.main() + for i in range(10): + unittest.main()