-
Notifications
You must be signed in to change notification settings - Fork 0
/
miamiCorpusPOS.py
155 lines (131 loc) · 5.19 KB
/
miamiCorpusPOS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import posCall
import eval
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
# model_name = '/scratch/gpfs/ca2992/robertuito-base-cased'
model_name = '/scratch/gpfs/ca2992/codeswitch-spaeng-pos-lince'
tokenizer_name = '/scratch/gpfs/ca2992/codeswitch-spaeng-lid-lince'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
out_dir = '/scratch/gpfs/ca2992/jpLLM/jpLLM/pos_dict_out'
data_dir = '/scratch/gpfs/ca2992/jpLLM/bangor/crowdsourced_bangor'
pos_model = pipeline('ner', model=model, tokenizer=tokenizer)
pos_truth = []
pos_pred = []
lid_truth = []
# given a token with the '#' symbol,
# remove the symbol for preprocessing
def cleanPoundSign(word):
tempTok = ""
for i in range(len(word)):
if (word[i] != '#'):
tempTok = tempTok + word[i]
return tempTok
# words in the annotated Bangor Corpus
# contain ' if a contraction. Check to allow
# concatenation
def isContraction(word):
for char in word:
if (char == '\''):
return True
return False
# convert token predictions to word predictions
def tokenToWordPred(message, trueWords):
posResult = pos_model(message)
index = 0
for word in trueWords:
posToken = posResult[index].get('word')
# get the pos predicted for this token and append
# to the pos word level predictions
pos = posResult[index].get('entity')
pos_pred.append([pos])
# if token word mismatch impossible to handle
if (word != posToken and word[0] != posToken[0]):
print("MISMATCH", word, posToken)
continue
while (word != posToken and word[0] == posToken[0]):
index += 1
posToken = posToken + posResult[index].get('word')
# get rid of # symbols added by tokenizer
posToken = cleanPoundSign(posToken)
index += 1
with open(out_dir, "a") as output:
for file in os.listdir(data_dir):
if os.path.isdir(data_dir + '/' + file):
# Skip directories and readme
continue
if(file == "README.md"):
continue
# open the current file in the directory
with open(data_dir + '/' + file, "r") as read:
numWords = 0
words = []
message = ""
for line in read:
values = line.split()
# skip blank lines or placeholder lines
if (len(values) <= 3):
# print(line)
continue
# print(values[0], values[1], values[2], values[3])
# print(line)
pos = values[3] #pos at index 3 of each line
lid = values[2] # lid at index 2 of each line
word = values[1] # word at index 1 of each line
numWords += 1
# print(pos)
if isContraction(word):
# if is a contraction, implicitly use last truth tag
message = message + word
lastWord = words.pop()
words.append(lastWord + word)
else:
# if it is not a contraction, use the truth tag
message = message + " " + word
words.append(word)
pos_truth.append([pos])
lid_truth.append([lid])
# at the end of each sentence, pass into the model
if (word == '.'):
tokenToWordPred(message, words)
numWords = 0
pos = []
lid = []
words = []
message = ""
# get any remaining tokens/words and analyze them
if (len(message) != 0):
tokenToWordPred(message, words)
numWords = 0
words = []
message = ""
# after each file, length of pos_truth == lngth of pos_pred
assert len(pos_truth) == len(pos_pred)
# print(pos_truth, file = output)
# print(pos_pred, file = output)
# print(len(pos_truth), len(pos_pred), file = output)
# note, i can concatenate the pos_truth with the lid_truth
# as long as I also concatenate lid_truth with pos_pred
# to get stats depending on the language
print(eval.getMetrics(pos_truth, pos_pred), file = output)
print("Length of pos_pred:", len(pos_pred))
print("Length of pos_truth:", len(pos_truth))
print("Length of lid_truth:", len(lid_truth))
assert len(pos_truth) == len(lid_truth)
assert len(pos_pred) == len(pos_truth)
index = 0
error_dict = {}
correct_dict = {}
for pred in pos_pred:
lid = lid_truth[index]
truth = pos_truth[index]
if pred[0] != truth[0]:
key = (pred[0], tuple(lid), tuple(truth))
error_dict[key] = error_dict.get(key, 0) + 1
else:
key = (pred[0], tuple(lid), tuple(truth))
correct_dict[key] = correct_dict.get(key, 0) + 1
index += 1
print(error_dict, file = output)
print('\n\n\n\n', file = output)
print(correct_dict, file = output)