Skip to content

Commit

Permalink
TODO: Pickle closures.
Browse files Browse the repository at this point in the history
  • Loading branch information
shawntan committed Jun 27, 2013
1 parent 5eec5f1 commit 6fb39cf
Show file tree
Hide file tree
Showing 4 changed files with 859 additions and 53 deletions.
5 changes: 2 additions & 3 deletions crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def predict(self,x_vec, debug=False):
"""
all_features = self.all_features(x_vec)
log_potential = np.dot(all_features,self.theta)
return [ self.labels[i] for i in self.log_predict(log_potential,len(x_vec),len(self.labels)) ]
return [ self.labels[i] for i in self.slow_predict(log_potential,len(x_vec),len(self.labels)) ]

def slow_predict(self,log_potential,N,K,debug=False):
"""
Expand Down Expand Up @@ -214,8 +214,7 @@ def train(self,x_vecs,y_vecs,debug=False):
vectorised_x_vecs,vectorised_y_vecs = self.create_vector_list(x_vecs,y_vecs)
l = lambda theta: self.neg_likelihood_and_deriv(vectorised_x_vecs,vectorised_y_vecs,theta)
val = optimize.fmin_l_bfgs_b(l,self.theta)
if debug:
print val
if debug: print val
self.theta,_,_ = val
return self.theta

Expand Down
58 changes: 34 additions & 24 deletions features.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re
from collections import defaultdict

alphas = re.compile('^[a-zA-Z]+$')
def listify(gen):
"Convert a generator into a function which returns a list"
def patched(*args, **kwargs):
Expand All @@ -17,36 +19,44 @@ def fit_dataset(filename):
for line in open(filename,'r'):
sent_words = []
sent_labels = []
for token in line.strip().split():
word,label= token.split('/')
word = word.lower()
labels.add(label)
obsrvs.add(word)
word_sets[label].add(word)
sent_words.append(word)
sent_labels.append(label)
sents_words.append(sent_words)
sents_labels.append(sent_labels)
try:
for token in line.strip().split():
word,label= token.rsplit('/',2)
if alphas.match(word):
orig_word = word
word = word.lower()
labels.add(label)
obsrvs.add(word)
word_sets[label].add(word)
sent_words.append(orig_word)
sent_labels.append(label)
else:
continue
sents_words.append(sent_words)
sents_labels.append(sent_labels)
except Exception:
print line
return (labels,obsrvs,word_sets,sents_words,sents_labels)

@listify
def set_membership(word_sets):
for tag in word_sets:
def fun(yp,y,x_v,i):

if i < len(x_v):
if x_v[i].lower() in word_sets[tag]:
def set_membership(labels,*word_sets):
for lbl in labels:
for ws in word_sets:
def fun(yp,y,x_v,i,lbl=lbl,s=ws):
if i < len(x_v) and y==lbl and (x_v[i].lower() in s):
#print lbl, ws,x_v[i]
return 1
else: return 0
yield fun
else: return 0
yield fun

@listify
def match_regex(*regexps):
def match_regex(labels,*regexps):
for regexp in regexps:
p = re.compile(regexp)
def fun(yp,y,x_v,i):
if i < len(x_v) and p.match(x_v[i]): return 1
else: return 0
yield fun
for lbl in labels:
def fun(yp,y,x_v,i,lbl=lbl,p=p):
if i < len(x_v) and y==lbl and p.match(x_v[i]):
return 1
else: return 0
yield fun

regex_functions= ['^[^0-9a-zA-Z]+$','^[A-Z\.]+$','^[0-9\.]+$']
Loading

0 comments on commit 6fb39cf

Please sign in to comment.