Skip to content

Commit

Permalink
added rendering for img input
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 23, 2020
1 parent 289fde2 commit 6d9c819
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 18 deletions.
16 changes: 6 additions & 10 deletions td3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ def __init__(self, env, alpha, beta, hidden_dims, tau,
self.critic_1 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'critic_1').to(self.device)
self.critic_2 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'critic_2').to(self.device)

self.critic_optimizer = torch.optim.Adam(
chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=beta)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=alpha)

self.target_actor = ImageActor(in_channels, n_actions, self.max_action, order, depth, multiplier, 'target_actor').to(self.device)
self.target_critic_1 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'target_critic_1').to(self.device)
self.target_critic_2 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'target_critic_2').to(self.device)
Expand All @@ -67,14 +63,14 @@ def __init__(self, env, alpha, beta, hidden_dims, tau,
self.critic_1 = Critic(state_space, hidden_dims, n_actions, 'critic_1').to(self.device)
self.critic_2 = Critic(state_space, hidden_dims, n_actions, 'critic_2').to(self.device)

self.critic_optimizer = torch.optim.Adam(
chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=beta)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=alpha)

self.target_actor = Actor(state_space, hidden_dims, n_actions, self.max_action, 'target_actor').to(self.device)
self.target_critic_1 = Critic(state_space, hidden_dims, n_actions, 'target_critic_1').to(self.device)
self.target_critic_2 = Critic(state_space, hidden_dims, n_actions, 'target_critic_2').to(self.device)

self.critic_optimizer = torch.optim.Adam(
chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=beta)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=alpha)

# copy weights
self.update_network_parameters(tau=1)

Expand All @@ -87,8 +83,8 @@ def _get_noise(self, clip=True):
def _clamp_action_bound(self, action):
return action.clamp(self.min_action, self.max_action)

def choose_action(self, observation):
if self.time_step < self.warmup:
def choose_action(self, observation, rendering=False):
if self.time_step < self.warmup or not rendering:
mu = self._get_noise(clip=False)
else:
state = torch.tensor(observation, dtype=torch.float).to(self.device)
Expand Down
2 changes: 1 addition & 1 deletion td3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

# training hp params
parser.add_argument('--n_episodes', type=int, default=1000, help='Number of episodes')
parser.add_argument('--batch_size', type=int, default=100, help='Batch size')
parser.add_argument('--batch_size', type=int, default=512, help='Batch size')
parser.add_argument('--alpha', type=float, default=0.001, help='Learning rate actor')
parser.add_argument('--beta', type=float, default=0.001, help='Learning rate critic')
parser.add_argument('--warmup', type=int, default=1000, help='Number of warmup steps')
Expand Down
37 changes: 30 additions & 7 deletions td3/render.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import gym
from tqdm import tqdm
import torch
from collections import deque

from td3.agent import Agent
from td3.main import args
Expand All @@ -10,26 +12,47 @@
args.checkpoint_dir += f'/{args.env_name}_td3.pth'
# env & agent
env = gym.make(args.env_name)

if args.virtual_display:
import pyvirtualdisplay
_display = pyvirtualdisplay.Display(visible=False, size=(1400, 900))
_ = _display.start()

if args.img_input:
env.reset()
env = PixelObservationWrapper(env)

agent = Agent(env, args.alpha, args.beta, args.hidden_dims, args.tau, args.batch_size,
args.gamma, args.d, 0, args.max_size, args.c, args.sigma,
args.one_device, args.log_dir, args.checkpoint_dir)
args.gamma, args.d, args.warmup, args.max_size, args.c, args.sigma,
args.one_device, args.log_dir, args.checkpoint_dir, args.img_input,
args.in_channels, args.order, args.depth, args.multiplier,
args.action_embed_dim, args.hidden_dim, args.crop_dim)
best_score = env.reward_range[0]
load_weights(args.checkpoint_dir,
[agent.actor] , ['actor'])
episodes = tqdm(range(args.n_episodes))

for e in episodes:
# resetting
state = env.reset()
done = False
score = 0
if args.img_input:
state_queue = deque(
[preprocess_img(state['pixels'], args.crop_dim) for _ in range(args.order)],
maxlen=args.order)
state = torch.cat(list(state_queue), 1).cpu().numpy()
done, score = False, 0

while not done:
action = agent.choose_action(state)
state_, reward, done, _ = env.step(action)
action = agent.choose_action(state, rendering=True)
state_, reward, _, _ = env.step(action)
if isinstance(reward, np.ndarray):
reward = reward[0]
if args.img_input:
state_queue.append(preprocess_img(state_['pixels'], args.crop_dim))
state = torch.cat(list(state_queue), 1).cpu().numpy()

# reset, log & render
score += reward
state = state_
episodes.set_postfix({'Reward': reward})
env.render()
if score > best_score:
Expand Down

0 comments on commit 6d9c819

Please sign in to comment.