Skip to content

Commit

Permalink
modified networks to be concatenate features
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 23, 2020
1 parent d6c54ac commit a1f4087
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions td3/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,38 +50,38 @@ def __init__(self, in_channels, n_actions, max_action, order, depth, multiplier,
super().__init__()
self.name = name
self.max_action = max_action
self.order = order

convs = []
prev_ch = in_channels*order
ch = multiplier
prev_ch, ch = in_channels, multiplier
for i in range(depth):
if i == depth - 1:
convs.append(nn.Conv2d(in_channels=prev_ch, out_channels=ch, kernel_size=4, padding=1, stride=2))
else:
convs.append(ConvNormAct(in_channels=prev_ch, out_channels=ch, mode='down'))
prev_ch = ch
ch *= 2
self.actor = nn.Sequential(
self.convs = nn.Sequential(
*convs,
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(ch, n_actions),
nn.Tanh()
)
nn.Flatten())
self.fc = nn.Linear(order*ch, n_actions)

def forward(self, imgs):
return self.actor(imgs) * self.max_action
img_feature = [self.convs(imgs[:, i*self.order:(i+1)*self.order, :, :]) for i in range(self.order)]
img_feature = torch.cat(img_feature, 1)
return torch.tanh(self.fc(img_feature)) * self.max_action


class ImageCritic(nn.Module):
def __init__(self, in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, name):
super().__init__()
self.name = name
self.order = order

# constructed simple cnn
convs = []
prev_ch = in_channels*order
ch = multiplier
prev_ch, ch = in_channels, multiplier
for i in range(depth):
if i == depth - 1:
convs.append(nn.Conv2d(in_channels=prev_ch, out_channels=ch, kernel_size=4, padding=1, stride=2))
Expand All @@ -99,13 +99,15 @@ def __init__(self, in_channels, n_actions, hidden_dim, action_embed_dim, order,
nn.Linear(hidden_dim, action_embed_dim)
)
self.combined_critic_head = nn.Sequential(
nn.Linear(ch + action_embed_dim, hidden_dim),
nn.Linear(ch*order + action_embed_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)

def forward(self, state, action):
img_embedding = self.avg_pool(self.convs(state)).squeeze()
img_embedding = [self.avg_pool(self.convs(
state[:, i*self.order:(i+1)*self.order, :, :])).squeeze() for i in range(self.order)]
img_embedding = torch.cat(img_embedding, 1)
action_embedding = self.action_head(action)
combined_embedding = torch.cat([img_embedding, action_embedding], dim=1)
return self.combined_critic_head(combined_embedding)

0 comments on commit a1f4087

Please sign in to comment.