From c02d0482ad0e3c1f2ac27eea6115afdf92af4e73 Mon Sep 17 00:00:00 2001 From: Swain Date: Thu, 17 Mar 2022 12:01:59 +0800 Subject: [PATCH] feature(nyz): add stochastic dueling network (#234) * feature(nyz): add stochastic dueling network * polish(nyz): polish sdn and add unittest --- ding/model/common/head.py | 139 ++++++++++++++++++++++++--- ding/model/common/tests/test_head.py | 19 +++- 2 files changed, 146 insertions(+), 12 deletions(-) diff --git a/ding/model/common/head.py b/ding/model/common/head.py index 2d1adf7fa1..2ad3c5e90f 100644 --- a/ding/model/common/head.py +++ b/ding/model/common/head.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.distributions import Normal, Independent from ding.torch_utils import fc_block, noise_block, NoiseLinearLayer, MLP from ding.rl_utils import beta_function_map @@ -25,7 +26,7 @@ def __init__( Overview: Init the Head according to arguments. Arguments: - - hidden_size (:obj:`int`): The ``hidden_size`` used before connected to ``DuelingHead`` + - hidden_size (:obj:`int`): The ``hidden_size`` used before connected to ``DiscreteHead`` - output_size (:obj:`int`): The number of output - layer_num (:obj:`int`): The num of layers used in the network to compute Q value output - activation (:obj:`nn.Module`): @@ -95,7 +96,7 @@ def __init__( Overview: Init the Head according to arguments. Arguments: - - hidden_size (:obj:`int`): The ``hidden_size`` used before connected to ``DuelingHead`` + - hidden_size (:obj:`int`): The ``hidden_size`` used before connected to ``DistributionHead`` - output_size (:obj:`int`): The num of output - layer_num (:obj:`int`): The num of layers used in the network to compute Q value output - activation (:obj:`nn.Module`): @@ -176,7 +177,7 @@ def __init__( Overview: Init the Head according to arguments. Arguments: - - hidden_size (:obj:`int`): The ``hidden_size`` used before connected to ``DuelingHead`` + - hidden_size (:obj:`int`): The ``hidden_size`` used before connected to ``RainbowHead`` - output_size (:obj:`int`): The num of output - layer_num (:obj:`int`): The num of layers used in the network to compute Q value output - activation (:obj:`nn.Module`): @@ -268,7 +269,7 @@ def __init__( Overview: Init the Head according to arguments. Arguments: - - hidden_size (:obj:`int`): The ``hidden_size`` used before connected to ``DuelingHead`` + - hidden_size (:obj:`int`): The ``hidden_size`` used before connected to ``QRDQNHead`` - output_size (:obj:`int`): The num of output - layer_num (:obj:`int`): The num of layers used in the network to compute Q value output - activation (:obj:`nn.Module`): @@ -348,7 +349,7 @@ def __init__( Overview: Init the Head according to arguments. Arguments: - - hidden_size (:obj:`int`): The ``hidden_size`` used before connected to ``DuelingHead`` + - hidden_size (:obj:`int`): The ``hidden_size`` used before connected to ``QuantileHead`` - output_size (:obj:`int`): The num of output - layer_num (:obj:`int`): The num of layers used in the network to compute Q value output - activation (:obj:`nn.Module`): @@ -532,8 +533,123 @@ def forward(self, x: torch.Tensor) -> Dict: """ a = self.A(x) v = self.V(x) - logit = a - a.mean(dim=-1, keepdim=True) + v - return {'logit': logit} + q_value = a - a.mean(dim=-1, keepdim=True) + v + return {'logit': q_value} + + +class StochasticDuelingHead(nn.Module): + + def __init__( + self, + hidden_size: int, + action_shape: int, + layer_num: int = 1, + a_layer_num: Optional[int] = None, + v_layer_num: Optional[int] = None, + activation: Optional[nn.Module] = nn.ReLU(), + norm_type: Optional[str] = None, + noise: Optional[bool] = False, + last_tanh: Optional[bool] = True, + ) -> None: + """ + Overview: + The Stochastic Dueling Network proposed in paper ACER (arxiv 1611.01224), dueling netwowrk architecture in \ + continuous action space. Initialize the head according to input arguments. + Arguments: + - hidden_size (:obj:`int`): The num of observation embedding size. + - action_shape (:obj:`int`): The num of continuous action shape, usually integer value. + - a_layer_num (:obj:`int`): The num of layers used in the network to compute action output. + - v_layer_num (:obj:`int`): The num of layers used in the network to compute value output. + - activation (:obj:`nn.Module`): The type of activation function to use in ``MLP`` after ``layer_fn``, \ + if ``None`` then default set to ``nn.ReLU()`` + - norm_type (:obj:`str`): The type of normalization to use, see ``ding.torch_utils.fc_block`` for \ + more details. + - noise (:obj:`bool`): Whether to use noisy ``fc_block`` for more exploration. + """ + super(StochasticDuelingHead, self).__init__() + if a_layer_num is None: + a_layer_num = layer_num + if v_layer_num is None: + v_layer_num = layer_num + layer = NoiseLinearLayer if noise else nn.Linear + block = noise_block if noise else fc_block + self.A = nn.Sequential( + MLP( + hidden_size + action_shape, + hidden_size, + hidden_size, + a_layer_num, + layer_fn=layer, + activation=activation, + norm_type=norm_type + ), block(hidden_size, 1) + ) + self.V = nn.Sequential( + MLP( + hidden_size, + hidden_size, + hidden_size, + v_layer_num, + layer_fn=layer, + activation=activation, + norm_type=norm_type + ), block(hidden_size, 1) + ) + if last_tanh: + self.tanh = nn.Tanh() + else: + self.tanh = None + + def forward( + self, + s: torch.Tensor, + a: torch.Tensor, + mu: torch.Tensor, + sigma: torch.Tensor, + sample_size: int = 10, + ) -> Dict[str, torch.Tensor]: + """ + Overview: + Use encoded observation, behaviour action and sampled actions with (mu, sigma) output by actor head \ + at current timestep to get dueling Q-value, i.e. continuous dueling head. + Arguments: + - s (:obj:`torch.Tensor`): The encoded embedding state tensor, determined with given ``hidden_size``, \ + i.e. shape is ``(B, N=hidden_size)``. + - a (:obj:`torch.Tensor`): The original continuous behaviour action, determined with ``action_size`` \ + i.e. shape is ``(B, N=action_size)``. + - mu (:obj:`torch.Tensor`): + The mu gaussian reparameterization output of actor head at current timestep, size (B, action_size) + - sigma (:obj:`torch.Tensor`): + The sigma gaussian reparameterization output of actor head at current timestep, size (B, action_size) + - sample_size (:obj:`int`): The number of samples for continuous action when computing the Q value + Returns: + - outputs (:obj:`Dict[str, torch.Tensor]`): Output dict data, including q_value and v_value tensor, \ + and their shape is ``(B, 1)``. + """ + + batch_size = s.shape[0] # batch_size or batch_size * T + hidden_size = s.shape[1] + action_size = a.shape[1] + state_cat_action = torch.cat((s, a), dim=1) # size (B, action_size + state_size) + a_value = self.A(state_cat_action) # size (B, 1) + v_value = self.V(s) # size (B, 1) + # size (B, sample_size, hidden_size) + expand_s = (torch.unsqueeze(s, 1)).expand((batch_size, sample_size, hidden_size)) + + # in case for gradient back propagation + dist = Independent(Normal(mu, sigma), 1) + action_sample = dist.rsample(sample_shape=(sample_size, )) + if self.tanh: + action_sample = self.tanh(action_sample) + # (sample_size, B, action_size)->(B, sample_size, action_size) + action_sample = action_sample.permute(1, 0, 2) + + # size (B, sample_size, action_size + hidden_size) + state_cat_action_sample = torch.cat((expand_s, action_sample), dim=-1) + a_val_sample = self.A(state_cat_action_sample) # size (B, sample_size, 1) + q_value = v_value + a_value - a_val_sample.mean(dim=1) # size (B, 1) + + return {'q_value': q_value, 'v_value': v_value} class RegressionHead(nn.Module): @@ -551,7 +667,7 @@ def __init__( Overview: Init the Head according to arguments. Arguments: - - hidden_size (:obj:`int`): The ``hidden_size`` used before connected to ``DuelingHead`` + - hidden_size (:obj:`int`): The ``hidden_size`` used before connected to ``RegressionHead`` - output_size (:obj:`int`): The num of output - final_tanh (:obj:`Optional[bool]`): Whether a final tanh layer is needed - layer_num (:obj:`int`): The num of layers used in the network to compute Q value output @@ -617,7 +733,7 @@ def __init__( Overview: Init the Head according to arguments. Arguments: - - hidden_size (:obj:`int`): The ``hidden_size`` used before connected to ``DuelingHead`` + - hidden_size (:obj:`int`): The ``hidden_size`` used before connected to ``ReparameterizationHead`` - output_size (:obj:`int`): The num of output - layer_num (:obj:`int`): The num of layers used in the network to compute Q value output - sigma_type (:obj:`Optional[str]`): Sigma type used in ``['fixed', 'independent', 'conditioned']`` @@ -660,8 +776,8 @@ def forward(self, x: torch.Tensor) -> Dict: Run ``MLP`` with ``ReparameterizationHead`` setups and return the result prediction dictionary. Necessary Keys: - - mu (:obj:`torch.Tensor`) Tensor of cells of updated mu values, with same size as ``x``. - - sigma (:obj:`torch.Tensor`) Tensor of cells of updated sigma values, with same size as ``x``. + - mu (:obj:`torch.Tensor`) Tensor of cells of updated mu values of size ``(B, action_size)`` + - sigma (:obj:`torch.Tensor`) Tensor of cells of updated sigma values of size ``(B, action_size)`` Examples: >>> head = ReparameterizationHead(64, 64, sigma_type='fixed') >>> inputs = torch.randn(4, 64) @@ -740,6 +856,7 @@ def forward(self, x: torch.Tensor) -> Dict: # discrete 'discrete': DiscreteHead, 'dueling': DuelingHead, + 'sdn': StochasticDuelingHead, 'distribution': DistributionHead, 'rainbow': RainbowHead, 'qrdqn': QRDQNHead, diff --git a/ding/model/common/tests/test_head.py b/ding/model/common/tests/test_head.py index 22181dc698..3c36640047 100644 --- a/ding/model/common/tests/test_head.py +++ b/ding/model/common/tests/test_head.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from ding.model.common.head import DuelingHead, ReparameterizationHead, MultiHead +from ding.model.common.head import DuelingHead, ReparameterizationHead, MultiHead, StochasticDuelingHead from ding.torch_utils import is_differentiable B = 4 @@ -67,3 +67,20 @@ def test_multi_head(self): self.output_check(head, outputs['logit']) for i, d in enumerate(output_size_list): assert outputs['logit'][i].shape == (B, d) + + @pytest.mark.tmp + def test_stochastic_dueling(self): + obs = torch.randn(B, embedding_dim) + behaviour_action = torch.randn(B, action_shape).clamp(-1, 1) + mu = torch.randn(B, action_shape).requires_grad_(True) + sigma = torch.rand(B, action_shape).requires_grad_(True) + model = StochasticDuelingHead(embedding_dim, action_shape, 3, 3) + + assert mu.grad is None and sigma.grad is None + outputs = model(obs, behaviour_action, mu, sigma) + self.output_check(model, outputs['q_value']) + assert isinstance(mu.grad, torch.Tensor) + print(mu.grad) + assert isinstance(sigma.grad, torch.Tensor) + assert outputs['q_value'].shape == (B, 1) + assert outputs['v_value'].shape == (B, 1)