Skip to content

Commit

Permalink
polish(zym): optimize ppo continuous act (#801)
Browse files Browse the repository at this point in the history
* fix (zym): update hidden size of head in VAC module

* feat (zym): update ppo config to support continuous action space
  • Loading branch information
YinminZhang authored Jun 13, 2024
1 parent d919fa5 commit f5fed7c
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 31 deletions.
29 changes: 17 additions & 12 deletions ding/model/template/vac.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def __init__(
``ReparameterizationHead``, and hybrid heads.
- share_encoder (:obj:`bool`): Whether to share observation encoders between actor and decoder.
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
the last element must match ``head_hidden_size``.
the last element is used as the input size of ``actor_head`` and ``critic_head``.
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``actor_head`` network, defaults \
to 64, it must match the last element of ``encoder_hidden_size_list``.
to 64, it is the hidden size of the last layer of the ``actor_head`` network.
- actor_head_layer_num (:obj:`int`): The num of layers used in the ``actor_head`` network to compute action.
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``critic_head`` network, defaults \
to 64, it must match the last element of ``encoder_hidden_size_list``.
to 64, it is the hidden size of the last layer of the ``critic_head`` network.
- critic_head_layer_num (:obj:`int`): The num of layers used in the ``critic_head`` network.
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
if ``None`` then default set it to ``nn.ReLU()``.
Expand Down Expand Up @@ -108,15 +108,13 @@ def new_encoder(outsize, activation):
)

if self.share_encoder:
assert actor_head_hidden_size == critic_head_hidden_size, \
"actor and critic network head should have same size."
if encoder:
if isinstance(encoder, torch.nn.Module):
self.encoder = encoder
else:
raise ValueError("illegal encoder instance.")
else:
self.encoder = new_encoder(actor_head_hidden_size, activation)
self.encoder = new_encoder(encoder_hidden_size_list[-1], activation)
else:
if encoder:
if isinstance(encoder, torch.nn.Module):
Expand All @@ -125,25 +123,31 @@ def new_encoder(outsize, activation):
else:
raise ValueError("illegal encoder instance.")
else:
self.actor_encoder = new_encoder(actor_head_hidden_size, activation)
self.critic_encoder = new_encoder(critic_head_hidden_size, activation)
self.actor_encoder = new_encoder(encoder_hidden_size_list[-1], activation)
self.critic_encoder = new_encoder(encoder_hidden_size_list[-1], activation)

# Head Type
self.critic_head = RegressionHead(
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
encoder_hidden_size_list[-1],
1,
critic_head_layer_num,
activation=activation,
norm_type=norm_type,
hidden_size=critic_head_hidden_size
)
self.action_space = action_space
assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space
if self.action_space == 'continuous':
self.multi_head = False
self.actor_head = ReparameterizationHead(
actor_head_hidden_size,
encoder_hidden_size_list[-1],
action_shape,
actor_head_layer_num,
sigma_type=sigma_type,
activation=activation,
norm_type=norm_type,
bound_type=bound_type
bound_type=bound_type,
hidden_size=actor_head_hidden_size,
)
elif self.action_space == 'discrete':
actor_head_cls = DiscreteHead
Expand Down Expand Up @@ -172,14 +176,15 @@ def new_encoder(outsize, activation):
action_shape.action_args_shape = squeeze(action_shape.action_args_shape)
action_shape.action_type_shape = squeeze(action_shape.action_type_shape)
actor_action_args = ReparameterizationHead(
actor_head_hidden_size,
encoder_hidden_size_list[-1],
action_shape.action_args_shape,
actor_head_layer_num,
sigma_type=sigma_type,
fixed_sigma_value=fixed_sigma_value,
activation=activation,
norm_type=norm_type,
bound_type=bound_type,
hidden_size=actor_head_hidden_size,
)
actor_action_type = DiscreteHead(
actor_head_hidden_size,
Expand Down
26 changes: 25 additions & 1 deletion ding/policy/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class PPOPolicy(Policy):
batch_size=64,
# (float) The step size of gradient descent.
learning_rate=3e-4,
# (dict or None) The learning rate decay.
# If not None, should contain key 'epoch_num' and 'min_lr_lambda'.
# where 'epoch_num' is the total epoch num to decay the learning rate to min value,
# 'min_lr_lambda' is the final decayed learning rate.
lr_scheduler=None,
# (float) The loss weight of value network, policy network weight is set to 1.
value_weight=0.5,
# (float) The loss weight of entropy regularization, policy network weight is set to 1.
Expand Down Expand Up @@ -169,6 +174,16 @@ def _init_learn(self) -> None:
clip_value=self._cfg.learn.grad_clip_value
)

# Define linear lr scheduler
if self._cfg.learn.lr_scheduler is not None:
epoch_num = self._cfg.learn.lr_scheduler['epoch_num']
min_lr_lambda = self._cfg.learn.lr_scheduler['min_lr_lambda']

self._lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self._optimizer,
lr_lambda=lambda epoch: max(1.0 - epoch * (1.0 - min_lr_lambda) / epoch_num, min_lr_lambda)
)

self._learn_model = model_wrap(self._model, wrapper_name='base')

# Algorithm config
Expand Down Expand Up @@ -314,8 +329,13 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
total_loss.backward()
self._optimizer.step()

if self._cfg.learn.lr_scheduler is not None:
cur_lr = sum(self._lr_scheduler.get_last_lr()) / len(self._lr_scheduler.get_last_lr())
else:
cur_lr = self._optimizer.defaults['lr']

return_info = {
'cur_lr': self._optimizer.defaults['lr'],
'cur_lr': cur_lr,
'total_loss': total_loss.item(),
'policy_loss': ppo_loss.policy_loss.item(),
'value_loss': ppo_loss.value_loss.item(),
Expand All @@ -336,6 +356,10 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
}
)
return_infos.append(return_info)

if self._cfg.learn.lr_scheduler is not None:
self._lr_scheduler.step()

return return_infos

def _init_collect(self) -> None:
Expand Down
14 changes: 12 additions & 2 deletions dizoo/mujoco/config/ant_onppo_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from easydict import EasyDict
import torch.nn as nn

ant_ppo_config = dict(
exp_name="ant_onppo_seed0",
Expand All @@ -17,15 +18,24 @@
recompute_adv=True,
action_space='continuous',
model=dict(
encoder_hidden_size_list=[128, 128],
action_space='continuous',
obs_shape=111,
action_shape=8,
share_encoder=False,
actor_head_layer_num=0,
critic_head_layer_num=2,
critic_head_hidden_size=256,
actor_head_hidden_size=128,
activation=nn.Tanh(),
bound_type='tanh',
),
learn=dict(
epoch_per_collect=10,
update_per_collect=1,
batch_size=320,
batch_size=128,
learning_rate=3e-4,
lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0),
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
Expand All @@ -39,7 +49,7 @@
grad_clip_value=0.5,
),
collect=dict(
n_sample=3200,
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
Expand Down
22 changes: 16 additions & 6 deletions dizoo/mujoco/config/halfcheetah_onppo_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from easydict import EasyDict
import torch.nn as nn

collector_env_num = 1
evaluator_env_num = 1
collector_env_num = 8
evaluator_env_num = 8
halfcheetah_ppo_config = dict(
exp_name='halfcheetah_onppo_seed0',
env=dict(
Expand All @@ -10,23 +11,32 @@
norm_reward=dict(use_norm=False, ),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=1,
n_evaluator_episode=8,
stop_value=12000,
),
policy=dict(
cuda=True,
recompute_adv=True,
action_space='continuous',
model=dict(
encoder_hidden_size_list=[128, 128],
action_space='continuous',
share_encoder=False,
actor_head_layer_num=0,
critic_head_layer_num=2,
critic_head_hidden_size=256,
actor_head_hidden_size=128,
obs_shape=17,
action_shape=6,
activation=nn.Tanh(),
bound_type='tanh',
),
learn=dict(
epoch_per_collect=10,
update_per_collect=1,
batch_size=320,
batch_size=128,
learning_rate=3e-4,
lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0),
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
Expand All @@ -42,12 +52,12 @@
),
collect=dict(
collector_env_num=collector_env_num,
n_sample=3200,
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
eval=dict(evaluator=dict(eval_freq=500, )),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
halfcheetah_ppo_config = EasyDict(halfcheetah_ppo_config)
Expand Down
18 changes: 14 additions & 4 deletions dizoo/mujoco/config/hopper_onppo_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from easydict import EasyDict
import torch.nn as nn

hopper_onppo_config = dict(
exp_name='hopper_onppo_seed0',
Expand All @@ -12,19 +13,28 @@
stop_value=4000,
),
policy=dict(
cuda=True,
cuda=False,
recompute_adv=True,
action_space='continuous',
model=dict(
encoder_hidden_size_list=[128, 128],
obs_shape=11,
action_shape=3,
action_space='continuous',
share_encoder=False,
actor_head_layer_num=0,
critic_head_layer_num=2,
critic_head_hidden_size=256,
actor_head_hidden_size=128,
activation=nn.Tanh(),
bound_type='tanh',
),
learn=dict(
epoch_per_collect=10,
update_per_collect=1,
batch_size=320,
batch_size=128,
learning_rate=3e-4,
lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0),
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
Expand All @@ -39,12 +49,12 @@
grad_clip_value=0.5,
),
collect=dict(
n_sample=3200,
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
eval=dict(evaluator=dict(eval_freq=500, )),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
hopper_onppo_config = EasyDict(hopper_onppo_config)
Expand Down
22 changes: 16 additions & 6 deletions dizoo/mujoco/config/walker2d_onppo_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from easydict import EasyDict
import torch.nn as nn

collector_env_num = 1
evaluator_env_num = 1
collector_env_num = 8
evaluator_env_num = 8
walker2d_onppo_config = dict(
exp_name='walker2d_onppo_seed0',
env=dict(
Expand All @@ -10,23 +11,32 @@
norm_reward=dict(use_norm=False, ),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=10,
n_evaluator_episode=8,
stop_value=6000,
),
policy=dict(
cuda=True,
recompute_adv=True,
action_space='continuous',
model=dict(
encoder_hidden_size_list=[128, 128],
action_space='continuous',
share_encoder=False,
actor_head_layer_num=0,
critic_head_layer_num=2,
critic_head_hidden_size=256,
actor_head_hidden_size=128,
obs_shape=17,
action_shape=6,
activation=nn.Tanh(),
bound_type='tanh',
),
learn=dict(
epoch_per_collect=10,
update_per_collect=1,
batch_size=320,
batch_size=128,
learning_rate=3e-4,
lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0),
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
Expand All @@ -43,12 +53,12 @@
),
collect=dict(
collector_env_num=collector_env_num,
n_sample=3200,
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
eval=dict(evaluator=dict(eval_freq=500, )),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
walker2d_onppo_config = EasyDict(walker2d_onppo_config)
Expand Down

0 comments on commit f5fed7c

Please sign in to comment.