Skip to content

Commit

Permalink
Overhall improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
lcswillems committed Jul 25, 2019
1 parent 84640da commit 1f2f5d9
Show file tree
Hide file tree
Showing 14 changed files with 237 additions and 228 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,10 @@ During training, logs are printed in your terminal (and saved in text and CSV fo

**Note:** `U` gives the update number, `F` the total number of frames, `FPS` the number of frames per second, `D` the total duration, `rR:μσmM` the mean, std, min and max reshaped return per episode, `F:μσmM` the mean, std, min and max number of frames per episode, `H` the entropy, `V` the value, `pL` the policy loss, `vL` the value loss and `` the gradient norm.

During training, logs might also be plotted in Tensorboard if `--tb` is added.
During training, logs are also plotted in Tensorboard:

<p><img src="README-rsrc/train-tensorboard.png"></p>

**Note:** `tensorboardX` package is required and can be installed with `pip3 install tensorboardX`.

<h2 id="scripts-visualize">scripts/visualize.py</h2>

An example of use:
Expand Down
22 changes: 10 additions & 12 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
import torch_ac
import gym


# Function from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py
def init_params(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
if classname.find("Linear") != -1:
m.weight.data.normal_(0, 1)
m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
if m.bias is not None:
m.bias.data.fill_(0)


class ACModel(nn.Module, torch_ac.RecurrentACModel):
def __init__(self, obs_space, action_space, use_memory=False, use_text=False):
super().__init__()
Expand Down Expand Up @@ -53,14 +54,11 @@ def __init__(self, obs_space, action_space, use_memory=False, use_text=False):
self.embedding_size += self.text_embedding_size

# Define actor's model
if isinstance(action_space, gym.spaces.Discrete):
self.actor = nn.Sequential(
nn.Linear(self.embedding_size, 64),
nn.Tanh(),
nn.Linear(64, action_space.n)
)
else:
raise ValueError("Unknown action space: " + str(action_space))
self.actor = nn.Sequential(
nn.Linear(self.embedding_size, 64),
nn.Tanh(),
nn.Linear(64, action_space.n)
)

# Define critic's model
self.critic = nn.Sequential(
Expand All @@ -81,7 +79,7 @@ def semi_memory_size(self):
return self.image_embedding_size

def forward(self, obs, memory):
x = torch.transpose(torch.transpose(obs.image, 1, 3), 2, 3)
x = obs.image.transpose(1, 3).transpose(2, 3)
x = self.image_conv(x)
x = x.reshape(x.shape[0], -1)

Expand All @@ -107,4 +105,4 @@ def forward(self, obs, memory):

def _get_embed_text(self, text):
_, hidden = self.text_rnn(self.word_embedding(text))
return hidden[-1]
return hidden[-1]
3 changes: 0 additions & 3 deletions requirements.pip

This file was deleted.

4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch-ac>=1.1.0
gym-minigrid
tensorboardX>=1.6
numpy>=1.3
38 changes: 20 additions & 18 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
#!/usr/bin/env python3

import argparse
import gym
import gym_minigrid
import time
import torch
from torch_ac.utils.penv import ParallelEnv

import utils


# Parse arguments

parser = argparse.ArgumentParser()
parser.add_argument("--env", required=True,
help="name of the environment to be run (REQUIRED)")
help="name of the environment (REQUIRED)")
parser.add_argument("--model", required=True,
help="name of the trained model (REQUIRED)")
parser.add_argument("--episodes", type=int, default=100,
Expand All @@ -32,50 +29,55 @@

utils.seed(args.seed)

# Generate environment
# Set device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}\n")

# Load environments

envs = []
for i in range(args.procs):
env = gym.make(args.env)
env.seed(args.seed + 10000*i)
env = utils.make_env(args.env, args.seed + 10000 * i)
envs.append(env)
env = ParallelEnv(envs)
print("Environments loaded\n")

# Define agent
# Load agent

model_dir = utils.get_model_dir(args.model)
agent = utils.Agent(args.env, env.observation_space, model_dir, args.argmax, args.procs)
print("CUDA available: {}\n".format(torch.cuda.is_available()))
agent = utils.Agent(env.observation_space, env.action_space, model_dir, device, args.argmax, args.procs)
print("Agent loaded\n")

# Initialize logs

logs = {"num_frames_per_episode": [], "return_per_episode": []}

# Run the agent
# Run agent

start_time = time.time()

obss = env.reset()

log_done_counter = 0
log_episode_return = torch.zeros(args.procs, device=agent.device)
log_episode_num_frames = torch.zeros(args.procs, device=agent.device)
log_episode_return = torch.zeros(args.procs, device=device)
log_episode_num_frames = torch.zeros(args.procs, device=device)

while log_done_counter < args.episodes:
actions = agent.get_actions(obss)
obss, rewards, dones, _ = env.step(actions)
agent.analyze_feedbacks(rewards, dones)

log_episode_return += torch.tensor(rewards, device=agent.device, dtype=torch.float)
log_episode_num_frames += torch.ones(args.procs, device=agent.device)
log_episode_return += torch.tensor(rewards, device=device, dtype=torch.float)
log_episode_num_frames += torch.ones(args.procs, device=device)

for i, done in enumerate(dones):
if done:
log_done_counter += 1
logs["return_per_episode"].append(log_episode_return[i].item())
logs["num_frames_per_episode"].append(log_episode_num_frames[i].item())

mask = 1 - torch.tensor(dones, device=agent.device, dtype=torch.float)
mask = 1 - torch.tensor(dones, device=device, dtype=torch.float)
log_episode_return *= mask
log_episode_num_frames *= mask

Expand All @@ -102,4 +104,4 @@

indexes = sorted(range(len(logs["return_per_episode"])), key=lambda k: logs["return_per_episode"][k])
for i in indexes[:n]:
print("- episode {}: R={}, F={}".format(i, logs["return_per_episode"][i], logs["num_frames_per_episode"][i]))
print("- episode {}: R={}, F={}".format(i, logs["return_per_episode"][i], logs["num_frames_per_episode"][i]))
Loading

0 comments on commit 1f2f5d9

Please sign in to comment.