Skip to content

Commit

Permalink
feature(whl): add AWR algorithm (#828)
Browse files Browse the repository at this point in the history
* init commit

* reformat

* polish

* polish readme

* reformat

* polish
  • Loading branch information
kxzxvbk authored Sep 26, 2024
1 parent 6ae1396 commit 3898386
Show file tree
Hide file tree
Showing 8 changed files with 421 additions and 18 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ It provides **python-first** and **asynchronous-native** task and middleware abs
- Offline RL algorithms: BCQ, CQL, TD3BC, Decision Transformer, EDAC, Diffuser, Decision Diffuser, SO2
- Model-based RL algorithms: SVG, STEVE, MBPO, DDPPO, DreamerV3
- Exploration algorithms: HER, RND, ICM, NGU
- LLM + RL Algorithms: PPO-max, DPO, PromptPG
- LLM + RL Algorithms: PPO-max, DPO, PromptPG, PromptAWR
- Other algorithms: such as PER, PLR, PCGrad
- MCTS + RL algorithms: AlphaZero, MuZero, please refer to [LightZero](https://github.com/opendilab/LightZero)
- Generative Model + RL algorithms: Diffusion-QL, QGPO, SRPO, please refer to [GenerativeRL](https://github.com/opendilab/GenerativeRL)
Expand Down Expand Up @@ -283,6 +283,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 54 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
| 55 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
| 56 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
| 57 | [AWR](https://arxiv.org/pdf/1910.00177) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/ibc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/prompt_awr.py) | python3 -u tabmwp_awr_config.py |

</details>

Expand Down
51 changes: 39 additions & 12 deletions ding/model/template/language_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict
from typing import List, Dict, Optional
import torch
from torch import nn

Expand All @@ -15,31 +15,44 @@ class LanguageTransformer(nn.Module):
"""
Overview:
The LanguageTransformer network. Download a pre-trained language model and add head on it.
In the default case, we use BERT model as the text encoder, whose bi-directional character is good
for obtaining the embedding of the whole sentence.
Interfaces:
``__init__``, ``forward``
"""
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']

def __init__(
self,
model_name: str = "bert-base-uncased",
add_linear: bool = False,
embedding_size: int = 128,
freeze_encoder: bool = True
freeze_encoder: bool = True,
hidden_dim: int = 768,
norm_embedding: bool = False
) -> None:
"""
Overview:
Init the LanguageTransformer Model according to input arguments.
Arguments:
- model_name (:obj:`str`): The base language model name in huggingface, such as "bert-base-uncased".
- add_linear (:obj:`bool`): Whether to add a linear layer on the top of language model, defaults to be \
``False``.
``False``.
- embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128.
- freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \
defaults to be ``True``.
defaults to be ``True``.
- hidden_dim (:obj:`int`): The embedding dimension of the encoding model (e.g. BERT). This value should \
correspond to the model you use. For bert-base-uncased, this value is 768.
- norm_embedding (:obj:`bool`): Whether to normalize the embedding vectors. Default to be ``False``.
"""
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForTokenClassification.from_pretrained(model_name)
in_channel = hidden_dim if not add_linear else embedding_size
self.value_head = nn.Linear(in_channel, 1)
self.norm = nn.Identity() if not norm_embedding else nn.LayerNorm(
normalized_shape=in_channel, elementwise_affine=False
)

# Freeze transformer encoder and only train the linear layer
if freeze_encoder:
Expand All @@ -49,9 +62,7 @@ def __init__(
if add_linear:
# Add a small, adjustable linear layer on top of language model tuned through RL
self.embedding_size = embedding_size
self.linear = nn.Linear(
self.model.config.hidden_size, embedding_size
) # 768 for bert-base-uncased, distilbert-base-uncased
self.linear = nn.Linear(self.model.config.hidden_size, embedding_size)
else:
self.linear = None

Expand All @@ -66,19 +77,27 @@ def _calc_embedding(self, x: list) -> torch.Tensor:
last_hidden_states = output.hidden_states[-1]
# Get [CLS] hidden states
sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size
sentence_embedding = self.norm(sentence_embedding)

if self.linear:
sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size

return sentence_embedding

def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dict:
def forward(
self,
train_samples: List[str],
candidate_samples: Optional[List[str]] = None,
mode: str = 'compute_actor'
) -> Dict:
"""
Overview:
LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores.
Different ``mode`` will forward with different network modules to get different outputs.
Arguments:
- train_samples (:obj:`List[str]`): One list of strings.
- candidate_samples (:obj:`List[str]`): The other list of strings to calculate the matching scores.
- candidate_samples (:obj:`Optional[List[str]]`): The other list of strings to calculate matching scores.
- - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.
Returns:
- output (:obj:`Dict`): Output dict data, including the logit of matching scores and the \
corresponding ``torch.distributions.Categorical`` object.
Expand All @@ -96,7 +115,15 @@ def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dic
>>> scores = model(ctxt_list, cands_list)
>>> assert scores.shape == (1, 3)
"""
assert mode in self.mode
prompt_embedding = self._calc_embedding(train_samples)
cands_embedding = self._calc_embedding(candidate_samples)
scores = torch.mm(prompt_embedding, cands_embedding.t())
return {'dist': torch.distributions.Categorical(logits=scores), 'logit': scores}

res_dict = {}
if mode in ['compute_actor', 'compute_actor_critic']:
cands_embedding = self._calc_embedding(candidate_samples)
scores = torch.mm(prompt_embedding, cands_embedding.t())
res_dict.update({'dist': torch.distributions.Categorical(logits=scores), 'logit': scores})
if mode in ['compute_critic', 'compute_actor_critic']:
value = self.value_head(prompt_embedding)
res_dict.update({'value': value})
return res_dict
34 changes: 29 additions & 5 deletions ding/model/template/tests/test_language_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,33 @@ def check_model(self):
cands_list = [problems[pid] for pid in cand_pids]

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
scores = model(ctxt_list, cands_list)
assert scores.shape == (1, 3)
output = model(ctxt_list, cands_list, mode='compute_actor')
assert 'dist' in output.keys() and 'logit' in output.keys() and len(output.keys()) == 2
assert output['logit'].shape == (1, 3)

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256)
scores = model(ctxt_list, cands_list)
assert scores.shape == (1, 3)
output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and len(output.keys()) == 1
assert output['value'].shape == (1, )

output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and 'dist' in output.keys() and 'logit' in output.keys() and len(
output.keys()
) == 3
assert output['value'].shape == (1, )
assert output['logit'].shape == (1, 3)

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, norm_embedding=True)
output = model(ctxt_list, cands_list, mode='compute_actor')
assert 'dist' in output.keys() and 'logit' in output.keys() and len(output.keys()) == 2
assert output['logit'].shape == (1, 3)

output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and len(output.keys()) == 1
assert output['value'].shape == (1, )

output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and 'dist' in output.keys() and 'logit' in output.keys() and len(
output.keys()
) == 3
assert output['value'].shape == (1, )
assert output['logit'].shape == (1, 3)
1 change: 1 addition & 0 deletions ding/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@
# new-type policy
from .ppof import PPOFPolicy
from .prompt_pg import PromptPGPolicy
from .prompt_awr import PromptAWRPolicy
from .happo import HAPPOPolicy
6 changes: 6 additions & 0 deletions ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from .prompt_pg import PromptPGPolicy
from .plan_diffuser import PDPolicy
from .happo import HAPPOPolicy
from .prompt_awr import PromptAWRPolicy


class EpsCommandModePolicy(CommandModePolicy):
Expand Down Expand Up @@ -455,3 +456,8 @@ def _get_setting_eval(self, command_info: dict) -> dict:
@POLICY_REGISTRY.register('prompt_pg_command')
class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('prompt_awr_command')
class PromptAWRCommandModePolicy(PromptAWRPolicy, DummyCommandModePolicy):
pass
Loading

0 comments on commit 3898386

Please sign in to comment.