Skip to content

Commit

Permalink
polish
Browse files Browse the repository at this point in the history
  • Loading branch information
‘whl’ committed Sep 19, 2024
1 parent 8a47348 commit 10f6626
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 21 deletions.
31 changes: 22 additions & 9 deletions ding/model/template/language_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict
from typing import List, Dict, Optional
import torch
from torch import nn

Expand All @@ -18,13 +18,16 @@ class LanguageTransformer(nn.Module):
Interfaces:
``__init__``, ``forward``
"""
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']

def __init__(
self,
model_name: str = "bert-base-uncased",
add_linear: bool = False,
embedding_size: int = 128,
freeze_encoder: bool = True
freeze_encoder: bool = True,
hidden_dim: int = 768,
norm_embedding: bool = False
) -> None:
"""
Overview:
Expand All @@ -36,12 +39,16 @@ def __init__(
- embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128.
- freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \
defaults to be ``True``.
- hidden_dim (:obj:`int`): The embedding dimension of the encoding model (e.g. BERT). This value should \
correspond to the model you use. For bert-base-uncased, this value is 768.
- norm_embedding (:obj:`bool`): Whether to normalize the embedding vectors. Default to be ``False``.
"""
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForTokenClassification.from_pretrained(model_name)
in_channel = 768 if not add_linear else embedding_size
in_channel = hidden_dim if not add_linear else embedding_size
self.value_head = nn.Linear(in_channel, 1)
self.norm = nn.Identity() if not norm_embedding else nn.LayerNorm(normalized_shape=in_channel)

# Freeze transformer encoder and only train the linear layer
if freeze_encoder:
Expand All @@ -51,9 +58,7 @@ def __init__(
if add_linear:
# Add a small, adjustable linear layer on top of language model tuned through RL
self.embedding_size = embedding_size
self.linear = nn.Linear(
self.model.config.hidden_size, embedding_size
) # 768 for bert-base-uncased, distilbert-base-uncased
self.linear = nn.Linear(self.model.config.hidden_size, embedding_size)
else:
self.linear = None

Expand All @@ -68,19 +73,27 @@ def _calc_embedding(self, x: list) -> torch.Tensor:
last_hidden_states = output.hidden_states[-1]
# Get [CLS] hidden states
sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size
sentence_embedding = self.norm(sentence_embedding)

if self.linear:
sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size

return sentence_embedding

def forward(self, train_samples: List[str], candidate_samples: List[str], mode='compute_actor') -> Dict:
def forward(
self,
train_samples: List[str],
candidate_samples: Optional[List[str]] = None,
mode: str = 'compute_actor'
) -> Dict:
"""
Overview:
LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores.
Different ``mode`` will forward with different network modules to get different outputs and save computation.
Arguments:
- train_samples (:obj:`List[str]`): One list of strings.
- candidate_samples (:obj:`List[str]`): The other list of strings to calculate the matching scores.
- candidate_samples (:obj:`Optional[List[str]]`): The other list of strings to calculate the matching scores.
- - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.
Returns:
- output (:obj:`Dict`): Output dict data, including the logit of matching scores and the \
corresponding ``torch.distributions.Categorical`` object.
Expand All @@ -98,7 +111,7 @@ def forward(self, train_samples: List[str], candidate_samples: List[str], mode='
>>> scores = model(ctxt_list, cands_list)
>>> assert scores.shape == (1, 3)
"""
assert mode in ['compute_actor', 'compute_critic', 'compute_actor_critic']
assert mode in self.mode
prompt_embedding = self._calc_embedding(train_samples)

res_dict = {}
Expand Down
34 changes: 29 additions & 5 deletions ding/model/template/tests/test_language_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,33 @@ def check_model(self):
cands_list = [problems[pid] for pid in cand_pids]

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
scores = model(ctxt_list, cands_list)
assert scores.shape == (1, 3)
output = model(ctxt_list, cands_list, mode='compute_actor')
assert 'dist' in output.keys() and 'logit' in output.keys() and len(output.keys()) == 2
assert output['logit'].shape == (1, 3)

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256)
scores = model(ctxt_list, cands_list)
assert scores.shape == (1, 3)
output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and len(output.keys()) == 1
assert output['value'].shape == (1, )

output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and 'dist' in output.keys() and 'logit' in output.keys() and len(
output.keys()
) == 3
assert output['value'].shape == (1, )
assert output['logit'].shape == (1, 3)

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, norm_embedding=True)
output = model(ctxt_list, cands_list, mode='compute_actor')
assert 'dist' in output.keys() and 'logit' in output.keys() and len(output.keys()) == 2
assert output['logit'].shape == (1, 3)

output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and len(output.keys()) == 1
assert output['value'].shape == (1, )

output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and 'dist' in output.keys() and 'logit' in output.keys() and len(
output.keys()
) == 3
assert output['value'].shape == (1, )
assert output['logit'].shape == (1, 3)
21 changes: 14 additions & 7 deletions ding/policy/prompt_awr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import namedtuple
from typing import List, Dict, Any, Tuple
from typing import List, Dict, Any, Tuple, Union

import torch

Expand All @@ -17,6 +17,7 @@ class PromptAWRPolicy(Policy):
Overview:
Policy class of AWR (Advantage Weighted Regression) algorithm, proposed in https://arxiv.org/abs/1910.00177.
Especially, this policy is designed for training a language model policy.
In this policy, the environment's observation includes the current context, a list of optional actions (strings). The final output of the policy is a set of optional actions with a size of ``shot_number``.
"""
config = dict(
# (str) Name of the registered RL policy (refer to the "register_policy" function).
Expand All @@ -31,6 +32,8 @@ class PromptAWRPolicy(Policy):
priority_IS_weight=False,
# (str) Type of action space used in the policy, with valid options ['discrete', 'continuous'].
action_space='discrete',
# (int) The number of actions that can be done simultaneously in one timestep.
shot_number=1,
# learn_mode configuration
learn=dict(
# (int) Number of updates per data collection. A2C requires this to be set to 1.
Expand All @@ -41,17 +44,18 @@ class PromptAWRPolicy(Policy):
learning_rate=0.001,
# (Tuple[float, float]) Coefficients used for computing running averages of gradient and its square.
betas=(0.9, 0.999),
beta=1.0,
# (float) Term added to the denominator to improve numerical stability in optimizer.
eps=1e-8,
# (float) Maximum norm for gradients.
grad_norm=0.5,
# (float) Scaling factor for value network loss relative to policy network loss.
value_weight=0.5,
# (float) Coefficient that controls the exp scale in awr algorithm.
beta=1.0,
# (float) Weight of entropy regularization in the loss function.
entropy_weight=0.01,
# (bool) Flag to enable normalization of advantages.
adv_norm=False,
# (Tuple[float, float]) The range of adv. Value that exceeds this range will be clipped.
adv_range=(-0.5, 0.5),
# (bool) If set to True, the 'done' signals that indicate the end of an episode due to environment time
# limits are disregarded. By default, this is set to False. This setting is particularly useful for tasks
# that have a predetermined episode length, such as HalfCheetah and various other MuJoCo environments,
Expand Down Expand Up @@ -149,10 +153,13 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
total_policy_loss, total_entropy_loss, total_value_loss = 0, 0, 0
for ii in range(self._cfg.shot_number):
log_prob = output['dist'].log_prob(real_act[:, ii])
policy_loss = -(log_prob * torch.exp((return_ - batch['value']) / self._cfg.learn.beta)).mean()
value_loss = ((return_ - output['value']) ** 2).mean()
adv = torch.clamp(
return_ - batch['value'], min=self._cfg.learn.norm_range[0], max=self._cfg.learn.norm_range[1]
)
policy_loss = -(log_prob * torch.exp(adv / self._cfg.learn.beta)).mean()
total_policy_loss += policy_loss
total_value_loss += value_loss
value_loss = ((return_ - output['value']) ** 2).mean()
total_value_loss += value_loss
total_entropy_loss += -self._cfg.learn.entropy_weight * output['dist'].entropy().mean()
total_loss = total_entropy_loss + total_policy_loss + total_value_loss

Expand Down
2 changes: 2 additions & 0 deletions ding/policy/prompt_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class PromptPGPolicy(Policy):
on_policy=True, # for pg strictly on policy algorithm, this line should not be modified by users
# (bool) whether to use deterministic action for evaluation.
deterministic_eval=True,
# (int) The number of actions that can be done simultaneously in one timestep.
shot_number=1,
learn=dict(
# (int) the number of samples for one update.
batch_size=64,
Expand Down

0 comments on commit 10f6626

Please sign in to comment.