Skip to content

Commit

Permalink
modified network for td3 image actor & critic
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 28, 2020
1 parent f24e842 commit 101d1df
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 62 deletions.
20 changes: 9 additions & 11 deletions td3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, env, alpha, beta, hidden_dims, tau,
img_input, in_channels, order, depth, multiplier,
action_embed_dim, hidden_dim, crop_dim):
if img_input:
input_dim = [in_channels*order, crop_dim, crop_dim]
input_dim = [in_channels * order, crop_dim, crop_dim]
else:
input_dim = env.observation_space.shape
state_space = input_dim[0]
Expand Down Expand Up @@ -49,16 +49,14 @@ def __init__(self, env, alpha, beta, hidden_dims, tau,

# networks & optimizers
if img_input:
self.actor = ImageActor(in_channels, n_actions, hidden_dim, self.max_action, order, depth, multiplier, 'actor').to(self.device)
self.critic_1 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'critic_1').to(self.device)
self.critic_2 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'critic_2').to(self.device)

self.target_actor = ImageActor(in_channels, n_actions, hidden_dim, self.max_action, order, depth, multiplier, 'target_actor').to(self.device)
self.target_critic_1 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'target_critic_1').to(self.device)
self.target_critic_2 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, 'target_critic_2').to(self.device)
print('actor')
print(self.actor)

self.actor = ImageActor(in_channels, n_actions, hidden_dim, self.max_action, order, depth, multiplier, crop_dim, 'actor').to(self.device)
self.critic_1 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, crop_dim, 'critic_1').to(self.device)
self.critic_2 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, crop_dim, 'critic_2').to(self.device)

self.target_actor = ImageActor(in_channels, n_actions, hidden_dim, self.max_action, order, depth, multiplier, crop_dim, 'target_actor').to(self.device)
self.target_critic_1 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, crop_dim, 'target_critic_1').to(self.device)
self.target_critic_2 = ImageCritic(in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, crop_dim, 'target_critic_2').to(self.device)

# physics networks
else:
self.actor = Actor(state_space, hidden_dims, n_actions, self.max_action, 'actor').to(self.device)
Expand Down
87 changes: 44 additions & 43 deletions td3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,46 +79,47 @@
score_history = deque([], maxlen=args.window_size)
episodes = tqdm(range(args.n_episodes))

# for e in episodes:
# # resetting
# state = env.reset()
# if args.img_input:
# state_queue = deque(
# [preprocess_img(state['pixels'], args.crop_dim) for _ in range(args.order)],
# maxlen=args.order)
# state = torch.cat(list(state_queue), 1).cpu().numpy()
# done, score = False, 0

# while not done:
# action = agent.choose_action(state)
# state_, reward, done, _ = env.step(action)
# if isinstance(reward, np.ndarray):
# reward = reward[0]
# if args.img_input:
# state_queue.append(preprocess_img(state_['pixels'], args.crop_dim))
# state_ = torch.cat(list(state_queue), 1).cpu().numpy()
# agent.remember(state, action, reward, state_, done)
# agent.learn()

# # reset, log & render
# score += reward
# state = state_
# episodes.set_postfix({'Reward': reward, 'Iteration': agent.time_step})
# if args.no_render:
# continue
# env.render()

# # logging
# score_history.append(score)
# moving_avg = sum(score_history) / len(score_history)
# agent.add_scalar('Average Score', moving_avg, global_step=e)

# # save weights @ best score
# if moving_avg > best_score:
# best_score = moving_avg
# agent.save_networks()

# tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, \
# Episode Score: {score}, \
# Average Score: {moving_avg}, \
# Best Score: {best_score}')
for e in episodes:
# resetting
state = env.reset()
if args.img_input:
state_queue = deque(
[preprocess_img(state['pixels'], args.crop_dim) for _ in range(args.order)],
maxlen=args.order)
state = torch.cat(list(state_queue), 1).cpu().numpy()
done, score = False, 0

while not done:
action = agent.choose_action(state)
state_, reward, done, _ = env.step(action)
if isinstance(reward, np.ndarray):
reward = reward[0]
if args.img_input:
state_queue.append(preprocess_img(state_['pixels'], args.crop_dim))
state_ = torch.cat(list(state_queue), 1).cpu().numpy()
agent.remember(state, action, reward, state_, done)
agent.learn()

# reset, log & render
score += reward
state = state_
episodes.set_postfix({'Reward': reward, 'Iteration': agent.time_step})
if args.no_render:
continue
env.render()

# logging
score_history.append(score)
moving_avg = sum(score_history) / len(score_history)
agent.add_scalar('Average Score', moving_avg, global_step=e)
agent.add_scalar('Episode Score', score, global_step=e)

# save weights @ best score
if moving_avg > best_score:
best_score = moving_avg
agent.save_networks()

tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, \
Episode Score: {score}, \
Average Score: {moving_avg}, \
Best Score: {best_score}')
20 changes: 12 additions & 8 deletions td3/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,27 +46,28 @@ def forward(self, state):


class ImageActor(nn.Module):
def __init__(self, in_channels, n_actions, hidden_dim, max_action, order, depth, multiplier, name):
def __init__(self, in_channels, n_actions, hidden_dim, max_action, order, depth, multiplier, img_size, name):
super().__init__()
self.name = name
self.max_action = max_action
self.order = order
self.min_hw = img_size // (2 ** depth)

convs = []
prev_ch, ch = in_channels, multiplier
for i in range(depth):
if i == depth - 1:
convs.append(nn.Conv2d(in_channels=prev_ch, out_channels=ch, kernel_size=4, padding=1, stride=2))
convs.append(nn.BatchNorm2d(ch))
else:
convs.append(ConvNormAct(in_channels=prev_ch, out_channels=ch, mode='down'))
prev_ch = ch
ch *= 2
self.convs = nn.Sequential(
*convs,
nn.AdaptiveAvgPool2d(1),
nn.Flatten())
self.fc = nn.Sequential(
nn.Linear(order * ch, hidden_dim),
nn.Linear(order * ch * self.min_hw ** 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_actions))

Expand All @@ -77,23 +78,25 @@ def forward(self, imgs):


class ImageCritic(nn.Module):
def __init__(self, in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, name):
def __init__(self, in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, img_size, name):
super().__init__()
self.name = name
self.order = order
self.min_hw = img_size // (2 ** depth)

# constructed simple cnn
convs = []
prev_ch, ch = in_channels, multiplier
for i in range(depth):
if i == depth - 1:
convs.append(nn.Conv2d(in_channels=prev_ch, out_channels=ch, kernel_size=4, padding=1, stride=2))
convs.append(nn.BatchNorm2d(ch))
else:
convs.append(ConvNormAct(in_channels=prev_ch, out_channels=ch, mode='down'))
prev_ch = ch
ch *= 2
self.ch = ch
self.convs = nn.Sequential(*convs)
self.avg_pool = nn.AdaptiveAvgPool2d(1)

# embed actions, concat w/ img and output critic
self.action_head = nn.Sequential(
Expand All @@ -102,14 +105,15 @@ def __init__(self, in_channels, n_actions, hidden_dim, action_embed_dim, order,
nn.Linear(hidden_dim, action_embed_dim)
)
self.combined_critic_head = nn.Sequential(
nn.Linear(ch*order + action_embed_dim, hidden_dim),
nn.Linear(ch * order * self.min_hw ** 2 + action_embed_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)

def forward(self, states, action):
img_embedding = [self.avg_pool(
self.convs(state)).squeeze() for state in states.chunk(self.order, 1)]
batch_size = states.size(0)
img_embedding = [self.convs(state).view(
batch_size, self.ch * self.min_hw * self.min_hw) for state in states.chunk(self.order, 1)]
img_embedding = torch.cat(img_embedding, 1)
action_embedding = self.action_head(action)
combined_embedding = torch.cat([img_embedding, action_embedding], dim=1)
Expand Down

0 comments on commit 101d1df

Please sign in to comment.