From f5fed7c8c300184a9375526e8a8ee89d5b7afeac Mon Sep 17 00:00:00 2001 From: "Yinmin.Zhang" <40760801+YinminZhang@users.noreply.github.com> Date: Thu, 13 Jun 2024 10:50:09 +1000 Subject: [PATCH] polish(zym): optimize ppo continuous act (#801) * fix (zym): update hidden size of head in VAC module * feat (zym): update ppo config to support continuous action space --- ding/model/template/vac.py | 29 +++++++++++-------- ding/policy/ppo.py | 26 ++++++++++++++++- dizoo/mujoco/config/ant_onppo_config.py | 14 +++++++-- .../mujoco/config/halfcheetah_onppo_config.py | 22 ++++++++++---- dizoo/mujoco/config/hopper_onppo_config.py | 18 +++++++++--- dizoo/mujoco/config/walker2d_onppo_config.py | 22 ++++++++++---- 6 files changed, 100 insertions(+), 31 deletions(-) diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index 47d5cb1bd6..00ef4162b9 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -54,12 +54,12 @@ def __init__( ``ReparameterizationHead``, and hybrid heads. - share_encoder (:obj:`bool`): Whether to share observation encoders between actor and decoder. - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ - the last element must match ``head_hidden_size``. + the last element is used as the input size of ``actor_head`` and ``critic_head``. - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``actor_head`` network, defaults \ - to 64, it must match the last element of ``encoder_hidden_size_list``. + to 64, it is the hidden size of the last layer of the ``actor_head`` network. - actor_head_layer_num (:obj:`int`): The num of layers used in the ``actor_head`` network to compute action. - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``critic_head`` network, defaults \ - to 64, it must match the last element of ``encoder_hidden_size_list``. + to 64, it is the hidden size of the last layer of the ``critic_head`` network. - critic_head_layer_num (:obj:`int`): The num of layers used in the ``critic_head`` network. - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \ if ``None`` then default set it to ``nn.ReLU()``. @@ -108,15 +108,13 @@ def new_encoder(outsize, activation): ) if self.share_encoder: - assert actor_head_hidden_size == critic_head_hidden_size, \ - "actor and critic network head should have same size." if encoder: if isinstance(encoder, torch.nn.Module): self.encoder = encoder else: raise ValueError("illegal encoder instance.") else: - self.encoder = new_encoder(actor_head_hidden_size, activation) + self.encoder = new_encoder(encoder_hidden_size_list[-1], activation) else: if encoder: if isinstance(encoder, torch.nn.Module): @@ -125,25 +123,31 @@ def new_encoder(outsize, activation): else: raise ValueError("illegal encoder instance.") else: - self.actor_encoder = new_encoder(actor_head_hidden_size, activation) - self.critic_encoder = new_encoder(critic_head_hidden_size, activation) + self.actor_encoder = new_encoder(encoder_hidden_size_list[-1], activation) + self.critic_encoder = new_encoder(encoder_hidden_size_list[-1], activation) # Head Type self.critic_head = RegressionHead( - critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type + encoder_hidden_size_list[-1], + 1, + critic_head_layer_num, + activation=activation, + norm_type=norm_type, + hidden_size=critic_head_hidden_size ) self.action_space = action_space assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space if self.action_space == 'continuous': self.multi_head = False self.actor_head = ReparameterizationHead( - actor_head_hidden_size, + encoder_hidden_size_list[-1], action_shape, actor_head_layer_num, sigma_type=sigma_type, activation=activation, norm_type=norm_type, - bound_type=bound_type + bound_type=bound_type, + hidden_size=actor_head_hidden_size, ) elif self.action_space == 'discrete': actor_head_cls = DiscreteHead @@ -172,7 +176,7 @@ def new_encoder(outsize, activation): action_shape.action_args_shape = squeeze(action_shape.action_args_shape) action_shape.action_type_shape = squeeze(action_shape.action_type_shape) actor_action_args = ReparameterizationHead( - actor_head_hidden_size, + encoder_hidden_size_list[-1], action_shape.action_args_shape, actor_head_layer_num, sigma_type=sigma_type, @@ -180,6 +184,7 @@ def new_encoder(outsize, activation): activation=activation, norm_type=norm_type, bound_type=bound_type, + hidden_size=actor_head_hidden_size, ) actor_action_type = DiscreteHead( actor_head_hidden_size, diff --git a/ding/policy/ppo.py b/ding/policy/ppo.py index 289bc72c44..958cba1d83 100644 --- a/ding/policy/ppo.py +++ b/ding/policy/ppo.py @@ -52,6 +52,11 @@ class PPOPolicy(Policy): batch_size=64, # (float) The step size of gradient descent. learning_rate=3e-4, + # (dict or None) The learning rate decay. + # If not None, should contain key 'epoch_num' and 'min_lr_lambda'. + # where 'epoch_num' is the total epoch num to decay the learning rate to min value, + # 'min_lr_lambda' is the final decayed learning rate. + lr_scheduler=None, # (float) The loss weight of value network, policy network weight is set to 1. value_weight=0.5, # (float) The loss weight of entropy regularization, policy network weight is set to 1. @@ -169,6 +174,16 @@ def _init_learn(self) -> None: clip_value=self._cfg.learn.grad_clip_value ) + # Define linear lr scheduler + if self._cfg.learn.lr_scheduler is not None: + epoch_num = self._cfg.learn.lr_scheduler['epoch_num'] + min_lr_lambda = self._cfg.learn.lr_scheduler['min_lr_lambda'] + + self._lr_scheduler = torch.optim.lr_scheduler.LambdaLR( + self._optimizer, + lr_lambda=lambda epoch: max(1.0 - epoch * (1.0 - min_lr_lambda) / epoch_num, min_lr_lambda) + ) + self._learn_model = model_wrap(self._model, wrapper_name='base') # Algorithm config @@ -314,8 +329,13 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: total_loss.backward() self._optimizer.step() + if self._cfg.learn.lr_scheduler is not None: + cur_lr = sum(self._lr_scheduler.get_last_lr()) / len(self._lr_scheduler.get_last_lr()) + else: + cur_lr = self._optimizer.defaults['lr'] + return_info = { - 'cur_lr': self._optimizer.defaults['lr'], + 'cur_lr': cur_lr, 'total_loss': total_loss.item(), 'policy_loss': ppo_loss.policy_loss.item(), 'value_loss': ppo_loss.value_loss.item(), @@ -336,6 +356,10 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: } ) return_infos.append(return_info) + + if self._cfg.learn.lr_scheduler is not None: + self._lr_scheduler.step() + return return_infos def _init_collect(self) -> None: diff --git a/dizoo/mujoco/config/ant_onppo_config.py b/dizoo/mujoco/config/ant_onppo_config.py index 73d5ba344b..32793ffecc 100644 --- a/dizoo/mujoco/config/ant_onppo_config.py +++ b/dizoo/mujoco/config/ant_onppo_config.py @@ -1,4 +1,5 @@ from easydict import EasyDict +import torch.nn as nn ant_ppo_config = dict( exp_name="ant_onppo_seed0", @@ -17,15 +18,24 @@ recompute_adv=True, action_space='continuous', model=dict( + encoder_hidden_size_list=[128, 128], action_space='continuous', obs_shape=111, action_shape=8, + share_encoder=False, + actor_head_layer_num=0, + critic_head_layer_num=2, + critic_head_hidden_size=256, + actor_head_hidden_size=128, + activation=nn.Tanh(), + bound_type='tanh', ), learn=dict( epoch_per_collect=10, update_per_collect=1, - batch_size=320, + batch_size=128, learning_rate=3e-4, + lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0), value_weight=0.5, entropy_weight=0.001, clip_ratio=0.2, @@ -39,7 +49,7 @@ grad_clip_value=0.5, ), collect=dict( - n_sample=3200, + n_sample=2048, unroll_len=1, discount_factor=0.99, gae_lambda=0.95, diff --git a/dizoo/mujoco/config/halfcheetah_onppo_config.py b/dizoo/mujoco/config/halfcheetah_onppo_config.py index 87046ff6f5..f63f296167 100644 --- a/dizoo/mujoco/config/halfcheetah_onppo_config.py +++ b/dizoo/mujoco/config/halfcheetah_onppo_config.py @@ -1,7 +1,8 @@ from easydict import EasyDict +import torch.nn as nn -collector_env_num = 1 -evaluator_env_num = 1 +collector_env_num = 8 +evaluator_env_num = 8 halfcheetah_ppo_config = dict( exp_name='halfcheetah_onppo_seed0', env=dict( @@ -10,7 +11,7 @@ norm_reward=dict(use_norm=False, ), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, - n_evaluator_episode=1, + n_evaluator_episode=8, stop_value=12000, ), policy=dict( @@ -18,15 +19,24 @@ recompute_adv=True, action_space='continuous', model=dict( + encoder_hidden_size_list=[128, 128], action_space='continuous', + share_encoder=False, + actor_head_layer_num=0, + critic_head_layer_num=2, + critic_head_hidden_size=256, + actor_head_hidden_size=128, obs_shape=17, action_shape=6, + activation=nn.Tanh(), + bound_type='tanh', ), learn=dict( epoch_per_collect=10, update_per_collect=1, - batch_size=320, + batch_size=128, learning_rate=3e-4, + lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0), value_weight=0.5, entropy_weight=0.001, clip_ratio=0.2, @@ -42,12 +52,12 @@ ), collect=dict( collector_env_num=collector_env_num, - n_sample=3200, + n_sample=2048, unroll_len=1, discount_factor=0.99, gae_lambda=0.95, ), - eval=dict(evaluator=dict(eval_freq=500, )), + eval=dict(evaluator=dict(eval_freq=5000, )), ), ) halfcheetah_ppo_config = EasyDict(halfcheetah_ppo_config) diff --git a/dizoo/mujoco/config/hopper_onppo_config.py b/dizoo/mujoco/config/hopper_onppo_config.py index 2cbf05a553..0853aa4abb 100644 --- a/dizoo/mujoco/config/hopper_onppo_config.py +++ b/dizoo/mujoco/config/hopper_onppo_config.py @@ -1,4 +1,5 @@ from easydict import EasyDict +import torch.nn as nn hopper_onppo_config = dict( exp_name='hopper_onppo_seed0', @@ -12,19 +13,28 @@ stop_value=4000, ), policy=dict( - cuda=True, + cuda=False, recompute_adv=True, action_space='continuous', model=dict( + encoder_hidden_size_list=[128, 128], obs_shape=11, action_shape=3, action_space='continuous', + share_encoder=False, + actor_head_layer_num=0, + critic_head_layer_num=2, + critic_head_hidden_size=256, + actor_head_hidden_size=128, + activation=nn.Tanh(), + bound_type='tanh', ), learn=dict( epoch_per_collect=10, update_per_collect=1, - batch_size=320, + batch_size=128, learning_rate=3e-4, + lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0), value_weight=0.5, entropy_weight=0.001, clip_ratio=0.2, @@ -39,12 +49,12 @@ grad_clip_value=0.5, ), collect=dict( - n_sample=3200, + n_sample=2048, unroll_len=1, discount_factor=0.99, gae_lambda=0.95, ), - eval=dict(evaluator=dict(eval_freq=500, )), + eval=dict(evaluator=dict(eval_freq=5000, )), ), ) hopper_onppo_config = EasyDict(hopper_onppo_config) diff --git a/dizoo/mujoco/config/walker2d_onppo_config.py b/dizoo/mujoco/config/walker2d_onppo_config.py index 035a998286..2437d62e43 100644 --- a/dizoo/mujoco/config/walker2d_onppo_config.py +++ b/dizoo/mujoco/config/walker2d_onppo_config.py @@ -1,7 +1,8 @@ from easydict import EasyDict +import torch.nn as nn -collector_env_num = 1 -evaluator_env_num = 1 +collector_env_num = 8 +evaluator_env_num = 8 walker2d_onppo_config = dict( exp_name='walker2d_onppo_seed0', env=dict( @@ -10,7 +11,7 @@ norm_reward=dict(use_norm=False, ), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, - n_evaluator_episode=10, + n_evaluator_episode=8, stop_value=6000, ), policy=dict( @@ -18,15 +19,24 @@ recompute_adv=True, action_space='continuous', model=dict( + encoder_hidden_size_list=[128, 128], action_space='continuous', + share_encoder=False, + actor_head_layer_num=0, + critic_head_layer_num=2, + critic_head_hidden_size=256, + actor_head_hidden_size=128, obs_shape=17, action_shape=6, + activation=nn.Tanh(), + bound_type='tanh', ), learn=dict( epoch_per_collect=10, update_per_collect=1, - batch_size=320, + batch_size=128, learning_rate=3e-4, + lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0), value_weight=0.5, entropy_weight=0.001, clip_ratio=0.2, @@ -43,12 +53,12 @@ ), collect=dict( collector_env_num=collector_env_num, - n_sample=3200, + n_sample=2048, unroll_len=1, discount_factor=0.99, gae_lambda=0.95, ), - eval=dict(evaluator=dict(eval_freq=500, )), + eval=dict(evaluator=dict(eval_freq=5000, )), ), ) walker2d_onppo_config = EasyDict(walker2d_onppo_config)