@@ -15,28 +15,30 @@ using namespace cpprl;
1515
1616// Algorithm hyperparameters
1717const std::string algorithm = " PPO" ;
18- const int batch_size = 512 ;
18+ const int batch_size = 256 ;
1919const float clip_param = 0.2 ;
2020const float discount_factor = 0.99 ;
21- const float entropy_coef = 1e-5 ;
21+ const float entropy_coef = 0.0 ;
2222const float gae = 0.95 ;
2323const float learning_rate = 3e-4 ;
2424const int log_interval = 1 ;
25- const int num_epoch = 3 ;
26- const int num_mini_batch = 32 ;
25+ const int max_frames = 1e+7 ;
26+ const int num_epoch = 10 ;
27+ const int num_mini_batch = 8 ;
2728const int reward_average_window_size = 10 ;
2829const bool use_gae = true ;
30+ const bool use_lr_decay = true ;
2931const float value_loss_coef = 0.5 ;
3032
3133// Environment hyperparameters
3234const float env_gamma = discount_factor; // Set to -1 to disable
33- const std::string env_name = " BipedalWalker -v2" ;
35+ const std::string env_name = " BipedalWalkerHardcore -v2" ;
3436const int num_envs = 8 ;
35- const float render_reward_threshold = 300 ;
37+ const float render_reward_threshold = 250 ;
3638
3739// Model hyperparameters
3840const int hidden_size = 64 ;
39- const bool recurrent = false ;
41+ const bool recurrent = true ;
4042const bool use_cuda = false ;
4143
4244std::vector<float > flatten_vector (std::vector<float > const &input)
@@ -143,7 +145,8 @@ int main(int argc, char *argv[])
143145
144146 auto start_time = std::chrono::high_resolution_clock::now ();
145147
146- for (int update = 0 ; update < 100000 ; ++update)
148+ int num_updates = max_frames / (batch_size * num_envs);
149+ for (int update = 0 ; update < num_updates; ++update)
147150 {
148151 for (int step = 0 ; step < batch_size; ++step)
149152 {
@@ -227,7 +230,16 @@ int main(int argc, char *argv[])
227230 }
228231 storage.compute_returns (next_value, use_gae, discount_factor, gae);
229232
230- auto update_data = algo->update (storage);
233+ float decay_level;
234+ if (use_lr_decay)
235+ {
236+ decay_level = 1 . - static_cast <float >(update) / num_updates;
237+ }
238+ else
239+ {
240+ decay_level = 1 ;
241+ }
242+ auto update_data = algo->update (storage, decay_level);
231243 storage.after_update ();
232244
233245 if (update % log_interval == 0 && update > 0 )
@@ -237,7 +249,7 @@ int main(int argc, char *argv[])
237249 auto run_time_secs = std::chrono::duration_cast<std::chrono::seconds>(run_time);
238250 auto fps = total_steps / (run_time_secs.count () + 1e-9 );
239251 spdlog::info (" ---" );
240- spdlog::info (" Update: {}" , update);
252+ spdlog::info (" Update: {}/{} " , update, num_updates );
241253 spdlog::info (" Total frames: {}" , total_steps);
242254 spdlog::info (" FPS: {}" , fps);
243255 for (const auto &datum : update_data)
0 commit comments