Skip to content

Commit

Permalink
feature(nyz): add encoder in MAVAC (#823)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Aug 14, 2024
1 parent a54d475 commit ae3ddc6
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 34 deletions.
78 changes: 44 additions & 34 deletions ding/model/template/mavac.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Dict, Optional
from typing import Union, Dict, Tuple, Optional
import torch
import torch.nn as nn

Expand Down Expand Up @@ -35,6 +35,7 @@ def __init__(
norm_type: Optional[str] = None,
sigma_type: Optional[str] = 'independent',
bound_type: Optional[str] = None,
encoder: Optional[Tuple[torch.nn.Module, torch.nn.Module]] = None,
) -> None:
"""
Overview:
Expand Down Expand Up @@ -66,6 +67,9 @@ def __init__(
to ``independent``, which means state-independent sigma parameters.
- bound_type (:obj:`Optional[str]`): The type of action bound methods in continuous action space, defaults \
to ``None``, which means no bound.
- encoder (:obj:`Optional[Tuple[torch.nn.Module, torch.nn.Module]]`): The encoder module list, defaults \
to ``None``, you can define your own actor and critic encoder module and pass it into MAVAC to \
deal with different observation space.
"""
super(MAVAC, self).__init__()
agent_obs_shape: int = squeeze(agent_obs_shape)
Expand All @@ -74,42 +78,38 @@ def __init__(
self.global_obs_shape, self.agent_obs_shape, self.action_shape = global_obs_shape, agent_obs_shape, action_shape
self.action_space = action_space
# Encoder Type
# We directly connect the Head after a Liner layer instead of using the 3-layer FCEncoder.
# In SMAC task it can obviously improve the performance.
# Users can change the model according to their own needs.
self.actor_encoder = nn.Identity()
self.critic_encoder = nn.Identity()
# Head Type
self.critic_head = nn.Sequential(
nn.Linear(global_obs_shape, critic_head_hidden_size), activation,
RegressionHead(
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
if encoder:
self.actor_encoder, self.critic_encoder = encoder
else:
# We directly connect the Head after a Liner layer instead of using the 3-layer FCEncoder.
# In SMAC task it can obviously improve the performance.
# Users can change the model according to their own needs.
self.actor_encoder = nn.Sequential(
nn.Linear(agent_obs_shape, actor_head_hidden_size),
activation,
)
self.critic_encoder = nn.Sequential(
nn.Linear(global_obs_shape, critic_head_hidden_size),
activation,
)
# Head Type
self.critic_head = RegressionHead(
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
)
assert self.action_space in ['discrete', 'continuous'], self.action_space
if self.action_space == 'discrete':
self.actor_head = nn.Sequential(
nn.Linear(agent_obs_shape, actor_head_hidden_size), activation,
DiscreteHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
activation=activation,
norm_type=norm_type
)
self.actor_head = DiscreteHead(
actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type
)
elif self.action_space == 'continuous':
self.actor_head = nn.Sequential(
nn.Linear(agent_obs_shape, actor_head_hidden_size), activation,
ReparameterizationHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
sigma_type=sigma_type,
activation=activation,
norm_type=norm_type,
bound_type=bound_type
)
self.actor_head = ReparameterizationHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
sigma_type=sigma_type,
activation=activation,
norm_type=norm_type,
bound_type=bound_type
)
# must use list, not nn.ModuleList
self.actor = [self.actor_encoder, self.actor_head]
Expand Down Expand Up @@ -261,7 +261,7 @@ def compute_actor_critic(self, x: Dict) -> Dict:
- value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
Shapes:
- logit (:obj:`torch.FloatTensor`): :math:`(B, M, N)`, where B is batch size and N is ``action_shape`` \
and M is ``agent_num``.
and M is ``agent_num``.
- value (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch sizeand M is ``agent_num``.
Examples:
Expand All @@ -275,6 +275,16 @@ def compute_actor_critic(self, x: Dict) -> Dict:
>>> assert outputs['value'].shape == torch.Size([10, 8])
>>> assert outputs['logit'].shape == torch.Size([10, 8, 14])
"""
logit = self.compute_actor(x)['logit']
value = self.compute_critic(x)['value']
x_actor = self.actor_encoder(x['agent_state'])
x_critic = self.critic_encoder(x['global_state'])

if self.action_space == 'discrete':
action_mask = x['action_mask']
x = self.actor_head(x_actor)
logit = x['logit']
logit[action_mask == 0.0] = -99999999
elif self.action_space == 'continuous':
x = self.actor_head(x_actor)
logit = x
value = self.critic_head(x_critic)['pred']
return {'logit': logit, 'value': value}
37 changes: 37 additions & 0 deletions ding/model/template/tests/test_mavac.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import numpy as np
import torch
import torch.nn as nn
from itertools import product

from ding.model import mavac
Expand Down Expand Up @@ -50,3 +51,39 @@ def test_vac(self, agent_obs_shape, global_obs_shape):
value = model(data, mode='compute_critic')['value']
assert value.shape == (B, agent_num)
self.output_check(model.critic, value, action_shape)

def test_vac_with_encoder(self, agent_obs_shape, global_obs_shape):
data = {
'agent_state': torch.randn(B, agent_num, agent_obs_shape),
'global_state': torch.randn(B, agent_num, global_obs_shape),
'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
}

actor_size, critic_size = 128, 128
encoder = [nn.Linear(agent_obs_shape, actor_size), nn.Linear(global_obs_shape, critic_size)]
model = MAVAC(
agent_obs_shape,
global_obs_shape,
action_shape,
agent_num,
encoder=encoder,
actor_head_hidden_size=actor_size,
critic_head_hidden_size=critic_size
)

logit = model(data, mode='compute_actor_critic')['logit']
value = model(data, mode='compute_actor_critic')['value']

outputs = value.sum() + logit.sum()
self.output_check(model, outputs, action_shape)

for p in model.parameters():
p.grad = None
logit = model(data, mode='compute_actor')['logit']
self.output_check(model.actor, logit, model.action_shape)

for p in model.parameters():
p.grad = None
value = model(data, mode='compute_critic')['value']
assert value.shape == (B, agent_num)
self.output_check(model.critic, value, action_shape)

0 comments on commit ae3ddc6

Please sign in to comment.