Skip to content

Commit

Permalink
BERT admissible command code
Browse files Browse the repository at this point in the history
  • Loading branch information
bryonkucharski committed Jun 10, 2020
1 parent 8de7668 commit af9a77a
Show file tree
Hide file tree
Showing 9 changed files with 17,884 additions and 0 deletions.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{'carrot': 0, 'toolbox': 1, 'cilantro': 2, 'sliding door': 3, 'water': 4, 'patio table': 5, 'stove': 6, 'patio chair': 7, 'commercial glass door': 8, 'table': 9, 'cookbook': 10, 'sliding patio door': 11, 'orange bell pepper': 12, 'chicken leg': 13, 'black pepper': 14, 'wooden door': 15, 'block of cheese': 16, 'sofa': 17, 'front door': 18, 'workbench': 19, 'red apple': 20, 'chicken wing': 21, 'yellow bell pepper': 22, 'red hot pepper': 23, 'red onion': 24, 'patio door': 25, 'oven': 26, 'screen door': 27, 'bed': 28, 'meal': 29, 'fiberglass door': 30, 'barn door': 31, 'shelf': 32, 'BBQ': 33, 'purple potato': 34, 'toilet': 35, 'yellow potato': 36, 'parsley': 37, 'flour': 38, 'red potato': 39, 'olive oil': 40, 'white onion': 41, 'counter': 42, 'fridge': 43, 'showcase': 44, 'salt': 45, 'plain door': 46, 'pork chop': 47, 'banana': 48, 'knife': 49, 'frosted-glass door': 50, "NO_OBJECT": 51, "NOT_ADMISSIBLE":52}
17,039 changes: 17,039 additions & 0 deletions bert_admissible_command_generator/data/example_dataset.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions bert_admissible_command_generator/data/template2id.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{'chop OBJ with OBJ': 0, 'close OBJ': 1, 'cook OBJ with OBJ': 2, 'dice OBJ with OBJ': 3, 'drink OBJ': 4, 'drop OBJ': 5, 'eat OBJ': 6, 'examine OBJ': 7, 'go east': 8, 'go north': 9, 'go south': 10, 'go west': 11, 'insert OBJ into OBJ': 12, 'inventory': 13, 'lock OBJ with OBJ': 14, 'look': 15, 'open OBJ': 16, 'prepare OBJ': 17, 'put OBJ on OBJ': 18, 'slice OBJ with OBJ': 19, 'take OBJ': 20, 'take OBJ from OBJ': 21, 'unlock OBJ with OBJ': 22 }
118 changes: 118 additions & 0 deletions bert_admissible_command_generator/data_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import numpy as np
import textworld
import textworld.gym
import gym
import re
import sys
import glob
import requests
import json
import numpy as np
import copy

def clean_game_state(state):

lines = state.split("\n")
cur = [a.strip() for a in lines]
cur = ' '.join(cur).strip().replace('\n', '').replace('---------', '')
cur = re.sub("(?<=-\=).*?(?=\=-)", '', cur)
cur = re.sub('[$_\\|/>]', '', cur)
cur = cur.replace("-==-", '').strip()
cur = cur.replace("\\", "").strip()
return cur


def generate_data(games, seed, branching_depth):
rng = np.random.RandomState(seed)
dataset = []
seen_states = set()
for game in tqdm(games):
# Ignore the following commands.
commands_to_ignore = ["look", "examine", "inventory"]

request_infos = textworld.EnvInfos(admissible_commands=True, last_action = True, game = True,inventory=True, description=True, entities=True, facts = True, extras=["recipe","walkthrough","goal"])
env_id = textworld.gym.register_game(game, request_infos, max_episode_steps=10000)
env = gym.make(env_id)

_, infos = env.reset()
walkthrough = infos["extra.walkthrough"]
if walkthrough[0] != "inventory": # Make sure we start with listing the inventory.
walkthrough = ["inventory"] + walkthrough


done = False
cmd = "restart" # The first previous_action is like [re]starting a new game.
for i in range(len(walkthrough) + 1):
obs, infos = env.reset()
obs = infos["description"] # `obs` would contain the banner and objective text which we don't want.

# Follow the walkthrough for a bit.
for cmd in walkthrough[:i]:

obs, _, done, infos = env.step(cmd)
state = "DESCRIPTION: "+ infos['description'] + " INVENTORY: "+ infos['inventory']
state = clean_game_state(state)

if state not in seen_states:

acs = infos['admissible_commands']
for ac in acs[:]:
if ac.startswith('examine') and ac != 'examine cookbook' or ac == 'look' or ac == 'inventory':
acs.remove(ac)
data = acs
data_name = 'admissible_commands'


dataset += [{
"game": os.path.basename(game),
"step": (i, 0),
"state": state,
data_name : data
}]

seen_states.add(state)

if done:
break # Stop collecting data if game is done.

if i == 0:
continue # No random commands before 'inventory'

# Then, take N random actions.
for j in range(branching_depth):
cmd = rng.choice([c for c in infos["admissible_commands"] if (c == "examine cookbook" or c.split()[0] not in commands_to_ignore)])
obs, _, done, infos = env.step(cmd)
if done:
break # Stop collecting data if game is done.
state = "DESCRIPTION: "+ infos['description'] + " INVENTORY: "+ infos['inventory']
state = clean_game_state(state)
if state not in seen_states:

acs = infos['admissible_commands']
for ac in acs[:]:
if (ac.startswith('examine') and ac != 'examine cookbook') or ac == 'look' or ac == 'inventory':
acs.remove(ac)
data = acs
data_name = 'admissible_commands'

dataset += [{
"game": os.path.basename(game),
"step": (i, j),
"state": state,
data_name : data
}]
seen_states.add(state)

with open('data.json', 'w') as fp:
json.dump(dataset, fp)


if __name__ == "__main__":

#first download textworld games from https://aka.ms/ftwp/dataset.zip

PATH_TO_GAMES = ''
path = PATH_TO_GAMES + "/train/*.ulx"
games = glob.glob(path)

generate_data(games,5154,5)
231 changes: 231 additions & 0 deletions bert_admissible_command_generator/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@

from torch.utils.data import Dataset
import json
import numpy as np
from transformers import BertTokenizer
import torch
import itertools

import re

class AdmissibleCommandsClassificationDataset(Dataset):
def __init__(self, data_file, template2id, object2id, max_seq_length, bert_model_type):
with open(data_file) as json_file:
self.data = json.load(json_file)

self.template2id = template2id
self.object2id = object2id

self.template_size = len(template2id)
self.object_size = len(object2id)

self.tokenizer = BertTokenizer.from_pretrained(bert_model_type, do_lower_case=False)

self.max_seq_length = max_seq_length

def __len__(self):
return len(self.data)

def __getitem__(self, idx):

datapoint = self.data[idx]
input_admissible = datapoint['admissible_commands']
state = datapoint['state']

#if self.encoder_type == 'bert':

x = '[CLS] ' + state + ' [SEP]'
x = self.tokenizer.tokenize(x)
input_ids = self.tokenizer.convert_tokens_to_ids(x)

input_mask = [1] * len(input_ids)
diff = (self.max_seq_length - len(input_ids))
if diff > 0:
padding = [0] * (self.max_seq_length - len(input_ids))
input_ids += padding
input_mask += padding
else:
input_ids = input_ids[:self.max_seq_length]
input_mask = input_mask[:self.max_seq_length]

input_ids = torch.LongTensor(input_ids).unsqueeze(0)
input_mask = torch.LongTensor(input_mask).unsqueeze(0)
state = torch.cat([input_ids, input_mask],dim=0)

# elif self.encoder_type == 'rnn':
# state = self.prepare_state_rnn(state)


template_targets = torch.zeros(self.template_size) #y_t|s
o1_template_targets = torch.zeros(self.template_size, self.object_size) #y_o1|s,t
o2_o1_template_targets = torch.zeros(self.template_size,self.object_size, self.object_size) #y_o2|o1,s,t

valid_acts = self.convert_commands_to_lists(input_admissible)

assert 'NO_OBJECT' in list(self.object2id.keys())
assert 'NOT_ADMISSIBLE' in list(self.object2id.keys())
no_obj_id = self.object2id["NO_OBJECT"]
not_admissible_obj_id = self.object2id["NOT_ADMISSIBLE"]

#fill in objects from admissible commands
for act in valid_acts:

#act is [template, obj1, obj2]
t = act[0]
template_idx = self.template2id[t]
template_targets[template_idx] = 1

#check how many objects template has
num_objs = len(act) - 1
if num_objs == 0:
#continue
o1_template_targets[template_idx][no_obj_id] = 1 #this template does not require any objects
o2_o1_template_targets[template_idx][no_obj_id][no_obj_id] = 1

elif num_objs == 1:
obj_id = self.object2id[act[1]]
o1_template_targets[template_idx][obj_id] = 1
o2_o1_template_targets[template_idx][obj_id][no_obj_id] = 1

elif num_objs == 2:
obj1_id = self.object2id[act[1]]
obj2_id = self.object2id[act[2]]
o1_template_targets[template_idx][obj1_id] = 1
o2_o1_template_targets[template_idx][obj1_id][obj2_id] = 1

#fill inadmissible commands
valid_templates = [valid_acts[i][0] for i in range(len(valid_acts))]
for t in self.template2id.keys():
if t not in valid_templates:
template_idx = self.template2id[t]
o1_template_targets[template_idx][not_admissible_obj_id] = 1 # #this template is not admissible, set flags for object targets
#import pdb;pdb.set_trace()
o2_o1_template_targets[template_idx][not_admissible_obj_id][not_admissible_obj_id] = 1

return torch.LongTensor(state), template_targets, o1_template_targets, o2_o1_template_targets

def convert_commands_to_lists(self,admissible_commands):
'''
input: [open fridge, close fridge, take onion from fridge . . .
output [[open OBJ, fridge, None], [close OBJ, fridge, None], [take OBJ from OBJ, onion, fridge] ...
'''
valid_acts = []

for act in admissible_commands:
ents = self.extract_entities(act)
for ent in ents:
if ent is not None:
act = act.replace(ent, "OBJ")
cmd = [act] + ents
valid_acts.append(cmd)
return valid_acts

def extract_entities(self, input_command):

"""
Extract entities in order from given input command
Example:
input 'cut apple with knife'
output [apple, knife]
input 'close fridge'
output [fride]
input: 'look'
output []
"""

#find which ents from all_ents are present in command
all_ents = list(self.object2id.keys())

starting_command = input_command

ents = []
idxs = []

#check combo of three words
three_words = [" ".join(item) for item in itertools.combinations(input_command.split(" "), 3)]
for combo in three_words:
if combo in all_ents:
ents.append(combo)
input_command = input_command.replace(combo,"OBJ")


two_words = [" ".join(item) for item in itertools.combinations(input_command.split(" "), 2)]
for combo in two_words:
if combo in all_ents:
ents.append(combo)
input_command = input_command.replace(combo,"OBJ")


words = re.findall(r'\w+', input_command)
for word in words:
if word in all_ents:
ents.append(word)


if len(ents) == 0:
return []
elif len(ents) == 1:
return [ents[0]]

#if more than one ent, determine which ent goes to position 1 or position 2
else:
ent1 = starting_command.replace(ents[0], "OBJ")
ent2 = starting_command.replace(ents[1], "OBJ")
ent1_pos = ent1.find("OBJ")
ent2_pos = ent2.find("OBJ")
if ent1_pos < ent2_pos:
return [ents[0], ents[1]]
elif ent1_pos > ent2_pos:
return [ents[1], ents[0]]

def prepare_state_rnn(self,state_description):
remove = ['=', '-', '\'', ':', '[', ']', 'eos', 'EOS', 'SOS', 'UNK', 'unk', 'sos', '<', '>']
for rm in remove:
state_description = state_description.replace(rm, '')

state_description = state_description.split('|')

ret = [self.sp.encode_as_ids('<s>' + s_desc + '</s>') for s_desc in state_description]

return self.pad_sequences(ret, maxlen=self.max_seq_length)

def pad_sequences(self,sequences, maxlen=None, dtype='int32', value=0.):
'''
Partially borrowed from Keras
# Arguments
sequences: list of lists where each element is a sequence
maxlen: int, maximum length
dtype: type to cast the resulting sequence.
value: float, value to pad the sequences to the desired value.
# Returns
x: numpy array with dimensions (number_of_sequences, maxlen)
'''
lengths = [len(s) for s in sequences]
nb_samples = len(sequences)
if maxlen is None:
maxlen = np.max(lengths)
# take the sample shape from the first non empty sequence
# checking for consistency in the main loop below.
sample_shape = tuple()
for s in sequences:
if len(s) > 0:
sample_shape = np.asarray(s).shape[1:]
break
x = (np.ones((nb_samples, maxlen) + sample_shape) * value).astype(dtype)
for idx, s in enumerate(sequences):
if len(s) == 0:
continue # empty list was found
# pre truncating
trunc = s[-maxlen:]
# check `trunc` has expected shape
trunc = np.asarray(trunc, dtype=dtype)
if trunc.shape[1:] != sample_shape:
raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' %
(trunc.shape[1:], idx, sample_shape))
# post padding
x[idx, :len(trunc)] = trunc

return x
Loading

0 comments on commit af9a77a

Please sign in to comment.