-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8de7668
commit af9a77a
Showing
9 changed files
with
17,884 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
1 change: 1 addition & 0 deletions
1
bert_admissible_command_generator/data/cooking_games_entities.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{'carrot': 0, 'toolbox': 1, 'cilantro': 2, 'sliding door': 3, 'water': 4, 'patio table': 5, 'stove': 6, 'patio chair': 7, 'commercial glass door': 8, 'table': 9, 'cookbook': 10, 'sliding patio door': 11, 'orange bell pepper': 12, 'chicken leg': 13, 'black pepper': 14, 'wooden door': 15, 'block of cheese': 16, 'sofa': 17, 'front door': 18, 'workbench': 19, 'red apple': 20, 'chicken wing': 21, 'yellow bell pepper': 22, 'red hot pepper': 23, 'red onion': 24, 'patio door': 25, 'oven': 26, 'screen door': 27, 'bed': 28, 'meal': 29, 'fiberglass door': 30, 'barn door': 31, 'shelf': 32, 'BBQ': 33, 'purple potato': 34, 'toilet': 35, 'yellow potato': 36, 'parsley': 37, 'flour': 38, 'red potato': 39, 'olive oil': 40, 'white onion': 41, 'counter': 42, 'fridge': 43, 'showcase': 44, 'salt': 45, 'plain door': 46, 'pork chop': 47, 'banana': 48, 'knife': 49, 'frosted-glass door': 50, "NO_OBJECT": 51, "NOT_ADMISSIBLE":52} |
17,039 changes: 17,039 additions & 0 deletions
17,039
bert_admissible_command_generator/data/example_dataset.json
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{'chop OBJ with OBJ': 0, 'close OBJ': 1, 'cook OBJ with OBJ': 2, 'dice OBJ with OBJ': 3, 'drink OBJ': 4, 'drop OBJ': 5, 'eat OBJ': 6, 'examine OBJ': 7, 'go east': 8, 'go north': 9, 'go south': 10, 'go west': 11, 'insert OBJ into OBJ': 12, 'inventory': 13, 'lock OBJ with OBJ': 14, 'look': 15, 'open OBJ': 16, 'prepare OBJ': 17, 'put OBJ on OBJ': 18, 'slice OBJ with OBJ': 19, 'take OBJ': 20, 'take OBJ from OBJ': 21, 'unlock OBJ with OBJ': 22 } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import numpy as np | ||
import textworld | ||
import textworld.gym | ||
import gym | ||
import re | ||
import sys | ||
import glob | ||
import requests | ||
import json | ||
import numpy as np | ||
import copy | ||
|
||
def clean_game_state(state): | ||
|
||
lines = state.split("\n") | ||
cur = [a.strip() for a in lines] | ||
cur = ' '.join(cur).strip().replace('\n', '').replace('---------', '') | ||
cur = re.sub("(?<=-\=).*?(?=\=-)", '', cur) | ||
cur = re.sub('[$_\\|/>]', '', cur) | ||
cur = cur.replace("-==-", '').strip() | ||
cur = cur.replace("\\", "").strip() | ||
return cur | ||
|
||
|
||
def generate_data(games, seed, branching_depth): | ||
rng = np.random.RandomState(seed) | ||
dataset = [] | ||
seen_states = set() | ||
for game in tqdm(games): | ||
# Ignore the following commands. | ||
commands_to_ignore = ["look", "examine", "inventory"] | ||
|
||
request_infos = textworld.EnvInfos(admissible_commands=True, last_action = True, game = True,inventory=True, description=True, entities=True, facts = True, extras=["recipe","walkthrough","goal"]) | ||
env_id = textworld.gym.register_game(game, request_infos, max_episode_steps=10000) | ||
env = gym.make(env_id) | ||
|
||
_, infos = env.reset() | ||
walkthrough = infos["extra.walkthrough"] | ||
if walkthrough[0] != "inventory": # Make sure we start with listing the inventory. | ||
walkthrough = ["inventory"] + walkthrough | ||
|
||
|
||
done = False | ||
cmd = "restart" # The first previous_action is like [re]starting a new game. | ||
for i in range(len(walkthrough) + 1): | ||
obs, infos = env.reset() | ||
obs = infos["description"] # `obs` would contain the banner and objective text which we don't want. | ||
|
||
# Follow the walkthrough for a bit. | ||
for cmd in walkthrough[:i]: | ||
|
||
obs, _, done, infos = env.step(cmd) | ||
state = "DESCRIPTION: "+ infos['description'] + " INVENTORY: "+ infos['inventory'] | ||
state = clean_game_state(state) | ||
|
||
if state not in seen_states: | ||
|
||
acs = infos['admissible_commands'] | ||
for ac in acs[:]: | ||
if ac.startswith('examine') and ac != 'examine cookbook' or ac == 'look' or ac == 'inventory': | ||
acs.remove(ac) | ||
data = acs | ||
data_name = 'admissible_commands' | ||
|
||
|
||
dataset += [{ | ||
"game": os.path.basename(game), | ||
"step": (i, 0), | ||
"state": state, | ||
data_name : data | ||
}] | ||
|
||
seen_states.add(state) | ||
|
||
if done: | ||
break # Stop collecting data if game is done. | ||
|
||
if i == 0: | ||
continue # No random commands before 'inventory' | ||
|
||
# Then, take N random actions. | ||
for j in range(branching_depth): | ||
cmd = rng.choice([c for c in infos["admissible_commands"] if (c == "examine cookbook" or c.split()[0] not in commands_to_ignore)]) | ||
obs, _, done, infos = env.step(cmd) | ||
if done: | ||
break # Stop collecting data if game is done. | ||
state = "DESCRIPTION: "+ infos['description'] + " INVENTORY: "+ infos['inventory'] | ||
state = clean_game_state(state) | ||
if state not in seen_states: | ||
|
||
acs = infos['admissible_commands'] | ||
for ac in acs[:]: | ||
if (ac.startswith('examine') and ac != 'examine cookbook') or ac == 'look' or ac == 'inventory': | ||
acs.remove(ac) | ||
data = acs | ||
data_name = 'admissible_commands' | ||
|
||
dataset += [{ | ||
"game": os.path.basename(game), | ||
"step": (i, j), | ||
"state": state, | ||
data_name : data | ||
}] | ||
seen_states.add(state) | ||
|
||
with open('data.json', 'w') as fp: | ||
json.dump(dataset, fp) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
#first download textworld games from https://aka.ms/ftwp/dataset.zip | ||
|
||
PATH_TO_GAMES = '' | ||
path = PATH_TO_GAMES + "/train/*.ulx" | ||
games = glob.glob(path) | ||
|
||
generate_data(games,5154,5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
|
||
from torch.utils.data import Dataset | ||
import json | ||
import numpy as np | ||
from transformers import BertTokenizer | ||
import torch | ||
import itertools | ||
|
||
import re | ||
|
||
class AdmissibleCommandsClassificationDataset(Dataset): | ||
def __init__(self, data_file, template2id, object2id, max_seq_length, bert_model_type): | ||
with open(data_file) as json_file: | ||
self.data = json.load(json_file) | ||
|
||
self.template2id = template2id | ||
self.object2id = object2id | ||
|
||
self.template_size = len(template2id) | ||
self.object_size = len(object2id) | ||
|
||
self.tokenizer = BertTokenizer.from_pretrained(bert_model_type, do_lower_case=False) | ||
|
||
self.max_seq_length = max_seq_length | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, idx): | ||
|
||
datapoint = self.data[idx] | ||
input_admissible = datapoint['admissible_commands'] | ||
state = datapoint['state'] | ||
|
||
#if self.encoder_type == 'bert': | ||
|
||
x = '[CLS] ' + state + ' [SEP]' | ||
x = self.tokenizer.tokenize(x) | ||
input_ids = self.tokenizer.convert_tokens_to_ids(x) | ||
|
||
input_mask = [1] * len(input_ids) | ||
diff = (self.max_seq_length - len(input_ids)) | ||
if diff > 0: | ||
padding = [0] * (self.max_seq_length - len(input_ids)) | ||
input_ids += padding | ||
input_mask += padding | ||
else: | ||
input_ids = input_ids[:self.max_seq_length] | ||
input_mask = input_mask[:self.max_seq_length] | ||
|
||
input_ids = torch.LongTensor(input_ids).unsqueeze(0) | ||
input_mask = torch.LongTensor(input_mask).unsqueeze(0) | ||
state = torch.cat([input_ids, input_mask],dim=0) | ||
|
||
# elif self.encoder_type == 'rnn': | ||
# state = self.prepare_state_rnn(state) | ||
|
||
|
||
template_targets = torch.zeros(self.template_size) #y_t|s | ||
o1_template_targets = torch.zeros(self.template_size, self.object_size) #y_o1|s,t | ||
o2_o1_template_targets = torch.zeros(self.template_size,self.object_size, self.object_size) #y_o2|o1,s,t | ||
|
||
valid_acts = self.convert_commands_to_lists(input_admissible) | ||
|
||
assert 'NO_OBJECT' in list(self.object2id.keys()) | ||
assert 'NOT_ADMISSIBLE' in list(self.object2id.keys()) | ||
no_obj_id = self.object2id["NO_OBJECT"] | ||
not_admissible_obj_id = self.object2id["NOT_ADMISSIBLE"] | ||
|
||
#fill in objects from admissible commands | ||
for act in valid_acts: | ||
|
||
#act is [template, obj1, obj2] | ||
t = act[0] | ||
template_idx = self.template2id[t] | ||
template_targets[template_idx] = 1 | ||
|
||
#check how many objects template has | ||
num_objs = len(act) - 1 | ||
if num_objs == 0: | ||
#continue | ||
o1_template_targets[template_idx][no_obj_id] = 1 #this template does not require any objects | ||
o2_o1_template_targets[template_idx][no_obj_id][no_obj_id] = 1 | ||
|
||
elif num_objs == 1: | ||
obj_id = self.object2id[act[1]] | ||
o1_template_targets[template_idx][obj_id] = 1 | ||
o2_o1_template_targets[template_idx][obj_id][no_obj_id] = 1 | ||
|
||
elif num_objs == 2: | ||
obj1_id = self.object2id[act[1]] | ||
obj2_id = self.object2id[act[2]] | ||
o1_template_targets[template_idx][obj1_id] = 1 | ||
o2_o1_template_targets[template_idx][obj1_id][obj2_id] = 1 | ||
|
||
#fill inadmissible commands | ||
valid_templates = [valid_acts[i][0] for i in range(len(valid_acts))] | ||
for t in self.template2id.keys(): | ||
if t not in valid_templates: | ||
template_idx = self.template2id[t] | ||
o1_template_targets[template_idx][not_admissible_obj_id] = 1 # #this template is not admissible, set flags for object targets | ||
#import pdb;pdb.set_trace() | ||
o2_o1_template_targets[template_idx][not_admissible_obj_id][not_admissible_obj_id] = 1 | ||
|
||
return torch.LongTensor(state), template_targets, o1_template_targets, o2_o1_template_targets | ||
|
||
def convert_commands_to_lists(self,admissible_commands): | ||
''' | ||
input: [open fridge, close fridge, take onion from fridge . . . | ||
output [[open OBJ, fridge, None], [close OBJ, fridge, None], [take OBJ from OBJ, onion, fridge] ... | ||
''' | ||
valid_acts = [] | ||
|
||
for act in admissible_commands: | ||
ents = self.extract_entities(act) | ||
for ent in ents: | ||
if ent is not None: | ||
act = act.replace(ent, "OBJ") | ||
cmd = [act] + ents | ||
valid_acts.append(cmd) | ||
return valid_acts | ||
|
||
def extract_entities(self, input_command): | ||
|
||
""" | ||
Extract entities in order from given input command | ||
Example: | ||
input 'cut apple with knife' | ||
output [apple, knife] | ||
input 'close fridge' | ||
output [fride] | ||
input: 'look' | ||
output [] | ||
""" | ||
|
||
#find which ents from all_ents are present in command | ||
all_ents = list(self.object2id.keys()) | ||
|
||
starting_command = input_command | ||
|
||
ents = [] | ||
idxs = [] | ||
|
||
#check combo of three words | ||
three_words = [" ".join(item) for item in itertools.combinations(input_command.split(" "), 3)] | ||
for combo in three_words: | ||
if combo in all_ents: | ||
ents.append(combo) | ||
input_command = input_command.replace(combo,"OBJ") | ||
|
||
|
||
two_words = [" ".join(item) for item in itertools.combinations(input_command.split(" "), 2)] | ||
for combo in two_words: | ||
if combo in all_ents: | ||
ents.append(combo) | ||
input_command = input_command.replace(combo,"OBJ") | ||
|
||
|
||
words = re.findall(r'\w+', input_command) | ||
for word in words: | ||
if word in all_ents: | ||
ents.append(word) | ||
|
||
|
||
if len(ents) == 0: | ||
return [] | ||
elif len(ents) == 1: | ||
return [ents[0]] | ||
|
||
#if more than one ent, determine which ent goes to position 1 or position 2 | ||
else: | ||
ent1 = starting_command.replace(ents[0], "OBJ") | ||
ent2 = starting_command.replace(ents[1], "OBJ") | ||
ent1_pos = ent1.find("OBJ") | ||
ent2_pos = ent2.find("OBJ") | ||
if ent1_pos < ent2_pos: | ||
return [ents[0], ents[1]] | ||
elif ent1_pos > ent2_pos: | ||
return [ents[1], ents[0]] | ||
|
||
def prepare_state_rnn(self,state_description): | ||
remove = ['=', '-', '\'', ':', '[', ']', 'eos', 'EOS', 'SOS', 'UNK', 'unk', 'sos', '<', '>'] | ||
for rm in remove: | ||
state_description = state_description.replace(rm, '') | ||
|
||
state_description = state_description.split('|') | ||
|
||
ret = [self.sp.encode_as_ids('<s>' + s_desc + '</s>') for s_desc in state_description] | ||
|
||
return self.pad_sequences(ret, maxlen=self.max_seq_length) | ||
|
||
def pad_sequences(self,sequences, maxlen=None, dtype='int32', value=0.): | ||
''' | ||
Partially borrowed from Keras | ||
# Arguments | ||
sequences: list of lists where each element is a sequence | ||
maxlen: int, maximum length | ||
dtype: type to cast the resulting sequence. | ||
value: float, value to pad the sequences to the desired value. | ||
# Returns | ||
x: numpy array with dimensions (number_of_sequences, maxlen) | ||
''' | ||
lengths = [len(s) for s in sequences] | ||
nb_samples = len(sequences) | ||
if maxlen is None: | ||
maxlen = np.max(lengths) | ||
# take the sample shape from the first non empty sequence | ||
# checking for consistency in the main loop below. | ||
sample_shape = tuple() | ||
for s in sequences: | ||
if len(s) > 0: | ||
sample_shape = np.asarray(s).shape[1:] | ||
break | ||
x = (np.ones((nb_samples, maxlen) + sample_shape) * value).astype(dtype) | ||
for idx, s in enumerate(sequences): | ||
if len(s) == 0: | ||
continue # empty list was found | ||
# pre truncating | ||
trunc = s[-maxlen:] | ||
# check `trunc` has expected shape | ||
trunc = np.asarray(trunc, dtype=dtype) | ||
if trunc.shape[1:] != sample_shape: | ||
raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' % | ||
(trunc.shape[1:], idx, sample_shape)) | ||
# post padding | ||
x[idx, :len(trunc)] = trunc | ||
|
||
return x |
Oops, something went wrong.