-
Notifications
You must be signed in to change notification settings - Fork 0
/
tb.py
423 lines (294 loc) · 11.9 KB
/
tb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
"""tb.py reads, searches and displays trees from Penn Treebank (PTB) format
treebank files.
Mark Johnson, 14th January, 2012, last modified 19th January 2016
Trees are represented in Python as nested list structures in the following
format:
Terminal nodes are represented by strings.
Nonterminal nodes are represented by lists. The first element of
the list is the node's label (a string), and the remaining elements
of the list are lists representing the node's children.
This module also defines two regular expressions.
nonterm_rex matches Penn treebank nonterminal labels, and parses them into
their various parts.
empty_re matches empty elements (terminals), and parses them into their
various parts.
"""
import collections, glob, re, sys
_header_re = re.compile(r"(\*x\*.*\*x\*[ \t]*\n)*\s*")
_openpar_re = re.compile(r"\s*\(\s*([^ \t\n\r\f\v()]*)\s*")
_closepar_re = re.compile(r"\s*\)\s*")
_terminal_re = re.compile(r"\s*([^ \t\n\r\f\v()]*)\s*")
# This is such a complicated regular expression that I use the special
# "verbose" form of regular expressions, which lets me index and document it
#
nonterm_rex = re.compile(r"""
^(?P<CAT>[A-Z0-9$|^]+) # category comes first
(?: # huge disjunct of optional annotations
- (?:(?P<FORMFUN>ADV|NOM) # stuff beginning with -
|(?P<GROLE>DTV|LGS|PRD|PUT|SBJ|TPC|VOC)
|(?P<ADV>BNF|DIR|EXT|LOC|MNR|PRP|TMP)
|(?P<MISC>CLR|CLF|HLN|SEZ|TTL)
|(?P<TPC>TPC)
|(?P<DYS>UNF|ETC|IMP)
|(?P<INDEX>[0-9]+)
)
| = (?P<EQINDEX>[0-9]+) # stuff beginning with =
)* # Kleene star
$""", re.VERBOSE)
empty_re = re.compile(r"^(?P<CAT>[A-Z0-9\?\*]+)(?:-(?P<INDEX>\d+))")
def read_file(filename):
"""Returns the trees in the PTB file filename."""
filecontents = open(filename, "rU").read()
pos = _header_re.match(filecontents).end()
trees = []
_string_trees(trees, filecontents, pos)
return trees
def string_trees(s):
"""Returns a list of the trees in PTB-format string s"""
trees = []
_string_trees(trees, s)
return trees
def _string_trees(trees, s, pos=0):
"""Reads a sequence of trees in string s[pos:].
Appends the trees to the argument trees.
Returns the ending position of those trees in s."""
while pos < len(s):
closepar_mo = _closepar_re.match(s, pos)
if closepar_mo:
return closepar_mo.end()
openpar_mo = _openpar_re.match(s, pos)
if openpar_mo:
tree = [openpar_mo.group(1)]
trees.append(tree)
pos = _string_trees(tree, s, openpar_mo.end())
else:
terminal_mo = _terminal_re.match(s, pos)
trees.append(terminal_mo.group(1))
pos = terminal_mo.end()
return pos
def make_nonterminal(label, children):
"""returns a tree node with root node label and children"""
return [label]+children
def make_terminal(word):
"""returns a terminal tree node with label word"""
return word
def make_preterminal(label, word):
"""returns a preterminal node with label for word"""
return [label, word]
def is_terminal(subtree):
"""True if this subtree consists of a single terminal node
(i.e., a word or an empty node)."""
return not isinstance(subtree, list)
def is_nonterminal(subtree):
"""True if this subtree does not consist of a single terminal node
(i.e., a word or an empty node)."""
return isinstance(subtree, list)
def is_preterminal(subtree):
"""True if the treebank subtree is rooted in a preterminal node
(i.e., is an empty node or dominates a word)."""
return isinstance(subtree, list) and len(subtree) == 2 and is_terminal(subtree[1])
def is_phrasal(subtree):
"""True if this treebank subtree is not a terminal or a preterminal node."""
return isinstance(subtree, list) and \
(len(subtree) == 1 or isinstance(subtree[1], list))
def is_punctuation(subtree):
"""True if this subtree is a preterminal node dominating a punctuation or
empty node."""
return is_preterminal(subtree) and \
tree_category(subtree) in ("''",":","#",",",".","``","-LRB-","-RRB-","-NONE-")
def tree_children(tree):
"""Returns the children subtrees of tree"""
if isinstance(tree, list):
return tree[1:]
else:
return []
def tree_label(tree):
"""Returns the label on the root node of tree."""
if isinstance(tree, list):
return tree[0]
else:
return tree
def label_category(label):
"""Returns the category part of a label."""
nonterm_mo = nonterm_rex.match(label)
if nonterm_mo:
return nonterm_mo.group('CAT')
else:
return label
def tree_category(tree):
"""Returns the category of the root node of tree."""
if isinstance(tree, list):
return label_category(tree[0])
else:
return tree
def map_labels(tree, fn):
"""Returns a tree in which every node's label is mapped by fn"""
if isinstance(tree, list):
return [fn(tree[0])]+[map_labels(child,fn) for child in tree[1:]]
else:
return tree
def label_noindices(label):
"""Removes indices in label if present"""
label_mo = nonterm_rex.match(label)
if label_mo:
start = max(label_mo.end('INDEX'), label_mo.end('EQINDEX'))
if start > 1:
return label[:start-2]
return label
def tree_children(tree):
"""Returns a list of the subtrees of tree."""
if isinstance(tree, list):
return tree[1:]
else:
return []
def tree_copy(tree):
"""Returns a deep copy of tree"""
if isinstance(tree, list):
return [tree_copy(child) for child in tree]
else:
return tree
def prune(tree, remove_empty=False, collapse_unary=False, binarise=False,
relabel=lambda x: x,
binlabelf=lambda labels: '+'.join(labels)):
"""Returns a copy of tree without empty nodes, unary nodes or node indices.
If binarise=='right' then right-binarise nodes, otherwise
if binarise is not False then left-binarise nodes.
binlabelf() maps a sequence of child node labels to the label for
a new binarised node.
"""
def left_binarise(cs, rightpos):
label = binlabelf(tree_label(cs[i]) for i in range(rightpos))
if rightpos <= 2:
return make_nonterminal(label, cs[:rightpos])
else:
return make_nonterminal(label, [left_binarise(cs, rightpos-1),cs[rightpos-1]])
def right_binarise(cs, leftpos, len_cs):
label = binlabelf(tree_label(c) for c in cs[leftpos:])
if leftpos + 2 >= len_cs:
return make_nonterminal(label, cs[leftpos:])
else:
return make_nonterminal(label, [cs[leftpos], right_binarise(cs, leftpos+1, len_cs)])
label = tree_label(tree)
if is_phrasal(tree):
cs = (prune(c, remove_empty, collapse_unary, binarise, relabel, binlabelf)
for c in tree_children(tree))
cs = [c for c in cs if c]
if cs or not remove_empty:
len_cs = len(cs)
if collapse_unary and len_cs == 1:
return make_nonterminal(relabel(label),
tree_children(cs[0]))
elif binarise and len_cs > 2:
if binarise=='right':
return make_nonterminal(relabel(label),
[cs[0], right_binarise(cs, 1, len_cs)])
else:
return make_nonterminal(relabel(label),
[left_binarise(cs, len_cs-1), cs[-1]])
else:
return make_nonterminal(relabel(label),
cs)
else:
return None
elif is_preterminal(tree):
if remove_empty and label == '-NONE-':
return None
else:
return make_nonterminal(relabel(label), tree_children(tree))
else:
return tree
def tree_nodes(tree):
"""Yields the nodes in tree"""
def visit(node):
yield node
if isinstance(node, list):
for child in node[1:]:
yield from visit(child)
yield from visit(tree)
def tree_terminals(tree):
"""Yields the terminal leaves of tree"""
def visit(node):
if isinstance(node, list):
for child in node[1:]:
yield from visit(child)
else:
yield node
yield from visit(tree)
def tree_preterminals(tree):
"""Yields the preterminal nodes of tree"""
def visit(node):
if is_preterminal(node):
yield node
else:
for child in node[1:]:
yield from visit(child)
yield from visit(tree)
def tree_phrasalnodes(tree):
"""Yields the phrasal (i.e., nonterminal and non-preterminal) nodes of tree"""
def visit(node):
if is_phrasal(node):
yield node
for child in node[1:]:
yield from visit(child)
yield from visit(tree)
def tree_constituents(tree, collect_root=False, collect_terminals=False,
collect_preterminals=False, ignore_punctuation=False):
"""maps a tree to a list of tuples (category,left,right) that
correspond to constituents of the tree.
If collect_root==True, then the list of tuples includes a tuple
for the root node of the tree.
If collect_terminals==True, then the list of tuples includes tuples
for the terminal nodes of the tree.
If collect_preterminals==True, then the list of tuples includes tuples
for the preterminal nodes of the tree.
If ignore_punctuation==True, then the left and right positions ignore
punctuation.
"""
def visitor(node, left, constituents):
if ignore_punctuation and is_punctuation(node):
return left
if is_terminal(node):
if collect_terminals:
constituents.append((tree_category(node),left,left+1))
return left+1
else:
right = left
for child in tree_subtrees(node):
right = visitor(child, right, constituents)
if collect_preterminals or is_phrasal(node):
constituents.append((tree_category(node),left,right))
return right
constituents = []
if collect_root:
visitor(tree, 0, constituents)
else:
right = 0
for child in tree_subtrees(tree):
right = visitor(child, right, constituents)
return constituents
def write(tree, outf=sys.stdout):
"""Write a tree to outf"""
if is_nonterminal(tree):
outf.write('(')
for i in range(0,len(tree)):
if i > 0:
outf.write(' ')
write(tree[i], outf)
outf.write(')')
else:
outf.write(tree)
def read_ptb(basedir="/usr/local/data/LDC/LDC2015T13_eng_news_txt_tbnk-ptb_revised/",
remove_empty=True, collapse_unary=False, binarise=False, relabel=label_category):
"""Returns a tuple (train,dev,test) of the trees in 2015 PTB. train, dev and test are generators
that enumerate the trees in each section"""
def _read_ptb(dirs):
for p in dirs:
for fname in sorted(glob.glob(basedir+p)):
for tree in read_file(fname):
yield prune(tree[1], remove_empty, collapse_unary, binarise, relabel)
ptb = collections.namedtuple('ptb', 'train dev test')
return ptb(train=_read_ptb(("data/penntree/0[2-9]/wsj*.tree",
"data/penntree/1[2-9]/wsj*.tree",
"data/penntree/2[01]/wsj*.tree")),
dev=_read_ptb(("data/penntree/24/wsj*.tree",)),
test=_read_ptb(("data/penntree/23/wsj*.tree",)))