-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsingleturn_dataset.py
297 lines (257 loc) · 12.1 KB
/
singleturn_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
# Originally taken from: https://github.com/iglu-contest/gridworld
# Author: Milagro Teruel, Artem Zholus
from ipaddress import ip_address
import os
import json
import re
import pandas as pd
import numpy as np
import pickle
import bz2
from collections import defaultdict
from pprint import pprint
from .task import Subtasks, Task, Tasks
from .download import download
from .common import VOXELWORLD_GROUND_LEVEL, fix_log, fix_xyz
from .multiturn_dataset import MultiturnDataset
from zipfile import ZipFile
from tqdm import tqdm
class SingleturnDataset(MultiturnDataset):
SINGLE_TURN_INSTRUCTION_FILENAME = 'single_turn_instructions.csv'
MULTI_TURN_INSTRUCTION_FILENAME = 'multi_turn_dialogs.csv'
DATASET_URL = {
"v1.0": 'https://github.com/microsoft/iglu-datasets/raw/main/datasets/single_turn_dataset.zip',
}
BLOCK_MAP = {
# voxelworld's colour id : iglu colour id
00: 0, # air
57: 1, # blue
50: 6, # yellow
59: 2, # green
47: 4, # orange
56: 5, # purple
60: 3, # red
# voxelworld (freeze version)'s colour id : iglu colour id
86: 1, # blue
87: 6, # yellow
88: 2, # green
89: 4, # orange
90: 5, # purple
91: 3, # red
}
def __init__(self, dataset_version='v1.0', task_kwargs=None,
force_download=False, limit=None,
keep_unmodified=True, keep_wrong_builds=True) -> None:
self.limit = limit
self.keep_unmodified = keep_unmodified
self.keep_wrong_builds = keep_wrong_builds
super().__init__(dataset_version=dataset_version,
task_kwargs=task_kwargs, force_download=force_download)
def get_instructions(self, data_path):
single_turn_df = pd.read_csv(os.path.join(
data_path, self.SINGLE_TURN_INSTRUCTION_FILENAME))
if self.limit is not None:
return single_turn_df[:self.limit]
return single_turn_df
def get_multiturn_dialogs(self, data_path):
return pd.read_csv(os.path.join(
data_path, self.MULTI_TURN_INSTRUCTION_FILENAME))
@classmethod
def get_data_path(cls):
"""Returns the path where iglu dataset will be downloaded and cached.
It can be set using the environment variable IGLU_DATA_PATH. Otherwise,
it will be `~/.iglu/data/single_turn_dataset`.
Returns
-------
str
The absolute path to data folder.
"""
if 'IGLU_DATA_PATH' in os.environ:
data_path = os.environ['IGLU_DATA_PATH']
custom = True
elif 'HOME' in os.environ:
data_path = os.path.join(
os.environ['HOME'], '.iglu', 'data', 'single_turn_dataset')
custom = False
else:
data_path = os.path.join(
os.path.expanduser('~'), '.iglu', 'data', 'single_turn_dataset')
custom = False
return data_path, custom
def download_dataset(self, data_path, force_download):
instruction_filepath = os.path.join(
data_path, self.SINGLE_TURN_INSTRUCTION_FILENAME)
path = os.path.join(data_path, 'single_turn_dataset.zip')
if os.path.exists(instruction_filepath) and not force_download:
print("Using cached dataset")
return
url = self.DATASET_URL[self.dataset_version]
if not isinstance(url, str):
url = url[0]
print(f"Downloading dataset from {url}")
download(
url=url,
destination=path,
data_prefix=data_path
)
print('Dataset downloaded!')
with ZipFile(path) as zfile:
zfile.extractall(data_path, members=tqdm(zfile.namelist(), desc='Extracting zip file'))
def create_task(self, previous_chat, initial_grid, target_grid,
last_instruction):
task = Task(
dialog=previous_chat,
instruction=last_instruction,
target_grid=Tasks.to_dense(target_grid),
starting_grid=Tasks.to_sparse(initial_grid),
full_grid=Tasks.to_dense(target_grid),
)
# To properly init max_int and prev_grid_size fields
task.sample()
return task
def get_previous_dialogs(self, single_turn_row, multiturn_dialogs):
# Filter multiturn rows with this game id and previous to step
utterances = []
mturn_data_path = single_turn_row.InitializedWorldPath.split('/')[-2:]
if len(mturn_data_path) != 2 or '-' not in mturn_data_path[1]:
print(f"Error with initial data path {single_turn_row.InitializedWorldPath}."
"Could not parse data path to get previous dialogs.")
return utterances
mturn_game_id = mturn_data_path[0]
try:
mturn_last_step = int(mturn_data_path[1].replace("step-", ""))
except Exception as e:
print(f"Error with initial data path {single_turn_row.InitializedWorldPath}."
"Could not parse step id to get previous dialogs.")
return utterances
dialog_rows = multiturn_dialogs[
(multiturn_dialogs.PartitionKey == mturn_game_id) &
(multiturn_dialogs.StepId < mturn_last_step) &
(multiturn_dialogs.IsHITQualified == True)]
for _, row in dialog_rows.sort_values('StepId')\
.reset_index(drop=True).iterrows():
if row.StepId % 2 == 1:
# Architect step
if isinstance(row.instruction, str):
utterance = row.instruction
utterances.append(
f'<Architect> {self.process(utterance)}')
elif isinstance(row.Answer4ClarifyingQuestion, str):
utterance = row.Answer4ClarifyingQuestion
utterances.append(
f'<Architect> {self.process(utterance)}')
elif isinstance(row.ClarifyingQuestion, str):
utterances.append(
f'<Builder> {self.process(row.ClarifyingQuestion)}')
return utterances
def parse_tasks(self, dialogs, path):
"""Fills attribute `self.tasks` with instances of Task.
A Task contains an initial world state, a target world state and a
single instruction.
Parameters
----------
dialogs : pandas.DataFrame
Contains information of each session, originally stored
in database tables. The information includes:
- InitializedWorldStructureId or InitializedWorldGameId:
Original target structure id of the initial world.
- InitializedWorldPath: Path to a json file that contains the
initial blocks of the world.
- ActionDataPath: Path relative to dataset location with the
target world.
- InputInstruction: Session instruction
- IsHITQualified: boolean indicating if the step is valid.
path : _type_
Path with the state of the VoxelWorld grid after each session.
Each session should have an associated directory named with the
session id, with json files that describe the world state after
each step.
"""
dialogs = dialogs[dialogs.InitializedWorldPath.notna()]
dialogs['InitializedWorldPath'] = dialogs['InitializedWorldPath'] \
.apply(lambda x: x.replace('\\', os.path.sep))
dialogs['InitializedWorldPath'] = dialogs['InitializedWorldPath'] \
.apply(lambda x: x.replace('/', os.path.sep))
# Get the list of games for which the instructions were clear.
turns = dialogs[dialogs.GameId.str.match("CQ-*")]
# Util function to read structure from disk.
def _load_structure(structure_path):
filepath = os.path.join(path, structure_path)
if not os.path.exists(filepath):
return None
with open(filepath) as structure_file:
structure_data = json.load(structure_file)
blocks = structure_data['worldEndingState']['blocks']
structure = [self.transform_block(block) for block in blocks]
return structure
multiturn_dialogs = self.get_multiturn_dialogs(path)
tasks_count = 0
pbar = tqdm(turns.iterrows(), total=len(turns), desc='parsing dataset')
errors = defaultdict(int)
for k, row in pbar:
pbar.set_postfix_str(f"{tasks_count} tasks")
assert row.InitializedWorldStructureId is not None
# Read initial structure
initial_world_blocks = _load_structure(row.InitializedWorldPath)
if initial_world_blocks is None:
errors['start_world_missing'] += 1
pbar.write(f"Skipping '{row.GameId}'. Can't load starting structure from '{row.InitializedWorldPath}'.")
continue
target_world_blocks = _load_structure(row.TargetWorldPath)
if target_world_blocks is None:
errors['target_world_missing'] += 1
pbar.write(f"Skipping '{row.GameId}'. Can't load target structure from '{row.TargetWorldPath}'.")
continue
# Check if target structure matches the initial structure.
if sorted(initial_world_blocks) == sorted(target_world_blocks) and not self.keep_unmodified:
errors['target_unchanged'] += 1
pbar.write(f"Skipping '{row.GameId}'. Target structure is the same as the initial one.")
continue
# Get the original game.
# TOOD: here we need to remove `[len("CQ-"):]`. Otherwise CQ-* sessions will be dropped.
# but this might get duplicates in the dataset.
orig = dialogs[dialogs.GameId == row.GameId[len("CQ-"):]]
if len(orig) == 0:
errors['dialog_not_found'] += 1
pbar.write(f"Skipping '{row.GameId}'. Can't find its original game '{row.GameId[len('CQ-'):]}'.")
continue
assert len(orig) == 1
orig = orig.iloc[0]
# Load original structure.
orig_target_world_blocks = _load_structure(orig.TargetWorldPath)
if orig_target_world_blocks is None:
errors['multiturn_target_not_found'] += 1
pbar.write(f"Skipping '{row.GameId}'. Can't load original target structure from '{orig.TargetWorldPath}'.")
continue
# Check if original structure matches the rebuilt one.
if sorted(orig_target_world_blocks) != sorted(target_world_blocks) and not self.keep_wrong_builds:
errors['multiturn_target_mismatch'] += 1
pbar.write(f"Skipping '{row.GameId}'. Target structure doesn't match the one in '{orig.GameId}'.")
continue
last_instruction = '<Architect> ' + self.process(row.InputInstruction)
# Read utterances
utterances = self.get_previous_dialogs(row, multiturn_dialogs)
utterances.append(last_instruction)
utterances = '\n'.join(utterances)
# Construct task
task = self.create_task(
utterances, initial_world_blocks, target_world_blocks,
last_instruction=last_instruction)
# e.g. initial_world_states\builder-data/8-c92/step-4 -> 8-c92/step-4
task_id, step_id = row.InitializedWorldPath.split("/")[-2:]
#self.tasks[row.InitializedWorldStructureId].append(task)
self.tasks[f"{task_id}/{step_id}"].append(task)
tasks_count += 1
if len(errors) != 0:
print('Some samples are failed to load. Here are the stats:')
pprint(dict(errors))
print(f'no error: {tasks_count}')
def __iter__(self):
for task_id, tasks in self.tasks.items():
for j, task in enumerate(tasks):
yield task_id, j, 1, task
def __len__(self):
return len(sum(self.tasks.values(), []))
if __name__ == "__main__":
dataset = SingleturnDataset()
pass