diff --git a/td3/networks.py b/td3/networks.py index ce08129..c24fa14 100644 --- a/td3/networks.py +++ b/td3/networks.py @@ -50,10 +50,10 @@ def __init__(self, in_channels, n_actions, max_action, order, depth, multiplier, super().__init__() self.name = name self.max_action = max_action + self.order = order convs = [] - prev_ch = in_channels*order - ch = multiplier + prev_ch, ch = in_channels, multiplier for i in range(depth): if i == depth - 1: convs.append(nn.Conv2d(in_channels=prev_ch, out_channels=ch, kernel_size=4, padding=1, stride=2)) @@ -61,27 +61,27 @@ def __init__(self, in_channels, n_actions, max_action, order, depth, multiplier, convs.append(ConvNormAct(in_channels=prev_ch, out_channels=ch, mode='down')) prev_ch = ch ch *= 2 - self.actor = nn.Sequential( + self.convs = nn.Sequential( *convs, nn.AdaptiveAvgPool2d(1), - nn.Flatten(), - nn.Linear(ch, n_actions), - nn.Tanh() - ) + nn.Flatten()) + self.fc = nn.Linear(order*ch, n_actions) def forward(self, imgs): - return self.actor(imgs) * self.max_action + img_feature = [self.convs(imgs[:, i*self.order:(i+1)*self.order, :, :]) for i in range(self.order)] + img_feature = torch.cat(img_feature, 1) + return torch.tanh(self.fc(img_feature)) * self.max_action class ImageCritic(nn.Module): def __init__(self, in_channels, n_actions, hidden_dim, action_embed_dim, order, depth, multiplier, name): super().__init__() self.name = name + self.order = order # constructed simple cnn convs = [] - prev_ch = in_channels*order - ch = multiplier + prev_ch, ch = in_channels, multiplier for i in range(depth): if i == depth - 1: convs.append(nn.Conv2d(in_channels=prev_ch, out_channels=ch, kernel_size=4, padding=1, stride=2)) @@ -99,13 +99,15 @@ def __init__(self, in_channels, n_actions, hidden_dim, action_embed_dim, order, nn.Linear(hidden_dim, action_embed_dim) ) self.combined_critic_head = nn.Sequential( - nn.Linear(ch + action_embed_dim, hidden_dim), + nn.Linear(ch*order + action_embed_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1) ) def forward(self, state, action): - img_embedding = self.avg_pool(self.convs(state)).squeeze() + img_embedding = [self.avg_pool(self.convs( + state[:, i*self.order:(i+1)*self.order, :, :])).squeeze() for i in range(self.order)] + img_embedding = torch.cat(img_embedding, 1) action_embedding = self.action_head(action) combined_embedding = torch.cat([img_embedding, action_embedding], dim=1) return self.combined_critic_head(combined_embedding)