From 10f66265a0fd716ea3176194de7900eb25fb635f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Thu, 19 Sep 2024 10:13:07 +0800 Subject: [PATCH] polish --- ding/model/template/language_transformer.py | 31 ++++++++++++----- .../tests/test_language_transformer.py | 34 ++++++++++++++++--- ding/policy/prompt_awr.py | 21 ++++++++---- ding/policy/prompt_pg.py | 2 ++ 4 files changed, 67 insertions(+), 21 deletions(-) diff --git a/ding/model/template/language_transformer.py b/ding/model/template/language_transformer.py index c758ffd753..38e68b318a 100644 --- a/ding/model/template/language_transformer.py +++ b/ding/model/template/language_transformer.py @@ -1,4 +1,4 @@ -from typing import List, Dict +from typing import List, Dict, Optional import torch from torch import nn @@ -18,13 +18,16 @@ class LanguageTransformer(nn.Module): Interfaces: ``__init__``, ``forward`` """ + mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] def __init__( self, model_name: str = "bert-base-uncased", add_linear: bool = False, embedding_size: int = 128, - freeze_encoder: bool = True + freeze_encoder: bool = True, + hidden_dim: int = 768, + norm_embedding: bool = False ) -> None: """ Overview: @@ -36,12 +39,16 @@ def __init__( - embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128. - freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \ defaults to be ``True``. + - hidden_dim (:obj:`int`): The embedding dimension of the encoding model (e.g. BERT). This value should \ + correspond to the model you use. For bert-base-uncased, this value is 768. + - norm_embedding (:obj:`bool`): Whether to normalize the embedding vectors. Default to be ``False``. """ super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForTokenClassification.from_pretrained(model_name) - in_channel = 768 if not add_linear else embedding_size + in_channel = hidden_dim if not add_linear else embedding_size self.value_head = nn.Linear(in_channel, 1) + self.norm = nn.Identity() if not norm_embedding else nn.LayerNorm(normalized_shape=in_channel) # Freeze transformer encoder and only train the linear layer if freeze_encoder: @@ -51,9 +58,7 @@ def __init__( if add_linear: # Add a small, adjustable linear layer on top of language model tuned through RL self.embedding_size = embedding_size - self.linear = nn.Linear( - self.model.config.hidden_size, embedding_size - ) # 768 for bert-base-uncased, distilbert-base-uncased + self.linear = nn.Linear(self.model.config.hidden_size, embedding_size) else: self.linear = None @@ -68,19 +73,27 @@ def _calc_embedding(self, x: list) -> torch.Tensor: last_hidden_states = output.hidden_states[-1] # Get [CLS] hidden states sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size + sentence_embedding = self.norm(sentence_embedding) if self.linear: sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size return sentence_embedding - def forward(self, train_samples: List[str], candidate_samples: List[str], mode='compute_actor') -> Dict: + def forward( + self, + train_samples: List[str], + candidate_samples: Optional[List[str]] = None, + mode: str = 'compute_actor' + ) -> Dict: """ Overview: LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores. + Different ``mode`` will forward with different network modules to get different outputs and save computation. Arguments: - train_samples (:obj:`List[str]`): One list of strings. - - candidate_samples (:obj:`List[str]`): The other list of strings to calculate the matching scores. + - candidate_samples (:obj:`Optional[List[str]]`): The other list of strings to calculate the matching scores. + - - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. Returns: - output (:obj:`Dict`): Output dict data, including the logit of matching scores and the \ corresponding ``torch.distributions.Categorical`` object. @@ -98,7 +111,7 @@ def forward(self, train_samples: List[str], candidate_samples: List[str], mode=' >>> scores = model(ctxt_list, cands_list) >>> assert scores.shape == (1, 3) """ - assert mode in ['compute_actor', 'compute_critic', 'compute_actor_critic'] + assert mode in self.mode prompt_embedding = self._calc_embedding(train_samples) res_dict = {} diff --git a/ding/model/template/tests/test_language_transformer.py b/ding/model/template/tests/test_language_transformer.py index 40095c2ab2..eaaaae5a84 100644 --- a/ding/model/template/tests/test_language_transformer.py +++ b/ding/model/template/tests/test_language_transformer.py @@ -17,9 +17,33 @@ def check_model(self): cands_list = [problems[pid] for pid in cand_pids] model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256) - scores = model(ctxt_list, cands_list) - assert scores.shape == (1, 3) + output = model(ctxt_list, cands_list, mode='compute_actor') + assert 'dist' in output.keys() and 'logit' in output.keys() and len(output.keys()) == 2 + assert output['logit'].shape == (1, 3) - model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256) - scores = model(ctxt_list, cands_list) - assert scores.shape == (1, 3) + output = model(ctxt_list, cands_list, mode='compute_critic') + assert 'value' in output.keys() and len(output.keys()) == 1 + assert output['value'].shape == (1, ) + + output = model(ctxt_list, cands_list, mode='compute_critic') + assert 'value' in output.keys() and 'dist' in output.keys() and 'logit' in output.keys() and len( + output.keys() + ) == 3 + assert output['value'].shape == (1, ) + assert output['logit'].shape == (1, 3) + + model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, norm_embedding=True) + output = model(ctxt_list, cands_list, mode='compute_actor') + assert 'dist' in output.keys() and 'logit' in output.keys() and len(output.keys()) == 2 + assert output['logit'].shape == (1, 3) + + output = model(ctxt_list, cands_list, mode='compute_critic') + assert 'value' in output.keys() and len(output.keys()) == 1 + assert output['value'].shape == (1, ) + + output = model(ctxt_list, cands_list, mode='compute_critic') + assert 'value' in output.keys() and 'dist' in output.keys() and 'logit' in output.keys() and len( + output.keys() + ) == 3 + assert output['value'].shape == (1, ) + assert output['logit'].shape == (1, 3) diff --git a/ding/policy/prompt_awr.py b/ding/policy/prompt_awr.py index a9f2438a17..d1471c7426 100644 --- a/ding/policy/prompt_awr.py +++ b/ding/policy/prompt_awr.py @@ -1,5 +1,5 @@ from collections import namedtuple -from typing import List, Dict, Any, Tuple +from typing import List, Dict, Any, Tuple, Union import torch @@ -17,6 +17,7 @@ class PromptAWRPolicy(Policy): Overview: Policy class of AWR (Advantage Weighted Regression) algorithm, proposed in https://arxiv.org/abs/1910.00177. Especially, this policy is designed for training a language model policy. + In this policy, the environment's observation includes the current context, a list of optional actions (strings). The final output of the policy is a set of optional actions with a size of ``shot_number``. """ config = dict( # (str) Name of the registered RL policy (refer to the "register_policy" function). @@ -31,6 +32,8 @@ class PromptAWRPolicy(Policy): priority_IS_weight=False, # (str) Type of action space used in the policy, with valid options ['discrete', 'continuous']. action_space='discrete', + # (int) The number of actions that can be done simultaneously in one timestep. + shot_number=1, # learn_mode configuration learn=dict( # (int) Number of updates per data collection. A2C requires this to be set to 1. @@ -41,17 +44,18 @@ class PromptAWRPolicy(Policy): learning_rate=0.001, # (Tuple[float, float]) Coefficients used for computing running averages of gradient and its square. betas=(0.9, 0.999), - beta=1.0, # (float) Term added to the denominator to improve numerical stability in optimizer. eps=1e-8, # (float) Maximum norm for gradients. grad_norm=0.5, # (float) Scaling factor for value network loss relative to policy network loss. value_weight=0.5, + # (float) Coefficient that controls the exp scale in awr algorithm. + beta=1.0, # (float) Weight of entropy regularization in the loss function. entropy_weight=0.01, - # (bool) Flag to enable normalization of advantages. - adv_norm=False, + # (Tuple[float, float]) The range of adv. Value that exceeds this range will be clipped. + adv_range=(-0.5, 0.5), # (bool) If set to True, the 'done' signals that indicate the end of an episode due to environment time # limits are disregarded. By default, this is set to False. This setting is particularly useful for tasks # that have a predetermined episode length, such as HalfCheetah and various other MuJoCo environments, @@ -149,10 +153,13 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: total_policy_loss, total_entropy_loss, total_value_loss = 0, 0, 0 for ii in range(self._cfg.shot_number): log_prob = output['dist'].log_prob(real_act[:, ii]) - policy_loss = -(log_prob * torch.exp((return_ - batch['value']) / self._cfg.learn.beta)).mean() - value_loss = ((return_ - output['value']) ** 2).mean() + adv = torch.clamp( + return_ - batch['value'], min=self._cfg.learn.norm_range[0], max=self._cfg.learn.norm_range[1] + ) + policy_loss = -(log_prob * torch.exp(adv / self._cfg.learn.beta)).mean() total_policy_loss += policy_loss - total_value_loss += value_loss + value_loss = ((return_ - output['value']) ** 2).mean() + total_value_loss += value_loss total_entropy_loss += -self._cfg.learn.entropy_weight * output['dist'].entropy().mean() total_loss = total_entropy_loss + total_policy_loss + total_value_loss diff --git a/ding/policy/prompt_pg.py b/ding/policy/prompt_pg.py index a5e3bfcba2..a76e0e5faf 100644 --- a/ding/policy/prompt_pg.py +++ b/ding/policy/prompt_pg.py @@ -26,6 +26,8 @@ class PromptPGPolicy(Policy): on_policy=True, # for pg strictly on policy algorithm, this line should not be modified by users # (bool) whether to use deterministic action for evaluation. deterministic_eval=True, + # (int) The number of actions that can be done simultaneously in one timestep. + shot_number=1, learn=dict( # (int) the number of samples for one update. batch_size=64,