Skip to content

Commit 3d106e5

Browse files
Improve PPO algorithm (#312)
* [IBR-2068] Modify standard deviation of gaussian action in ppo * [IBR-2068] Add ppo algorithm for discrete action * [IBR-2068] Add shared backbone for actor critic * [IBR-2068] Fix gpu oom bug * [IBR-2068] Tuning hyper-parameters for ppo * [IBR-2068] Modify multi env * [IBR-2068] Modify learner for shared actor critic * [IBR-2068] Rollback ppo config * [IBR-2068] Add ppo with discrete action * [IBR-2068]Remove retain_graph option * docs: add isk03276 as a contributor for code (#314) * docs: update README.md [skip ci] * docs: update .all-contributorsrc [skip ci] Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com>
1 parent b3df31e commit 3d106e5

File tree

19 files changed

+280
-70
lines changed

19 files changed

+280
-70
lines changed

.all-contributorsrc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@
8585
"contributions": [
8686
"maintenance"
8787
]
88+
},
89+
{
90+
"login": "isk03276",
91+
"name": "eunjin",
92+
"avatar_url": "https://avatars.githubusercontent.com/u/23740495?v=4",
93+
"profile": "https://github.com/isk03276",
94+
"contributions": [
95+
"code"
96+
]
8897
}
8998
],
9099
"contributorsPerLine": 7,

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
[![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/medipixel/rl_algorithms.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/medipixel/rl_algorithms/context:python)
55
[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)
66
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
7-
[![All Contributors](https://img.shields.io/badge/all_contributors-9-orange.svg?style=flat-square)](#contributors-)
7+
[![All Contributors](https://img.shields.io/badge/all_contributors-10-orange.svg?style=flat-square)](#contributors-)
88
<!-- ALL-CONTRIBUTORS-BADGE:END -->
99

1010
</p>
@@ -47,6 +47,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
4747
<tr>
4848
<td align="center"><a href="https://jiseonghan.github.io/"><img src="https://avatars2.githubusercontent.com/u/48741026?v=4?s=100" width="100px;" alt=""/><br /><sub><b>Jiseong Han</b></sub></a><br /><a href="https://github.com/medipixel/rl_algorithms/commits?author=jiseongHAN" title="Code">💻</a></td>
4949
<td align="center"><a href="https://github.com/sehyun-hwang"><img src="https://avatars3.githubusercontent.com/u/23437715?v=4?s=100" width="100px;" alt=""/><br /><sub><b>Sehyun Hwang</b></sub></a><br /><a href="#maintenance-sehyun-hwang" title="Maintenance">🚧</a></td>
50+
<td align="center"><a href="https://github.com/isk03276"><img src="https://avatars.githubusercontent.com/u/23740495?v=4?s=100" width="100px;" alt=""/><br /><sub><b>eunjin</b></sub></a><br /><a href="https://github.com/medipixel/rl_algorithms/commits?author=isk03276" title="Code">💻</a></td>
5051
</tr>
5152
</table>
5253

configs/lunarlander_continuous_v2/a2c.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ learner_cfg:
1010
backbone:
1111
actor:
1212
critic:
13+
shared_actor_critic:
1314
head:
1415
actor:
1516
type: "GaussianDist"
1617
configs:
1718
hidden_sizes: [256, 256]
1819
output_activation: "identity"
20+
fixed_logstd: True
1921
critic:
2022
type: "MLP"
2123
configs:

configs/lunarlander_continuous_v2/bc_ddpg.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ learner_cfg:
2424
backbone:
2525
actor:
2626
critic:
27+
shared_actor_critic:
2728
head:
2829
actor:
2930
type: "MLP"

configs/lunarlander_continuous_v2/bc_sac.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@ learner_cfg:
2929
actor:
3030
critic_vf:
3131
critic_qf:
32+
shared_actor_critic:
3233
head:
3334
actor:
3435
type: "TanhGaussianDistParams"
3536
configs:
3637
hidden_sizes: [256, 256]
3738
output_activation: "identity"
39+
fixed_logstd: False
3840
critic_vf:
3941
type: "MLP"
4042
configs:

configs/lunarlander_continuous_v2/ddpg.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ learner_cfg:
1414
backbone:
1515
actor:
1616
critic:
17+
shared_actor_critic:
1718
head:
1819
actor:
1920
type: "MLP"

configs/lunarlander_continuous_v2/ddpgfd.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ learner_cfg:
2525
backbone:
2626
actor:
2727
critic:
28+
shared_actor_critic:
2829
head:
2930
actor:
3031
type: "MLP"

configs/lunarlander_continuous_v2/ppo.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,27 @@ hyper_params:
88
epsilon_decay_period: 1500
99
w_value: 1.0
1010
w_entropy: 0.001
11-
gradient_clip_ac: 1.0
12-
gradient_clip_cr: 0.5
11+
gradient_clip_ac: 0.5
12+
gradient_clip_cr: 1.0
1313
epoch: 16
1414
rollout_len: 256
1515
n_workers: 12
16-
use_clipped_value_loss: True
16+
use_clipped_value_loss: False
1717
standardize_advantage: True
18-
is_discrete: False
1918

2019
learner_cfg:
2120
type: "PPOLearner"
2221
backbone:
2322
actor:
2423
critic:
24+
shared_actor_critic:
2525
head:
2626
actor:
2727
type: "GaussianDist"
2828
configs:
2929
hidden_sizes: [256, 256]
30-
output_activation: "tanh"
30+
output_activation: "identity"
31+
fixed_logstd: True
3132
critic:
3233
type: "MLP"
3334
configs:

configs/lunarlander_continuous_v2/sac.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ learner_cfg:
1919
actor:
2020
critic_vf:
2121
critic_qf:
22+
shared_actor_critic:
2223
head:
2324
actor:
2425
type: "TanhGaussianDistParams"
2526
configs:
2627
hidden_sizes: [256, 256]
2728
output_activation: "identity"
29+
fixed_logstd: False
2830
critic_vf:
2931
type: "MLP"
3032
configs:

configs/lunarlander_continuous_v2/sacfd.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@ learner_cfg:
3030
actor:
3131
critic_vf:
3232
critic_qf:
33+
shared_actor_critic:
3334
head:
3435
actor:
3536
type: "TanhGaussianDistParams"
3637
configs:
3738
hidden_sizes: [256, 256]
3839
output_activation: "identity"
40+
fixed_logstd: False
3941
critic_vf:
4042
type: "MLP"
4143
configs:

configs/lunarlander_continuous_v2/td3.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ learner_cfg:
1212
backbone:
1313
actor:
1414
critic:
15+
shared_actor_critic:
1516
head:
1617
actor:
1718
type: "MLP"

configs/lunarlander_v2/ppo.yaml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
type: "PPOAgent"
2+
hyper_params:
3+
gamma: 0.99
4+
tau: 0.95
5+
batch_size: 32
6+
max_epsilon: 0.2
7+
min_epsilon: 0.2
8+
epsilon_decay_period: 1500
9+
w_value: 1.0
10+
w_entropy: 0.001
11+
gradient_clip_ac: 0.5
12+
gradient_clip_cr: 1.0
13+
epoch: 16
14+
rollout_len: 256
15+
n_workers: 12
16+
use_clipped_value_loss: False
17+
standardize_advantage: True
18+
19+
learner_cfg:
20+
type: "PPOLearner"
21+
backbone:
22+
actor:
23+
critic:
24+
shared_actor_critic:
25+
head:
26+
actor:
27+
type: "CategoricalDist"
28+
configs:
29+
hidden_sizes: [256, 256]
30+
output_activation: "identity"
31+
critic:
32+
type: "MLP"
33+
configs:
34+
hidden_sizes: [256, 256]
35+
output_size: 1
36+
output_activation: "identity"
37+
optim_cfg:
38+
lr_actor: 0.0003
39+
lr_critic: 0.001
40+
weight_decay: 0.0

configs/pong_no_frameskip_v4/ppo.yaml

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
type: "PPOAgent"
2+
hyper_params:
3+
gamma: 0.99
4+
tau: 0.95
5+
batch_size: 32
6+
max_epsilon: 0.2
7+
min_epsilon: 0.2
8+
epsilon_decay_period: 1500
9+
w_value: 1.0
10+
w_entropy: 0.001
11+
gradient_clip_ac: 0.5
12+
gradient_clip_cr: 1.0
13+
epoch: 16
14+
rollout_len: 256
15+
n_workers: 4
16+
use_clipped_value_loss: False
17+
standardize_advantage: True
18+
19+
learner_cfg:
20+
type: "PPOLearner"
21+
backbone:
22+
actor:
23+
critic:
24+
shared_actor_critic:
25+
type: "CNN"
26+
configs:
27+
input_sizes: [4, 32, 64]
28+
output_sizes: [32, 64, 64]
29+
kernel_sizes: [8, 4, 3]
30+
strides: [4, 2, 1]
31+
paddings: [1, 0, 0]
32+
head:
33+
actor:
34+
type: "CategoricalDist"
35+
configs:
36+
hidden_sizes: [512]
37+
output_activation: "identity"
38+
critic:
39+
type: "MLP"
40+
configs:
41+
hidden_sizes: [512]
42+
output_size: 1
43+
output_activation: "identity"
44+
optim_cfg:
45+
lr_actor: 0.0003
46+
lr_critic: 0.001
47+
weight_decay: 0.0

rl_algorithms/common/abstract/agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def __init__(
6666
self.total_step = 0
6767
self.learner = None
6868

69+
self.is_discrete = isinstance(self.env_info.action_space, gym.spaces.Discrete)
70+
6971
@abstractmethod
7072
def select_action(self, state: np.ndarray) -> Union[torch.Tensor, np.ndarray]:
7173
pass

rl_algorithms/common/networks/brain.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,16 @@ def __init__(
2727
self,
2828
backbone_cfg: ConfigDict,
2929
head_cfg: ConfigDict,
30+
shared_backbone: nn.Module = None,
3031
):
3132
"""Initialize."""
3233
nn.Module.__init__(self)
33-
if not backbone_cfg:
34+
if shared_backbone is not None:
35+
self.backbone = shared_backbone
36+
head_cfg.configs.input_size = self.calculate_fc_input_size(
37+
head_cfg.configs.state_size
38+
)
39+
elif not backbone_cfg:
3440
self.backbone = identity
3541
head_cfg.configs.input_size = head_cfg.configs.state_size[0]
3642
else:

rl_algorithms/common/networks/heads.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Callable, Tuple
99

1010
import torch
11-
from torch.distributions import Normal
11+
from torch.distributions import Categorical, Normal
1212
import torch.nn as nn
1313
import torch.nn.functional as F
1414

@@ -127,10 +127,15 @@ def __init__(
127127
self.log_std_min = log_std_min
128128
self.log_std_max = log_std_max
129129
in_size = configs.hidden_sizes[-1]
130+
self.fixed_logstd = configs.fixed_logstd
130131

131-
# set log_std layer
132-
self.log_std_layer = nn.Linear(in_size, configs.output_size)
133-
self.log_std_layer = init_fn(self.log_std_layer)
132+
# set log_std
133+
if self.fixed_logstd:
134+
log_std = -0.5 * torch.ones(self.output_size, dtype=torch.float32)
135+
self.log_std = torch.nn.Parameter(log_std)
136+
else:
137+
self.log_std_layer = nn.Linear(in_size, configs.output_size)
138+
self.log_std_layer = init_fn(self.log_std_layer)
134139

135140
# set mean layer
136141
self.mu_layer = nn.Linear(in_size, configs.output_size)
@@ -144,10 +149,13 @@ def get_dist_params(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
144149
mu = self.mu_activation(self.mu_layer(hidden))
145150

146151
# get std
147-
log_std = torch.tanh(self.log_std_layer(hidden))
148-
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (
149-
log_std + 1
150-
)
152+
if self.fixed_logstd:
153+
log_std = self.log_std
154+
else:
155+
log_std = torch.tanh(self.log_std_layer(hidden))
156+
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (
157+
log_std + 1
158+
)
151159
std = torch.exp(log_std)
152160

153161
return mu, log_std, std
@@ -190,3 +198,32 @@ def forward(
190198
log_prob = log_prob.sum(-1, keepdim=True)
191199

192200
return action, log_prob, z, mu, std
201+
202+
203+
# TODO: Remove it when upgrade torch>=1.7
204+
# pylint: disable=abstract-method
205+
@HEADS.register_module
206+
class CategoricalDist(MLP):
207+
"""Multilayer perceptron with Categorical distribution output."""
208+
209+
def __init__(
210+
self,
211+
configs: ConfigDict,
212+
hidden_activation: Callable = F.relu,
213+
):
214+
"""Initialize."""
215+
super().__init__(
216+
configs=configs,
217+
hidden_activation=hidden_activation,
218+
use_output_layer=True,
219+
)
220+
221+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
222+
"""Forward method implementation."""
223+
ac_logits = super().forward(x)
224+
225+
# get categorical distribution and action
226+
dist = Categorical(logits=ac_logits)
227+
action = dist.sample()
228+
229+
return action, dist

0 commit comments

Comments
 (0)