From 3898386a56e315d5144e5b42909d8bff5d2fc337 Mon Sep 17 00:00:00 2001 From: Wang hl <59834623+kxzxvbk@users.noreply.github.com> Date: Thu, 26 Sep 2024 09:38:22 +0800 Subject: [PATCH] feature(whl): add AWR algorithm (#828) * init commit * reformat * polish * polish readme * reformat * polish --- README.md | 3 +- ding/model/template/language_transformer.py | 51 +++- .../tests/test_language_transformer.py | 34 ++- ding/policy/__init__.py | 1 + ding/policy/command_mode_policy_instance.py | 6 + ding/policy/prompt_awr.py | 274 ++++++++++++++++++ ding/policy/prompt_pg.py | 4 + dizoo/tabmwp/config/tabmwp_awr_config.py | 66 +++++ 8 files changed, 421 insertions(+), 18 deletions(-) create mode 100644 ding/policy/prompt_awr.py create mode 100644 dizoo/tabmwp/config/tabmwp_awr_config.py diff --git a/README.md b/README.md index 14d222ee66..73f4dfdf79 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ It provides **python-first** and **asynchronous-native** task and middleware abs - Offline RL algorithms: BCQ, CQL, TD3BC, Decision Transformer, EDAC, Diffuser, Decision Diffuser, SO2 - Model-based RL algorithms: SVG, STEVE, MBPO, DDPPO, DreamerV3 - Exploration algorithms: HER, RND, ICM, NGU -- LLM + RL Algorithms: PPO-max, DPO, PromptPG +- LLM + RL Algorithms: PPO-max, DPO, PromptPG, PromptAWR - Other algorithms: such as PER, PLR, PCGrad - MCTS + RL algorithms: AlphaZero, MuZero, please refer to [LightZero](https://github.com/opendilab/LightZero) - Generative Model + RL algorithms: Diffusion-QL, QGPO, SRPO, please refer to [GenerativeRL](https://github.com/opendilab/GenerativeRL) @@ -283,6 +283,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo` | 54 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 | | 55 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)
[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 | | 56 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 | +| 57 | [AWR](https://arxiv.org/pdf/1910.00177) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/ibc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/prompt_awr.py) | python3 -u tabmwp_awr_config.py | diff --git a/ding/model/template/language_transformer.py b/ding/model/template/language_transformer.py index cac2d69adf..796b2ba0a7 100644 --- a/ding/model/template/language_transformer.py +++ b/ding/model/template/language_transformer.py @@ -1,4 +1,4 @@ -from typing import List, Dict +from typing import List, Dict, Optional import torch from torch import nn @@ -15,16 +15,21 @@ class LanguageTransformer(nn.Module): """ Overview: The LanguageTransformer network. Download a pre-trained language model and add head on it. + In the default case, we use BERT model as the text encoder, whose bi-directional character is good + for obtaining the embedding of the whole sentence. Interfaces: ``__init__``, ``forward`` """ + mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] def __init__( self, model_name: str = "bert-base-uncased", add_linear: bool = False, embedding_size: int = 128, - freeze_encoder: bool = True + freeze_encoder: bool = True, + hidden_dim: int = 768, + norm_embedding: bool = False ) -> None: """ Overview: @@ -32,14 +37,22 @@ def __init__( Arguments: - model_name (:obj:`str`): The base language model name in huggingface, such as "bert-base-uncased". - add_linear (:obj:`bool`): Whether to add a linear layer on the top of language model, defaults to be \ - ``False``. + ``False``. - embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128. - freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \ - defaults to be ``True``. + defaults to be ``True``. + - hidden_dim (:obj:`int`): The embedding dimension of the encoding model (e.g. BERT). This value should \ + correspond to the model you use. For bert-base-uncased, this value is 768. + - norm_embedding (:obj:`bool`): Whether to normalize the embedding vectors. Default to be ``False``. """ super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForTokenClassification.from_pretrained(model_name) + in_channel = hidden_dim if not add_linear else embedding_size + self.value_head = nn.Linear(in_channel, 1) + self.norm = nn.Identity() if not norm_embedding else nn.LayerNorm( + normalized_shape=in_channel, elementwise_affine=False + ) # Freeze transformer encoder and only train the linear layer if freeze_encoder: @@ -49,9 +62,7 @@ def __init__( if add_linear: # Add a small, adjustable linear layer on top of language model tuned through RL self.embedding_size = embedding_size - self.linear = nn.Linear( - self.model.config.hidden_size, embedding_size - ) # 768 for bert-base-uncased, distilbert-base-uncased + self.linear = nn.Linear(self.model.config.hidden_size, embedding_size) else: self.linear = None @@ -66,19 +77,27 @@ def _calc_embedding(self, x: list) -> torch.Tensor: last_hidden_states = output.hidden_states[-1] # Get [CLS] hidden states sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size + sentence_embedding = self.norm(sentence_embedding) if self.linear: sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size return sentence_embedding - def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dict: + def forward( + self, + train_samples: List[str], + candidate_samples: Optional[List[str]] = None, + mode: str = 'compute_actor' + ) -> Dict: """ Overview: LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores. + Different ``mode`` will forward with different network modules to get different outputs. Arguments: - train_samples (:obj:`List[str]`): One list of strings. - - candidate_samples (:obj:`List[str]`): The other list of strings to calculate the matching scores. + - candidate_samples (:obj:`Optional[List[str]]`): The other list of strings to calculate matching scores. + - - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. Returns: - output (:obj:`Dict`): Output dict data, including the logit of matching scores and the \ corresponding ``torch.distributions.Categorical`` object. @@ -96,7 +115,15 @@ def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dic >>> scores = model(ctxt_list, cands_list) >>> assert scores.shape == (1, 3) """ + assert mode in self.mode prompt_embedding = self._calc_embedding(train_samples) - cands_embedding = self._calc_embedding(candidate_samples) - scores = torch.mm(prompt_embedding, cands_embedding.t()) - return {'dist': torch.distributions.Categorical(logits=scores), 'logit': scores} + + res_dict = {} + if mode in ['compute_actor', 'compute_actor_critic']: + cands_embedding = self._calc_embedding(candidate_samples) + scores = torch.mm(prompt_embedding, cands_embedding.t()) + res_dict.update({'dist': torch.distributions.Categorical(logits=scores), 'logit': scores}) + if mode in ['compute_critic', 'compute_actor_critic']: + value = self.value_head(prompt_embedding) + res_dict.update({'value': value}) + return res_dict diff --git a/ding/model/template/tests/test_language_transformer.py b/ding/model/template/tests/test_language_transformer.py index 40095c2ab2..eaaaae5a84 100644 --- a/ding/model/template/tests/test_language_transformer.py +++ b/ding/model/template/tests/test_language_transformer.py @@ -17,9 +17,33 @@ def check_model(self): cands_list = [problems[pid] for pid in cand_pids] model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256) - scores = model(ctxt_list, cands_list) - assert scores.shape == (1, 3) + output = model(ctxt_list, cands_list, mode='compute_actor') + assert 'dist' in output.keys() and 'logit' in output.keys() and len(output.keys()) == 2 + assert output['logit'].shape == (1, 3) - model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256) - scores = model(ctxt_list, cands_list) - assert scores.shape == (1, 3) + output = model(ctxt_list, cands_list, mode='compute_critic') + assert 'value' in output.keys() and len(output.keys()) == 1 + assert output['value'].shape == (1, ) + + output = model(ctxt_list, cands_list, mode='compute_critic') + assert 'value' in output.keys() and 'dist' in output.keys() and 'logit' in output.keys() and len( + output.keys() + ) == 3 + assert output['value'].shape == (1, ) + assert output['logit'].shape == (1, 3) + + model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, norm_embedding=True) + output = model(ctxt_list, cands_list, mode='compute_actor') + assert 'dist' in output.keys() and 'logit' in output.keys() and len(output.keys()) == 2 + assert output['logit'].shape == (1, 3) + + output = model(ctxt_list, cands_list, mode='compute_critic') + assert 'value' in output.keys() and len(output.keys()) == 1 + assert output['value'].shape == (1, ) + + output = model(ctxt_list, cands_list, mode='compute_critic') + assert 'value' in output.keys() and 'dist' in output.keys() and 'logit' in output.keys() and len( + output.keys() + ) == 3 + assert output['value'].shape == (1, ) + assert output['logit'].shape == (1, 3) diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py index 1f202da3bb..5015823d71 100755 --- a/ding/policy/__init__.py +++ b/ding/policy/__init__.py @@ -56,4 +56,5 @@ # new-type policy from .ppof import PPOFPolicy from .prompt_pg import PromptPGPolicy +from .prompt_awr import PromptAWRPolicy from .happo import HAPPOPolicy diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 2e817ead4b..1289d8e6ca 100644 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -52,6 +52,7 @@ from .prompt_pg import PromptPGPolicy from .plan_diffuser import PDPolicy from .happo import HAPPOPolicy +from .prompt_awr import PromptAWRPolicy class EpsCommandModePolicy(CommandModePolicy): @@ -455,3 +456,8 @@ def _get_setting_eval(self, command_info: dict) -> dict: @POLICY_REGISTRY.register('prompt_pg_command') class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy): pass + + +@POLICY_REGISTRY.register('prompt_awr_command') +class PromptAWRCommandModePolicy(PromptAWRPolicy, DummyCommandModePolicy): + pass diff --git a/ding/policy/prompt_awr.py b/ding/policy/prompt_awr.py new file mode 100644 index 0000000000..4b39057d22 --- /dev/null +++ b/ding/policy/prompt_awr.py @@ -0,0 +1,274 @@ +from collections import namedtuple +from typing import List, Dict, Any, Tuple, Union + +import torch + +from ding.model import model_wrap +from ding.rl_utils import get_train_sample +from ding.torch_utils import Adam, to_device +from ding.utils import POLICY_REGISTRY +from ding.utils.data import default_collate, default_decollate +from .base_policy import Policy + + +@POLICY_REGISTRY.register('prompt_awr') +class PromptAWRPolicy(Policy): + """ + Overview: + Policy class of AWR (Advantage Weighted Regression) algorithm, proposed in https://arxiv.org/abs/1910.00177. + Especially, this policy is designed for training a language model policy. + In this policy, the environment's observation includes the current context, a list of optional actions + (strings). The final output of the policy is a set of optional actions with a size of ``shot_number``. + """ + config = dict( + # (str) Name of the registered RL policy (refer to the "register_policy" function). + type='prompt_awr', + # (bool) Flag to enable CUDA for model computation. + cuda=False, + # (bool) Flag for using on-policy training (training policy is the same as the behavior policy). + on_policy=False, + # (bool) Flag for enabling priority experience replay. Must be False when priority_IS_weight is False. + priority=False, + # (bool) Flag for using Importance Sampling weights to correct updates. Requires `priority` to be True. + priority_IS_weight=False, + # (str) Type of action space used in the policy, with valid options ['discrete', 'continuous']. + action_space='discrete', + # (int) The number of actions that can be done simultaneously in one timestep. + shot_number=1, + # learn_mode configuration + learn=dict( + # (int) Number of updates per data collection. A2C requires this to be set to 1. + update_per_collect=1, + # (int) Batch size for learning. + batch_size=64, + # (float) Learning rate for optimizer. + learning_rate=0.001, + # (Tuple[float, float]) Coefficients used for computing running averages of gradient and its square. + betas=(0.9, 0.999), + # (float) Term added to the denominator to improve numerical stability in optimizer. + eps=1e-8, + # (float) Maximum norm for gradients. + grad_norm=0.5, + # (float) Scaling factor for value network loss relative to policy network loss. + value_weight=0.5, + # (float) Coefficient that controls the exp scale in awr algorithm. + beta=1.0, + # (float) Weight of entropy regularization in the loss function. + entropy_weight=0.001, + # (Tuple[float, float]) The range of adv. Value that exceeds this range will be clipped. + adv_range=(-0.5, 0.5), + # (bool) If set to True, the 'done' signals that indicate the end of an episode due to environment time + # limits are disregarded. By default, this is set to False. This setting is particularly useful for tasks + # that have a predetermined episode length, such as HalfCheetah and various other MuJoCo environments, + # where the maximum length is capped at 1000 steps. When enabled, any 'done' signal triggered by reaching + # the maximum episode steps will be overridden to 'False'. This ensures the accurate calculation of the + # Temporal Difference (TD) error, using the formula `gamma * (1 - done) * next_v + reward`, + # even when the episode surpasses the predefined step limit. + ignore_done=False, + ), + # collect_mode configuration + collect=dict( + # (int) The length of rollout for data collection. + unroll_len=1, + # (float) Discount factor for calculating future rewards, typically in the range [0, 1]. + discount_factor=0.9, + # (float) Trade-off parameter for balancing TD-error and Monte Carlo error in GAE. + gae_lambda=0.95, + ), + # eval_mode configuration (kept empty for compatibility purposes) + eval=dict(), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Returns the default model configuration used by the AWR algorithm. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): \ + Tuple containing the registered model name and model's import_names. + """ + return 'language_transformer', ['ding.model.template.language_transformer'] + + def _init_learn(self) -> None: + """ + Overview: + Initialize the learn mode of policy, including related attributes and modules. For AWR, it mainly \ + contains optimizer, algorithm-specific arguments such as value_weight, entropy_weight, adv_norm + and grad_norm, and main model. \ + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. + """ + assert self._cfg.action_space == "discrete" + # Optimizer + self._optimizer = Adam( + self._model.parameters(), + lr=self._cfg.learn.learning_rate, + betas=self._cfg.learn.betas, + eps=self._cfg.learn.eps + ) + + # Algorithm config + self._priority = self._cfg.priority + self._priority_IS_weight = self._cfg.priority_IS_weight + self._value_weight = self._cfg.learn.value_weight + self._entropy_weight = self._cfg.learn.entropy_weight + self._adv_norm = self._cfg.learn.adv_norm + self._grad_norm = self._cfg.learn.grad_norm + + # Main and target models + self._learn_model = self._model + + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + # Data preprocessing operations, such as stack data, cpu to cuda device + self._learn_model.train() + + for i in range(0, len(data), self._cfg.learn.batch_size): + batch = default_collate(data[i:i + self._cfg.learn.batch_size]) + if self._cuda: + batch = to_device(batch, self._device) + + # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected) + train_samples, cand_samples = batch["obs"]["train_sample"], batch["obs"]["candidate_samples"] + for cand_n in range(len(cand_samples)): + cand_samples[cand_n] = cand_samples[cand_n][0] + output = self._learn_model.forward(train_samples, cand_samples, mode='compute_actor_critic') + return_ = batch['return'] + + # Calculate AWR loss + real_act = batch['action'] + + # Ensure the shape of real_act is: (B, shot_number) + if len(real_act.shape) == 1: + real_act = real_act.unsqueeze(-1) + + # Calculate different parts of loss. + total_policy_loss, total_entropy_loss, total_value_loss = 0, 0, 0 + for shot_n in range(self._cfg.shot_number): + log_prob = output['dist'].log_prob(real_act[:, shot_n]) + # Clamp the adv for better stability. + adv = torch.clamp( + return_ - batch['value'], min=self._cfg.learn.norm_range[0], max=self._cfg.learn.norm_range[1] + ) + # The policy loss for AWR algorithm. + policy_loss = -(log_prob * torch.exp(adv / self._cfg.learn.beta)).mean() + total_policy_loss += policy_loss + # The value loss for AWR algorithm. + value_loss = ((return_ - output['value']) ** 2).mean() + total_value_loss += value_loss + # The entropy loss for AWR algorithm. + total_entropy_loss += -self._cfg.learn.entropy_weight * output['dist'].entropy().mean() + total_loss = total_entropy_loss + total_policy_loss + total_value_loss + + self._optimizer.zero_grad() + total_loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + list(self._learn_model.parameters()), + max_norm=self._grad_norm, + ) + self._optimizer.step() + + return { + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'total_loss': total_loss.item(), + 'policy_loss': total_policy_loss.item(), + 'entropy_loss': total_entropy_loss.item(), + 'value_loss': total_value_loss.item(), + 'return_abs_max': return_.abs().max().item(), + 'grad_norm': grad_norm, + } + + def _init_collect(self) -> None: + self._unroll_len = self._cfg.collect.unroll_len + self._gamma = self._cfg.collect.discount_factor + self._collect_model = model_wrap(self._model, wrapper_name='combination_multinomial_sample') + + def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ + Overview: + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ + dict is the same as the input data, i.e. environment id. + """ + data_id = list(data.keys()) + data = default_collate(list(data.values())) + self._model.eval() + with torch.no_grad(): + # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected) + for ii in range(len(data['candidate_samples'])): + data['candidate_samples'][ii] = data['candidate_samples'][ii][0] + output = self._collect_model.forward( + self._cfg.shot_number, data['train_sample'], data['candidate_samples'], mode="compute_actor_critic" + ) + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _process_transition(self, obs: Any, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: + return { + 'obs': obs, + 'action': policy_output['action'], + 'value': policy_output['value'], + 'reward': timestep.reward, + 'done': timestep.done, + } + + def _get_train_sample(self, data: list) -> Union[None, List[Any]]: + r""" + Overview: + Get the trajectory and the n step return data, then sample from the n_step return data + Arguments: + - data (:obj:`list`): The trajectory's buffer list + Returns: + - samples (:obj:`dict`): The training samples generated + """ + if self._cfg.learn.ignore_done: + raise NotImplementedError + + R = 0. + for i in reversed(range(len(data))): + R = self._gamma * R + data[i]['reward'] + data[i]['return'] = R + return get_train_sample(data, self._unroll_len) + + def _init_eval(self) -> None: + self._eval_model = model_wrap(self._model, wrapper_name='combination_argmax_sample') + + def _forward_eval(self, data: dict) -> dict: + data_id = list(data.keys()) + data = default_collate(list(data.values())) + self._model.eval() + with torch.no_grad(): + # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected) + for ii in range(len(data['candidate_samples'])): + data['candidate_samples'][ii] = data['candidate_samples'][ii][0] + output = self._eval_model.forward(self._cfg.shot_number, data['train_sample'], data['candidate_samples']) + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _monitor_vars_learn(self) -> List[str]: + return super()._monitor_vars_learn() + \ + ['policy_loss', 'entropy_loss', 'return_abs_max', 'grad_norm', 'value_loss'] diff --git a/ding/policy/prompt_pg.py b/ding/policy/prompt_pg.py index ebccadb8a3..a76e0e5faf 100644 --- a/ding/policy/prompt_pg.py +++ b/ding/policy/prompt_pg.py @@ -26,6 +26,8 @@ class PromptPGPolicy(Policy): on_policy=True, # for pg strictly on policy algorithm, this line should not be modified by users # (bool) whether to use deterministic action for evaluation. deterministic_eval=True, + # (int) The number of actions that can be done simultaneously in one timestep. + shot_number=1, learn=dict( # (int) the number of samples for one update. batch_size=64, @@ -98,6 +100,8 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: # calculate PG loss real_act = batch['action'] # shape: (B, shot_number) + if len(real_act.shape) == 1: + real_act = real_act.unsqueeze(-1) # Calculate loss. total_policy_loss, total_entropy_loss = 0, 0 for ii in range(self._cfg.shot_number): diff --git a/dizoo/tabmwp/config/tabmwp_awr_config.py b/dizoo/tabmwp/config/tabmwp_awr_config.py new file mode 100644 index 0000000000..7e3f22865f --- /dev/null +++ b/dizoo/tabmwp/config/tabmwp_awr_config.py @@ -0,0 +1,66 @@ +from easydict import EasyDict + +tabmwp_prompt_awr_config = dict( + exp_name='tabmwp_prompt_awr_seed0', + env=dict( + collector_env_num=1, + evaluator_env_num=1, + n_evaluator_episode=1, + stop_value=1, + cand_number=16, + train_number=80, + engine='text-davinci-002', + temperature=0., + max_tokens=512, + top_p=1., + frequency_penalty=0., + presence_penalty=0., + option_inds=["A", "B", "C", "D", "E", "F"], + # The API-key of openai. You can get your key in this website: https://platform.openai.com/ + api_key='', + enable_replay=True, + prompt_format='TQ-A', + seed=0, + ), + policy=dict( + cuda=True, + shot_number=2, + model=dict( + model_name="bert-base-uncased", + add_linear=True, + freeze_encoder=True, + embedding_size=128, + ), + learn=dict( + batch_size=10, + # (bool) Whether to normalize advantage. Default to False. + learning_rate=0.001, + # (float) loss weight of the value network, the weight of policy network is set to 1 + entropy_weight=0.001, + weight_decay=5e-3, + grad_norm=0.5, + ), + collect=dict( + # (int) collect n_sample data, train model 1 times + n_sample=20, + discount_factor=0., + ), + eval=dict(evaluator=dict(eval_freq=500, )), + ), +) +main_config = EasyDict(tabmwp_prompt_awr_config) + +tabmwp_prompt_awr_config = dict( + env=dict( + type='tabmwp', + import_names=['dizoo.tabmwp.envs.tabmwp_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='prompt_awr'), + replay_buffer=dict(type='naive'), +) +create_config = EasyDict(tabmwp_prompt_awr_config) + +if __name__ == '__main__': + from ding.entry import serial_pipeline_onpolicy + serial_pipeline_onpolicy((main_config, create_config), seed=0)