forked from QwenLM/Qwen-VL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfographicsvqa_eval.py
248 lines (171 loc) · 10.2 KB
/
infographicsvqa_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# This file can be downloaded from: https://www.docvqa.org/datasets/infographicvqa and https://rrc.cvc.uab.es/?ch=17&com=introduction
import os, json
import argparse
question_ids_to_exclude = []
# answer_types = {'image span': 'Image-Span', 'question span': 'Question-Span', 'multiple spans': 'Multi-Span', 'non span': 'None span', 'list': 'List'}
answer_types = {'image span': 'Image-Span', 'question span': 'Question-Span', 'multiple spans': 'Multi-Span', 'non span': 'None span'}
evidence_types = {'table/list': 'Table/list', 'textual': 'Text', 'photo/pciture/visual_objects': 'Visual/Layout', 'figure': 'Figure', 'map': 'Map'}
reasoning_requirements = {'comparison': 'Sorting', 'arithmetic': 'Arithmetic', 'counting':'Counting'}
def save_json(file_path, data):
with open(file_path, 'w+') as json_file:
json.dump(data, json_file)
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2+1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
def validate_data(gtFilePath, submFilePath):
"""
Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
Validates also that there are no missing files in the folder.
If some error detected, the method raises the error
"""
gtJson = json.load(open(gtFilePath,'rb'));
submJson = json.load(open(submFilePath,'rb'));
if not 'data' in gtJson:
raise Exception("The GT file is not valid (no data key)")
if not 'dataset_name' in gtJson:
raise Exception("The GT file is not valid (no dataset_name key)")
if isinstance(submJson, list) == False :
raise Exception("The Det file is not valid (root item must be an array)")
if len(submJson) != len(gtJson['data']) :
raise Exception("The Det file is not valid (invalid number of answers. Expected:" + str(len(gtJson['data'])) + " Found:" + str(len(submJson)) + ")")
gtQuestions = sorted([r['questionId'] for r in gtJson['data']])
res_id_to_index = {int(r['questionId']): ix for ix, r in enumerate(submJson)}
detQuestions = sorted([r['questionId'] for r in submJson])
if( (gtQuestions == detQuestions) == False ):
raise Exception("The Det file is not valid. Question IDs must much GT")
for gtObject in gtJson['data']:
try:
q_id = int(gtObject['questionId']);
res_ix = res_id_to_index[q_id];
except:
raise Exception("The Det file is not valid. Question " + str(gtObject['questionId']) + " not present")
else:
detObject = submJson[res_ix];
# if detObject['questionId'] != gtObject['questionId'] :
# raise Exception("Answer #" + str(i) + " not valid (invalid question ID. Expected:" + str(gtObject['questionId']) + "Found:" + detObject['questionId'] + ")")
if not 'answer' in detObject:
raise Exception("Question " + str(gtObject['questionId']) + " not valid (no answer key)")
if isinstance(detObject['answer'], list) == True :
raise Exception("Question " + str(gtObject['questionId']) + " not valid (answer key has to be a single string)")
def evaluate_method(gtFilePath, submFilePath, evaluationParams):
"""
Method evaluate_method: evaluate method and returns the results
Results. Dictionary with the following values:
- method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
- samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
"""
show_scores_per_answer_type = evaluationParams.answer_types
gtJson = json.load(open(gtFilePath,'rb'));
submJson = json.load(open(submFilePath,'rb'));
res_id_to_index = {int(r['questionId']): ix for ix, r in enumerate(submJson)}
perSampleMetrics = {}
totalScore = 0
row = 0
if show_scores_per_answer_type:
answerTypeTotalScore = {x:0 for x in answer_types.keys()}
answerTypeNumQuestions = {x:0 for x in answer_types.keys()}
evidenceTypeTotalScore = {x:0 for x in evidence_types.keys()}
evidenceTypeNumQuestions = {x:0 for x in evidence_types.keys()}
reasoningTypeTotalScore = {x:0 for x in reasoning_requirements.keys()}
reasoningTypeNumQuestions = {x:0 for x in reasoning_requirements.keys()}
for gtObject in gtJson['data']:
q_id = int(gtObject['questionId']);
res_ix = res_id_to_index[q_id];
detObject = submJson[res_ix];
if q_id in question_ids_to_exclude:
question_result = 0
info = 'Question EXCLUDED from the result'
else:
info = ''
values = []
for answer in gtObject['answers']:
# preprocess both the answers - gt and prediction
gt_answer = ' '.join(answer.strip().lower().split())
det_answer = ' '.join(detObject['answer'].strip().lower().split())
#dist = levenshtein_distance(answer.lower(), detObject['answer'].lower())
dist = levenshtein_distance(gt_answer,det_answer)
length = max( len(answer.upper()), len(detObject['answer'].upper()) )
values.append( 0.0 if length == 0 else float(dist) / float(length) )
question_result = 1 - min(values)
if (question_result < evaluationParams.anls_threshold) :
question_result = 0
totalScore += question_result
if show_scores_per_answer_type:
for q_type in gtObject["answer_type"]:
answerTypeTotalScore[q_type] += question_result
answerTypeNumQuestions[q_type] += 1
for q_type in gtObject["evidence"]:
evidenceTypeTotalScore[q_type] += question_result
evidenceTypeNumQuestions[q_type] += 1
for q_type in gtObject["operation/reasoning"]:
reasoningTypeTotalScore[q_type] += question_result
reasoningTypeNumQuestions[q_type] += 1
perSampleMetrics[str(gtObject['questionId'])] = {
'score':question_result,
'question':gtObject['question'],
'gt':gtObject['answers'],
'det':detObject['answer'],
'info': info
}
row = row + 1
methodMetrics = {
'score': 0 if len(gtJson['data']) == 0 else totalScore/ (len(gtJson['data']) - len(question_ids_to_exclude) )
}
answer_types_scores = {}
evidence_types_scores = {}
operation_types_scores = {}
if show_scores_per_answer_type:
for a_type, ref in answer_types.items():
answer_types_scores[ref] = 0 if len(gtJson['data']) == 0 else answerTypeTotalScore[a_type] / (answerTypeNumQuestions[a_type] )
for e_type, ref in evidence_types.items():
evidence_types_scores[ref] = 0 if len(gtJson['data']) == 0 else evidenceTypeTotalScore[e_type] / (evidenceTypeNumQuestions[e_type] )
for r_type, ref in reasoning_requirements.items():
operation_types_scores[ref] = 0 if len(gtJson['data']) == 0 else reasoningTypeTotalScore[r_type] / (reasoningTypeNumQuestions[r_type] )
resDict = {
'result': methodMetrics,
'scores_by_types': {'answer_types': answer_types_scores, 'evidence_types': evidence_types_scores, 'operation_types': operation_types_scores},
'per_sample_result':perSampleMetrics
}
return resDict;
def display_results(results, show_answer_types):
print("\nOverall ANLS: {:2.4f}".format(results['result']['score']))
if show_answer_types:
print("\nAnswer types:")
for a_type in answer_types.values():
print("\t{:12s} {:2.4f}".format(a_type, results['scores_by_types']['answer_types'][a_type]))
print("\nEvidence types:")
for e_type in evidence_types.values():
print("\t{:12s} {:2.4f}".format(e_type, results['scores_by_types']['evidence_types'][e_type]))
print("\nOperation required:")
for r_type in reasoning_requirements.values():
print("\t{:12s} {:2.4f}".format(r_type, results['scores_by_types']['operation_types'][r_type]))
if __name__=='__main__':
parser = argparse.ArgumentParser(description="InfographVQA evaluation script.")
parser.add_argument('-g', '--ground_truth', type=str, help="Path of the Ground Truth file.", required=True)
parser.add_argument('-s', '--submission_file', type=str, help="Path of your method's results file.", required=True)
parser.add_argument('-t', '--anls_threshold', type=float, default=0.5, help="ANLS threshold to use (See Scene-Text VQA paper for more info.).", required=False)
parser.add_argument('-a', '--answer_types', type=bool, default=False, help="Score break down by answer types (special gt file required).", required=False)
parser.add_argument('-o', '--output', type=str, help="Path to a directory where to copy the file 'results.json' that contains per-sample results.", required=False)
args = parser.parse_args()
# Validate the format of ground truth and submission files.
validate_data(args.ground_truth, args.submission_file)
# Evaluate method
results = evaluate_method(args.ground_truth, args.submission_file, args)
display_results(results, args.answer_types)
if args.output:
output_dir = args.output
if not os.path.exists(output_dir):
os.makedirs(output_dir)
resultsOutputname = os.path.join(output_dir, 'results.json')
save_json(resultsOutputname, results)
print("All results including per-sample result has been correctly saved!")