Skip to content

Commit

Permalink
Handle Box(X, Y, 3) observation space
Browse files Browse the repository at this point in the history
  • Loading branch information
lcswillems committed Nov 7, 2018
1 parent 8349b9a commit 94f0218
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 66 deletions.
2 changes: 1 addition & 1 deletion scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
# Define agent

model_dir = utils.get_model_dir(args.model)
agent = utils.Agent(model_dir, env.observation_space, args.argmax, args.procs)
agent = utils.Agent(args.env, env.observation_space, model_dir, args.argmax, args.procs)
print("CUDA available: {}\n".format(torch.cuda.is_available()))

# Initialize logs
Expand Down
4 changes: 2 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@

# Define obss preprocessor

preprocess_obss = utils.ObssPreprocessor(model_dir, envs[0].observation_space)
obs_space, preprocess_obss = utils.get_obss_preprocessor(args.env, envs[0].observation_space, model_dir)

# Load training status

Expand All @@ -117,7 +117,7 @@
acmodel = utils.load_model(model_dir)
logger.info("Model successfully loaded\n")
except OSError:
acmodel = ACModel(preprocess_obss.obs_space, envs[0].action_space, args.mem, args.instr)
acmodel = ACModel(obs_space, envs[0].action_space, args.mem, args.instr)
logger.info("Model successfully created\n")
logger.info("{}\n".format(acmodel))

Expand Down
5 changes: 2 additions & 3 deletions scripts/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
# Define agent

model_dir = utils.get_model_dir(args.model)
agent = utils.Agent(model_dir, env.observation_space, args.argmax)
agent = utils.Agent(args.env, env.observation_space, model_dir, args.argmax)

# Run the agent

Expand All @@ -47,10 +47,9 @@
while True:
if done:
obs = env.reset()
print("Instr:", obs["mission"])

time.sleep(args.pause)
renderer = env.render("human")
renderer = env.render()

action = agent.get_action(obs)
obs, reward, done, _ = env.step(action)
Expand Down
4 changes: 2 additions & 2 deletions utils/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ class Agent:
- to choose an action given an observation,
- to analyze the feedback (i.e. reward and done state) of its action."""

def __init__(self, model_dir, observation_space, argmax=False, num_envs=1):
self.preprocess_obss = utils.ObssPreprocessor(model_dir, observation_space)
def __init__(self, env_id, obs_space, model_dir, argmax=False, num_envs=1):
_, self.preprocess_obss = utils.get_obss_preprocessor(env_id, obs_space, model_dir)
self.acmodel = utils.load_model(model_dir)
self.argmax = argmax
self.num_envs = num_envs
Expand Down
109 changes: 51 additions & 58 deletions utils/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,64 @@
import re
import torch
import torch_rl
import gym

import utils

def get_obss_preprocessor(env_id, obs_space, model_dir):
# Check if it is a MiniGrid environment
if re.match("MiniGrid-.*", env_id):
obs_space = {"image": obs_space.spaces['image'].shape, "instr": 100}

vocab = Vocabulary(model_dir, obs_space["instr"])
def preprocess_obss(obss, device=None):
return torch_rl.DictList({
"image": preprocess_images([obs["image"] for obs in obss], device=device),
"instr": preprocess_instrs([obs["mission"] for obs in obss], vocab, device=device)
})

# Check if the obs_space is of type Box([X, Y, 3])
elif isinstance(obs_space, gym.spaces.Box) and len(obs_space.shape) == 3 and obs_space.shape[2] == 3:
obs_space = {"image": obs_space.shape}

def preprocess_obss(obss, device=None):
return torch_rl.DictList({
"image": preprocess_images([obs["image"] for obs in obss], device=device)
})

else:
raise "Unknown observation space: " + obs_space

return obs_space, preprocess_obss

def preprocess_images(images, device=None):
images = numpy.array(images)
return torch.tensor(images, device=device, dtype=torch.float)

def preprocess_instrs(instrs, vocab, device=None):
var_indexed_instrs = []
max_instr_len = 0

for instr in instrs:
tokens = re.findall("([a-z]+)", instr.lower())
var_indexed_instr = numpy.array([vocab[token] for token in tokens])
var_indexed_instrs.append(var_indexed_instr)
max_instr_len = max(len(var_indexed_instr), max_instr_len)

indexed_instrs = numpy.zeros((len(instrs), max_instr_len))

for i, indexed_instr in enumerate(var_indexed_instrs):
indexed_instrs[i, :len(indexed_instr)] = indexed_instr

return torch.tensor(indexed_instrs, device=device, dtype=torch.long)

class Vocabulary:
"""A mapping from tokens to ids with a capacity of `max_size` words.
It can be saved in a `vocab.json` file."""

def __init__(self, model_dir):
def __init__(self, model_dir, max_size):
self.path = utils.get_vocab_path(model_dir)
self.max_size = 100
self.max_size = max_size
self.vocab = {}
if os.path.exists(self.path):
self.vocab = json.load(open(self.path))
Expand All @@ -27,59 +75,4 @@ def __getitem__(self, token):

def save(self):
utils.create_folders_if_necessary(self.path)
json.dump(self.vocab, open(self.path, "w"))

class ObssPreprocessor:
"""A preprocessor of observations returned by the environment.
It converts MiniGrid observation space and MiniGrid observations
into the format that the model can handle."""

def __init__(self, model_dir, obs_space):
self.vocab = Vocabulary(model_dir)
self.obs_space = {
"image": obs_space.spaces['image'].shape,
"instr": self.vocab.max_size
}

def __call__(self, obss, device=None):
"""Converts a list of MiniGrid observations, i.e. a list of
(image, instruction) tuples into two PyTorch tensors.
The images are concatenated. The instructions are tokenified, then
tokens are converted into lists of ids using a Vocabulary object, and
finally, the lists of ids are concatenated.
Returns
-------
preprocessed_obss : DictList
Contains preprocessed images and preprocessed instructions.
"""

preprocessed_obss = torch_rl.DictList()

if "image" in self.obs_space.keys():
images = numpy.array([obs["image"] for obs in obss])
images = torch.tensor(images, device=device, dtype=torch.float)

preprocessed_obss.image = images

if "instr" in self.obs_space.keys():
raw_instrs = []
max_instr_len = 0

for obs in obss:
tokens = re.findall("([a-z]+)", obs["mission"].lower())
instr = numpy.array([self.vocab[token] for token in tokens])
raw_instrs.append(instr)
max_instr_len = max(len(instr), max_instr_len)

instrs = numpy.zeros((len(obss), max_instr_len))

for i, instr in enumerate(raw_instrs):
instrs[i, :len(instr)] = instr

instrs = torch.tensor(instrs, device=device, dtype=torch.long)

preprocessed_obss.instr = instrs

return preprocessed_obss
json.dump(self.vocab, open(self.path, "w"))

0 comments on commit 94f0218

Please sign in to comment.