-
Notifications
You must be signed in to change notification settings - Fork 2
/
clevr_utils.py
220 lines (169 loc) · 6.31 KB
/
clevr_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""Utilities for CLEVR-Dialog dataset generation.
Author: Satwik Kottur
"""
import copy
def pretty_print_templates(templates, verbosity=1):
"""Pretty prints templates.
Args:
templates: Templates to print
verbosity: 1 to print name and type of the templates
"""
# Verbosity 1: Name and type.
print('-'*70)
for ii in templates:
print('[Name: %s] [Type: %s]' % (ii['name'], ii['type']))
print('-'*70)
print('Total of %s templates..' % len(templates))
print('-'*70)
def pretty_print_scene_objects(scene):
"""Pretty prints scene objects.
Args:
scene: Scene graph containing list of objects
"""
for index, ii in enumerate(scene['objects']):
print_args = (index, ii['shape'], ii['color'], ii['size'], ii['material'])
print('\t%d : %s-%s-%s-%s' % print_args)
def pretty_print_dialogs(dialogs):
"""Pretty prints generated dialogs.
Args:
dialogs: Generated dialogs to print
"""
for scene_id, dialog_datum in enumerate(dialogs):
for dialog in dialog_datum['dialogs']:
print(dialog['caption'])
for round_id, ii in enumerate(dialog['dialog']):
coref_id = dialog['graph']['history'][round_id+1]['dependence']
in_tuple = (round_id, ii['question'], str(ii['answer']),
ii['template'], str(coref_id))
print('\t[Q-%d: %s] [A: %s] [%s] [%s]' % in_tuple)
def merge_update_scene_graph(orig_graph, graph_item):
"""Merges two scene graphs into one.
Args:
orig_graph: Original scene graph
graph_item: New graph item to add to the scene graph
Returns:
graph: Deep copy of the original scene graph after merging
"""
graph = copy.deepcopy(orig_graph)
# Local alias.
objects = graph['objects']
# If not mergeable, return the same scene graph.
if not graph_item['mergeable']:
return graph
# 1. Go through each new object
# 2. Find its batch in objects
# a. If found, assert for a clash of attributes, update
# b. If novel, just add the object as is
for new_obj in graph_item['objects']:
match_found = False
obj = objects.get(new_obj['id'], None)
if obj:
# Assert for existing entries.
for attr in new_obj:
try:
assert new_obj[attr] == obj.get(attr, new_obj[attr]),\
'Some of the attributes do not match!'
except: pdb.set_trace()
# Add additional keys.
objects[new_obj['id']].update(new_obj)
else:
# Add the new object.
objects[new_obj['id']] = new_obj
# if a relation, update it
if 'relation' in graph_item:
rel = graph_item['relation']
## update it with object 2 id
id1 = graph_item['objects'][0]['id']
id2 = graph_item['objects'][1]['id']
rel_objs = graph['relationships'][rel][id1]
rel_objs.append(id2)
graph['relationships'][rel][id1] = rel_objs
# update objects in graph
graph['objects'] = objects
return graph
def add_object_ids(scenes):
"""Adds object ids field for input scenes.
Args:
scenes: List of CLEVR scene graphs
Returns:
scenes: Adds object_id field for the objects in the scene graph inplace
"""
for scene_id, scene in enumerate(scenes['scenes']):
for obj_id, _ in enumerate(scene['objects']):
scenes['scenes'][scene_id]['objects'][obj_id]['id'] = obj_id
return scenes
def clean_object_attributes(scenes):
"""Cleans attributes for objects, keeping only attributes and id.
Args:
scenes: Scene graph to clean
Returns:
scenes: Cleaned up scene graphs inplace
"""
keys = ['shape', 'size', 'material', 'color', 'id']
for scene_id, scene in enumerate(scenes['scenes']):
for obj_id, obj in enumerate(scene['objects']):
new_obj = {key: obj[key] for key in keys}
scenes['scenes'][scene_id]['objects'][obj_id] = new_obj
return scenes
def pretty_print_corefs(dialog, coref_groups):
"""Prints coreferences for a dialog, higlighting different groups in colors.
Args:
dialog: Generated dialogs to print
coref_groups: Coreference groups for dialogs
"""
colorama.init()
# Mapping of group_id -> color_ids for (foreground, background)
color_map = {}
groups = coref_groups.get(0, [])
colored, color_map = pretty_print_coref_sentence(dialog['caption'], groups,
color_map)
print('\n\nC: %s' % colored)
for round_id, round_datum in enumerate(dialog['dialog']):
question = round_datum['question']
groups = coref_groups.get(round_id + 1, [])
colored, color_map = pretty_print_coref_sentence(question, groups,
color_map)
print('%d: %s' % (round_id, colored))
def pretty_print_coref_sentence(sentence, groups, color_map):
"""Prints a sentence containing difference coreference groups.
Args:
sentence: Text sentence
groups: List of coreference groups with spans
color_map: List of groups and associated color maps
Returns:
sentence: Text sentence with colors inserted
color_map: Updated, if new groups in the current sentence
"""
fore_colors = ['RED', 'GREEN', 'YELLOW', 'BLUE', 'MAGENTA']
back_colors = ['BLACK', 'YELLOW', 'CYAN']
insertions = []
for group in groups:
group_id = group['group_id']
if group_id in color_map:
forecolor_id, backcolor_id = color_map[group_id]
else:
num_groups = len(color_map)
forecolor_id = num_groups % len(fore_colors)
backcolor_id = num_groups // len(fore_colors)
color_map[group_id] = (forecolor_id, backcolor_id)
forecolor = fore_colors[forecolor_id]
backcolor = back_colors[backcolor_id]
insertions.append((group['span'][0], getattr(colorama.Fore, forecolor)))
insertions.append((group['span'][0], getattr(colorama.Back, backcolor)))
insertions.append((group['span'][1],
getattr(colorama.Style, 'RESET_ALL')))
# Perform insertions.
sentence = insert_into_sentence(sentence, insertions)
return sentence, color_map
def insert_into_sentence(sentence, insertions):
"""Sorts and performs insertions from right.
Args:
sentence: Sentence to perform insertions into
insertions: List of insertions, format: (position, text_insert)
Returns:
sentence: Inplace inserted sentence
"""
insertions = sorted(insertions, key=lambda x: x[0], reverse=True)
for position, text in insertions:
sentence = sentence[:position] + text + sentence[position:]
return sentence