Skip to content

Commit

Permalink
polish
Browse files Browse the repository at this point in the history
  • Loading branch information
‘whl’ committed Sep 23, 2024
1 parent 7c3913d commit b9c7df7
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 18 deletions.
8 changes: 5 additions & 3 deletions ding/model/template/language_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class LanguageTransformer(nn.Module):
"""
Overview:
The LanguageTransformer network. Download a pre-trained language model and add head on it.
In the default case, we use BERT model as the text encoder, whose bi-directional character is good
for obtaining the embedding of the whole sentence.
Interfaces:
``__init__``, ``forward``
"""
Expand All @@ -35,12 +37,12 @@ def __init__(
Arguments:
- model_name (:obj:`str`): The base language model name in huggingface, such as "bert-base-uncased".
- add_linear (:obj:`bool`): Whether to add a linear layer on the top of language model, defaults to be \
``False``.
``False``.
- embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128.
- freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \
defaults to be ``True``.
defaults to be ``True``.
- hidden_dim (:obj:`int`): The embedding dimension of the encoding model (e.g. BERT). This value should \
correspond to the model you use. For bert-base-uncased, this value is 768.
correspond to the model you use. For bert-base-uncased, this value is 768.
- norm_embedding (:obj:`bool`): Whether to normalize the embedding vectors. Default to be ``False``.
"""
super().__init__()
Expand Down
27 changes: 17 additions & 10 deletions ding/policy/prompt_awr.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class PromptAWRPolicy(Policy):
# (float) Coefficient that controls the exp scale in awr algorithm.
beta=1.0,
# (float) Weight of entropy regularization in the loss function.
entropy_weight=0.01,
entropy_weight=0.001,
# (Tuple[float, float]) The range of adv. Value that exceeds this range will be clipped.
adv_range=(-0.5, 0.5),
# (bool) If set to True, the 'done' signals that indicate the end of an episode due to environment time
Expand Down Expand Up @@ -82,7 +82,7 @@ class PromptAWRPolicy(Policy):
def default_model(self) -> Tuple[str, List[str]]:
"""
Overview:
Returns the default model configuration used by the A2C algorithm. ``__init__`` method will \
Returns the default model configuration used by the AWR algorithm. ``__init__`` method will \
automatically call this method to get the default model setting and create model.
Returns:
Expand All @@ -94,7 +94,7 @@ def default_model(self) -> Tuple[str, List[str]]:
def _init_learn(self) -> None:
"""
Overview:
Initialize the learn mode of policy, including related attributes and modules. For A2C, it mainly \
Initialize the learn mode of policy, including related attributes and modules. For AWR, it mainly \
contains optimizer, algorithm-specific arguments such as value_weight, entropy_weight, adv_norm
and grad_norm, and main model. \
This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
Expand Down Expand Up @@ -141,26 +141,33 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:

# Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
train_samples, cand_samples = batch["obs"]["train_sample"], batch["obs"]["candidate_samples"]
for ii in range(len(cand_samples)):
cand_samples[ii] = cand_samples[ii][0]
for cand_n in range(len(cand_samples)):
cand_samples[cand_n] = cand_samples[cand_n][0]
output = self._learn_model.forward(train_samples, cand_samples, mode='compute_actor_critic')
return_ = batch['return']

# calculate PG loss
real_act = batch['action'] # shape: (B, shot_number)
# Calculate AWR loss
real_act = batch['action']

# Ensure the shape of real_act is: (B, shot_number)
if len(real_act.shape) == 1:
real_act = real_act.unsqueeze(-1)
# Calculate loss.

# Calculate different parts of loss.
total_policy_loss, total_entropy_loss, total_value_loss = 0, 0, 0
for ii in range(self._cfg.shot_number):
log_prob = output['dist'].log_prob(real_act[:, ii])
for shot_n in range(self._cfg.shot_number):
log_prob = output['dist'].log_prob(real_act[:, shot_n])
# Clamp the adv for better stability.
adv = torch.clamp(
return_ - batch['value'], min=self._cfg.learn.norm_range[0], max=self._cfg.learn.norm_range[1]
)
# The policy loss for AWR algorithm.
policy_loss = -(log_prob * torch.exp(adv / self._cfg.learn.beta)).mean()
total_policy_loss += policy_loss
# The value loss for AWR algorithm.
value_loss = ((return_ - output['value']) ** 2).mean()
total_value_loss += value_loss
# The entropy loss for AWR algorithm.
total_entropy_loss += -self._cfg.learn.entropy_weight * output['dist'].entropy().mean()
total_loss = total_entropy_loss + total_policy_loss + total_value_loss

Expand Down
10 changes: 5 additions & 5 deletions dizoo/tabmwp/config/tabmwp_awr_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from easydict import EasyDict

tabmwp_prompt_pg_config = dict(
exp_name='tabmwp_prompt_pg_seed0',
tabmwp_prompt_awr_config = dict(
exp_name='tabmwp_prompt_awr_seed0',
env=dict(
collector_env_num=1,
evaluator_env_num=1,
Expand Down Expand Up @@ -48,9 +48,9 @@
eval=dict(evaluator=dict(eval_freq=500, )),
),
)
main_config = EasyDict(tabmwp_prompt_pg_config)
main_config = EasyDict(tabmwp_prompt_awr_config)

tabmwp_prompt_pg_config = dict(
tabmwp_prompt_awr_config = dict(
env=dict(
type='tabmwp',
import_names=['dizoo.tabmwp.envs.tabmwp_env'],
Expand All @@ -59,7 +59,7 @@
policy=dict(type='prompt_awr'),
replay_buffer=dict(type='naive'),
)
create_config = EasyDict(tabmwp_prompt_pg_config)
create_config = EasyDict(tabmwp_prompt_awr_config)

if __name__ == '__main__':
from ding.entry import serial_pipeline_onpolicy
Expand Down

0 comments on commit b9c7df7

Please sign in to comment.