File tree Expand file tree Collapse file tree 5 files changed +106
-4
lines changed Expand file tree Collapse file tree 5 files changed +106
-4
lines changed Original file line number Diff line number Diff line change
1
+ type : " PPOAgent"
2
+ hyper_params :
3
+ gamma : 0.99
4
+ tau : 0.95
5
+ batch_size : 32
6
+ max_epsilon : 0.2
7
+ min_epsilon : 0.2
8
+ epsilon_decay_period : 1500
9
+ w_value : 1.0
10
+ w_entropy : 0.001
11
+ gradient_clip_ac : 0.5
12
+ gradient_clip_cr : 1.0
13
+ epoch : 16
14
+ rollout_len : 256
15
+ n_workers : 12
16
+ use_clipped_value_loss : False
17
+ standardize_advantage : True
18
+
19
+ learner_cfg :
20
+ type : " PPOLearner"
21
+ backbone :
22
+ actor :
23
+ critic :
24
+ shared_actor_critic :
25
+ head :
26
+ actor :
27
+ type : " CategoricalDist"
28
+ configs :
29
+ hidden_sizes : [256, 256]
30
+ output_activation : " identity"
31
+ critic :
32
+ type : " MLP"
33
+ configs :
34
+ hidden_sizes : [256, 256]
35
+ output_size : 1
36
+ output_activation : " identity"
37
+ optim_cfg :
38
+ lr_actor : 0.0003
39
+ lr_critic : 0.001
40
+ weight_decay : 0.0
Original file line number Diff line number Diff line change
1
+ type : " PPOAgent"
2
+ hyper_params :
3
+ gamma : 0.99
4
+ tau : 0.95
5
+ batch_size : 32
6
+ max_epsilon : 0.2
7
+ min_epsilon : 0.2
8
+ epsilon_decay_period : 1500
9
+ w_value : 1.0
10
+ w_entropy : 0.001
11
+ gradient_clip_ac : 0.5
12
+ gradient_clip_cr : 1.0
13
+ epoch : 16
14
+ rollout_len : 256
15
+ n_workers : 4
16
+ use_clipped_value_loss : False
17
+ standardize_advantage : True
18
+
19
+ learner_cfg :
20
+ type : " PPOLearner"
21
+ backbone :
22
+ actor :
23
+ critic :
24
+ shared_actor_critic :
25
+ type : " CNN"
26
+ configs :
27
+ input_sizes : [4, 32, 64]
28
+ output_sizes : [32, 64, 64]
29
+ kernel_sizes : [8, 4, 3]
30
+ strides : [4, 2, 1]
31
+ paddings : [1, 0, 0]
32
+ head :
33
+ actor :
34
+ type : " CategoricalDist"
35
+ configs :
36
+ hidden_sizes : [512]
37
+ output_activation : " identity"
38
+ critic :
39
+ type : " MLP"
40
+ configs :
41
+ hidden_sizes : [512]
42
+ output_size : 1
43
+ output_activation : " identity"
44
+ optim_cfg :
45
+ lr_actor : 0.0003
46
+ lr_critic : 0.001
47
+ weight_decay : 0.0
Original file line number Diff line number Diff line change @@ -221,7 +221,6 @@ def __init__(
221
221
def forward (self , x : torch .Tensor ) -> Tuple [torch .Tensor , ...]:
222
222
"""Forward method implementation."""
223
223
ac_logits = super ().forward (x )
224
- # ac_probs = F.softmax(ac_logits, dim=-1)
225
224
226
225
# get categorical distribution and action
227
226
dist = Categorical (logits = ac_logits )
Original file line number Diff line number Diff line change @@ -123,7 +123,10 @@ def __init__(
123
123
self .learner = build_learner (self .learner_cfg , build_args )
124
124
125
125
def make_parallel_env (self , max_episode_steps , n_workers ):
126
- env_gen = env_generator (self .env .spec .id , max_episode_steps )
126
+ if "env_generator" in self .env_info .keys ():
127
+ env_gen = self .env_info .env_generator
128
+ else :
129
+ env_gen = env_generator (self .env .spec .id , max_episode_steps )
127
130
env_multi = make_envs (env_gen , n_envs = n_workers )
128
131
return env_multi
129
132
@@ -135,7 +138,9 @@ def select_action(self, state: np.ndarray) -> torch.Tensor:
135
138
log_prob = dist .log_prob (selected_action )
136
139
137
140
if self .is_test :
138
- selected_action = dist .mean
141
+ selected_action = (
142
+ dist .logits .argmax () if self .is_discrete else dist .mean
143
+ )
139
144
140
145
else :
141
146
_selected_action = (
Original file line number Diff line number Diff line change @@ -84,15 +84,25 @@ def parse_args() -> argparse.Namespace:
84
84
return parser .parse_args ()
85
85
86
86
87
+ def env_generator (env_name , max_episode_steps , frame_stack ):
88
+ def _thunk (rank : int ):
89
+ env = atari_env_generator (env_name , max_episode_steps , frame_stack = frame_stack )
90
+ env .seed (777 + rank + 1 )
91
+ return env
92
+
93
+ return _thunk
94
+
95
+
87
96
def main ():
88
97
"""Main."""
89
98
args = parse_args ()
90
99
91
100
# env initialization
92
101
env_name = "PongNoFrameskip-v4"
93
- env = atari_env_generator (
102
+ env_gen = env_generator (
94
103
env_name , args .max_episode_steps , frame_stack = args .framestack
95
104
)
105
+ env = env_gen (0 )
96
106
97
107
# set a random seed
98
108
common_utils .set_random_seed (args .seed , env )
@@ -112,6 +122,7 @@ def main():
112
122
observation_space = env .observation_space ,
113
123
action_space = env .action_space ,
114
124
is_atari = True ,
125
+ env_generator = env_gen ,
115
126
)
116
127
log_cfg = dict (agent = cfg .agent .type , curr_time = curr_time , cfg_path = args .cfg_path )
117
128
build_args = dict (
You can’t perform that action at this time.
0 commit comments