forked from OFA-Sys/OFA
-
Notifications
You must be signed in to change notification settings - Fork 1
/
preprocess_vizwiz.py
117 lines (105 loc) · 4.54 KB
/
preprocess_vizwiz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
A script to transform the VizWiz dataset into the format of the VQA dataset.
The format is a csv file with the following columns:
question-id, image-id, question text (lowercase), answer with confidence (like 1.0|!+no),
object label (can be blank), image as base64 encoded string
Further, a trainval_ans2label2.pkl file has to be generated which is simply a pickled dict,
mapping the most frequent answers to a (random) label.
Author: Jan Willruth
"""
import base64
import glob
import json
import _pickle as pickle
import sys
from io import BytesIO
from collections import Counter
from PIL import Image
from tqdm import tqdm
def img2base64(fn):
"""
Convert an image to base64 encoded string. Code from a maintainer of the OFA repo
(https://github.com/OFA-Sys/OFA/issues/56).
:param fn: filename of image
:return: base64 encoded string
"""
img = Image.open(fn)
img_buffer = BytesIO()
img.save(img_buffer, format=img.format)
byte_data = img_buffer.getvalue()
return base64.b64encode(byte_data).decode('utf-8')
def main():
"""print('Generating trainval_ans2label.pkl file...')
# Load annotations
train = json.load(open('vizwiz_data/Annotations/train.json', encoding='utf-8'))
val = json.load(open('vizwiz_data/Annotations/val.json', encoding='utf-8'))
annotations = train + val
# Extract answers
answers = [ans['answer'] for question in annotations for ans in question['answers']]
# Count occurrences of answers
answer_counts = {}
for ca in answers:
if ca not in answer_counts:
answer_counts[ca] = 1
else:
answer_counts[ca] += 1
# Pick top x most frequent answers (VQA uses 3129)
x = 3129
freq_answers = sorted(answer_counts, key=answer_counts.get, reverse=True)[:x]
# Create dict to map answers to labels
trainval_ans2label = {answer: i for i, answer in enumerate(freq_answers)}
# Save to file
with open('vizwiz_data/trainval_ans2label.pkl', 'wb') as f:
pickle.dump(trainval_ans2label, f)
print('Finished generating trainval_ans2label.pkl file...')"""
# Load pickled dict
pkl = pickle.load(open('vizwiz_data/trainval_ans2label.pkl', 'rb'))
# Dict to map answer confidence to value
conf = {'yes': '1.0', 'maybe': '0.5', 'no': '0.0'}
# Iterate over subsets
for subset in ['train', 'val', 'test']:
print(f'Generating rows for {subset} tsv file...')
# Load corresponding json file
annotations = json.load(open(f'vizwiz_data/Annotations/{subset}.json', encoding='utf-8'))
# Create empty set to store data
tsv_set = set()
# Single or multiple answer(s)
single_answer = True
# Iterate over all images in subset
file_names = glob.glob(f'vizwiz_data/{subset}/*.jpg')
for fn in tqdm(file_names, file=sys.stdout):
# Some string manipulation to get img_id
fn = fn.replace('\\', '/')
img_id = int(fn.split('/')[-1].split('_')[-1][3:-4])
# Get corresponding question
try:
question = annotations[img_id]['question'].lower()
except IndexError:
continue
# If test subset, use placeholder answer, else iterate over all questions
if subset == 'test':
tsv_set.add((img_id, img_id, question, '1.0|!+no', '', img2base64(fn)))
else:
if single_answer:
answers = [ca['answer'] for ca in annotations[img_id]['answers']]
conf, ans = 1.0, Counter(answers).most_common(1)[0][0]
ans_conf = f'{conf}|!+{ans}&&'
else:
ans_conf = ''
conf_ans = annotations[img_id]['answers']
conf_ans = [{'answer_confidence': ca['answer_confidence'], 'answer': ca['answer']}
for ca in conf_ans if ca['answer'] in pkl.values()]
if not conf_ans:
continue
for ca in conf_ans:
ans_conf += f'{conf[ca["answer_confidence"]]}|!+{ca["answer"]}&&'
tsv_set.add((img_id, img_id, question, ans_conf[:-2], '', img2base64(fn)))
# Write to tsv file
print(f'Writing {subset} tsv file...')
with open(f'vizwiz_data/vizwiz_{subset}.tsv', 'w', encoding='utf-8') as f:
for line in tsv_set:
f.write('\t'.join(map(str, line)) + '\n')
return 'Finished creating tsv files!'
if __name__ == '__main__':
print(main())
print('All done!')