-
Notifications
You must be signed in to change notification settings - Fork 141
/
test_transform.py
68 lines (56 loc) · 2.59 KB
/
test_transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# -*- coding: utf-8 -*-
import itertools
import nltk
from supar.models.const.crf.transform import Tree
from supar.models.dep.biaffine.transform import CoNLL
class TestCoNLL:
def istree_naive(self, sequence, proj=False, multiroot=True):
if proj and not CoNLL.isprojective(sequence):
return False
roots = [i for i, head in enumerate(sequence, 1) if head == 0]
if len(roots) == 0:
return False
if len(roots) > 1 and not multiroot:
return False
sequence = [-1] + sequence
def track(sequence, visited, i):
if visited[i]:
return False
visited[i] = True
for j, head in enumerate(sequence[1:], 1):
if head == i:
track(sequence, visited, j)
return True
visited = [False]*len(sequence)
for root in roots:
if not track(sequence, visited, root):
return False
if any([not i for i in visited[1:]]):
return False
return True
def test_isprojective(self):
assert CoNLL.isprojective([2, 4, 2, 0, 5])
assert CoNLL.isprojective([3, -1, 0, -1, 3])
assert not CoNLL.isprojective([2, 4, 0, 3, 4])
assert not CoNLL.isprojective([4, -1, 0, -1, 4])
assert not CoNLL.isprojective([2, -1, -1, 1, 0])
assert not CoNLL.isprojective([0, 5, -1, -1, 4])
def test_istree(self):
permutations = [list(sequence[:5]) for sequence in itertools.permutations(range(6))]
for sequence in permutations:
assert CoNLL.istree(sequence, False, False) == self.istree_naive(sequence, False, False), f"{sequence}"
assert CoNLL.istree(sequence, False, True) == self.istree_naive(sequence, False, True), f"{sequence}"
assert CoNLL.istree(sequence, True, False) == self.istree_naive(sequence, True, False), f"{sequence}"
assert CoNLL.istree(sequence, True, True) == self.istree_naive(sequence, True, True), f"{sequence}"
class TestTree:
def test_tree(self):
tree = nltk.Tree.fromstring("""
(TOP
(S
(NP (DT This) (NN time))
(, ,)
(NP (DT the) (NNS firms))
(VP (VBD were) (ADJP (JJ ready)))
(. .)))
""")
assert tree == Tree.build(tree, Tree.factorize(Tree.binarize(tree)[0]))