Skip to content

Commit 8cc7d05

Browse files
committed
[IBR-2068] Add ppo with discrete action
1 parent d05cc0a commit 8cc7d05

File tree

5 files changed

+106
-4
lines changed

5 files changed

+106
-4
lines changed

configs/lunarlander_v2/ppo.yaml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
type: "PPOAgent"
2+
hyper_params:
3+
gamma: 0.99
4+
tau: 0.95
5+
batch_size: 32
6+
max_epsilon: 0.2
7+
min_epsilon: 0.2
8+
epsilon_decay_period: 1500
9+
w_value: 1.0
10+
w_entropy: 0.001
11+
gradient_clip_ac: 0.5
12+
gradient_clip_cr: 1.0
13+
epoch: 16
14+
rollout_len: 256
15+
n_workers: 12
16+
use_clipped_value_loss: False
17+
standardize_advantage: True
18+
19+
learner_cfg:
20+
type: "PPOLearner"
21+
backbone:
22+
actor:
23+
critic:
24+
shared_actor_critic:
25+
head:
26+
actor:
27+
type: "CategoricalDist"
28+
configs:
29+
hidden_sizes: [256, 256]
30+
output_activation: "identity"
31+
critic:
32+
type: "MLP"
33+
configs:
34+
hidden_sizes: [256, 256]
35+
output_size: 1
36+
output_activation: "identity"
37+
optim_cfg:
38+
lr_actor: 0.0003
39+
lr_critic: 0.001
40+
weight_decay: 0.0

configs/pong_no_frameskip_v4/ppo.yaml

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
type: "PPOAgent"
2+
hyper_params:
3+
gamma: 0.99
4+
tau: 0.95
5+
batch_size: 32
6+
max_epsilon: 0.2
7+
min_epsilon: 0.2
8+
epsilon_decay_period: 1500
9+
w_value: 1.0
10+
w_entropy: 0.001
11+
gradient_clip_ac: 0.5
12+
gradient_clip_cr: 1.0
13+
epoch: 16
14+
rollout_len: 256
15+
n_workers: 4
16+
use_clipped_value_loss: False
17+
standardize_advantage: True
18+
19+
learner_cfg:
20+
type: "PPOLearner"
21+
backbone:
22+
actor:
23+
critic:
24+
shared_actor_critic:
25+
type: "CNN"
26+
configs:
27+
input_sizes: [4, 32, 64]
28+
output_sizes: [32, 64, 64]
29+
kernel_sizes: [8, 4, 3]
30+
strides: [4, 2, 1]
31+
paddings: [1, 0, 0]
32+
head:
33+
actor:
34+
type: "CategoricalDist"
35+
configs:
36+
hidden_sizes: [512]
37+
output_activation: "identity"
38+
critic:
39+
type: "MLP"
40+
configs:
41+
hidden_sizes: [512]
42+
output_size: 1
43+
output_activation: "identity"
44+
optim_cfg:
45+
lr_actor: 0.0003
46+
lr_critic: 0.001
47+
weight_decay: 0.0

rl_algorithms/common/networks/heads.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ def __init__(
221221
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
222222
"""Forward method implementation."""
223223
ac_logits = super().forward(x)
224-
# ac_probs = F.softmax(ac_logits, dim=-1)
225224

226225
# get categorical distribution and action
227226
dist = Categorical(logits=ac_logits)

rl_algorithms/ppo/agent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,10 @@ def __init__(
123123
self.learner = build_learner(self.learner_cfg, build_args)
124124

125125
def make_parallel_env(self, max_episode_steps, n_workers):
126-
env_gen = env_generator(self.env.spec.id, max_episode_steps)
126+
if "env_generator" in self.env_info.keys():
127+
env_gen = self.env_info.env_generator
128+
else:
129+
env_gen = env_generator(self.env.spec.id, max_episode_steps)
127130
env_multi = make_envs(env_gen, n_envs=n_workers)
128131
return env_multi
129132

@@ -135,7 +138,9 @@ def select_action(self, state: np.ndarray) -> torch.Tensor:
135138
log_prob = dist.log_prob(selected_action)
136139

137140
if self.is_test:
138-
selected_action = dist.mean
141+
selected_action = (
142+
dist.logits.argmax() if self.is_discrete else dist.mean
143+
)
139144

140145
else:
141146
_selected_action = (

run_pong_no_frameskip_v4.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,25 @@ def parse_args() -> argparse.Namespace:
8484
return parser.parse_args()
8585

8686

87+
def env_generator(env_name, max_episode_steps, frame_stack):
88+
def _thunk(rank: int):
89+
env = atari_env_generator(env_name, max_episode_steps, frame_stack=frame_stack)
90+
env.seed(777 + rank + 1)
91+
return env
92+
93+
return _thunk
94+
95+
8796
def main():
8897
"""Main."""
8998
args = parse_args()
9099

91100
# env initialization
92101
env_name = "PongNoFrameskip-v4"
93-
env = atari_env_generator(
102+
env_gen = env_generator(
94103
env_name, args.max_episode_steps, frame_stack=args.framestack
95104
)
105+
env = env_gen(0)
96106

97107
# set a random seed
98108
common_utils.set_random_seed(args.seed, env)
@@ -112,6 +122,7 @@ def main():
112122
observation_space=env.observation_space,
113123
action_space=env.action_space,
114124
is_atari=True,
125+
env_generator=env_gen,
115126
)
116127
log_cfg = dict(agent=cfg.agent.type, curr_time=curr_time, cfg_path=args.cfg_path)
117128
build_args = dict(

0 commit comments

Comments
 (0)