Skip to content
This repository was archived by the owner on Dec 28, 2023. It is now read-only.

Commit 82f7395

Browse files
author
Isaac Poulton
authored
Implement learning rate decay (#9)
1 parent 62c9176 commit 82f7395

File tree

6 files changed

+43
-21
lines changed

6 files changed

+43
-21
lines changed

example/gym_client.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,30 @@ using namespace cpprl;
1515

1616
// Algorithm hyperparameters
1717
const std::string algorithm = "PPO";
18-
const int batch_size = 512;
18+
const int batch_size = 256;
1919
const float clip_param = 0.2;
2020
const float discount_factor = 0.99;
21-
const float entropy_coef = 1e-5;
21+
const float entropy_coef = 0.0;
2222
const float gae = 0.95;
2323
const float learning_rate = 3e-4;
2424
const int log_interval = 1;
25-
const int num_epoch = 3;
26-
const int num_mini_batch = 32;
25+
const int max_frames = 1e+7;
26+
const int num_epoch = 10;
27+
const int num_mini_batch = 8;
2728
const int reward_average_window_size = 10;
2829
const bool use_gae = true;
30+
const bool use_lr_decay = true;
2931
const float value_loss_coef = 0.5;
3032

3133
// Environment hyperparameters
3234
const float env_gamma = discount_factor; // Set to -1 to disable
33-
const std::string env_name = "BipedalWalker-v2";
35+
const std::string env_name = "BipedalWalkerHardcore-v2";
3436
const int num_envs = 8;
35-
const float render_reward_threshold = 300;
37+
const float render_reward_threshold = 250;
3638

3739
// Model hyperparameters
3840
const int hidden_size = 64;
39-
const bool recurrent = false;
41+
const bool recurrent = true;
4042
const bool use_cuda = false;
4143

4244
std::vector<float> flatten_vector(std::vector<float> const &input)
@@ -143,7 +145,8 @@ int main(int argc, char *argv[])
143145

144146
auto start_time = std::chrono::high_resolution_clock::now();
145147

146-
for (int update = 0; update < 100000; ++update)
148+
int num_updates = max_frames / (batch_size * num_envs);
149+
for (int update = 0; update < num_updates; ++update)
147150
{
148151
for (int step = 0; step < batch_size; ++step)
149152
{
@@ -227,7 +230,16 @@ int main(int argc, char *argv[])
227230
}
228231
storage.compute_returns(next_value, use_gae, discount_factor, gae);
229232

230-
auto update_data = algo->update(storage);
233+
float decay_level;
234+
if (use_lr_decay)
235+
{
236+
decay_level = 1. - static_cast<float>(update) / num_updates;
237+
}
238+
else
239+
{
240+
decay_level = 1;
241+
}
242+
auto update_data = algo->update(storage, decay_level);
231243
storage.after_update();
232244

233245
if (update % log_interval == 0 && update > 0)
@@ -237,7 +249,7 @@ int main(int argc, char *argv[])
237249
auto run_time_secs = std::chrono::duration_cast<std::chrono::seconds>(run_time);
238250
auto fps = total_steps / (run_time_secs.count() + 1e-9);
239251
spdlog::info("---");
240-
spdlog::info("Update: {}", update);
252+
spdlog::info("Update: {}/{}", update, num_updates);
241253
spdlog::info("Total frames: {}", total_steps);
242254
spdlog::info("FPS: {}", fps);
243255
for (const auto &datum : update_data)

include/cpprl/algorithms/a2c.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ class A2C : public Algorithm
1616
{
1717
private:
1818
Policy &policy;
19-
float value_loss_coef, entropy_coef, max_grad_norm;
20-
std::unique_ptr<torch::optim::Optimizer> optimizer;
19+
float value_loss_coef, entropy_coef, max_grad_norm, original_learning_rate;
20+
std::unique_ptr<torch::optim::RMSprop> optimizer;
2121

2222
public:
2323
A2C(Policy &policy,
@@ -28,6 +28,6 @@ class A2C : public Algorithm
2828
float alpha = 0.99,
2929
float max_grad_norm = 0.5);
3030

31-
std::vector<UpdateDatum> update(RolloutStorage &rollouts);
31+
std::vector<UpdateDatum> update(RolloutStorage &rollouts, float decay_level = 1);
3232
};
3333
}

include/cpprl/algorithms/algorithm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Algorithm
1818
public:
1919
virtual ~Algorithm() = 0;
2020

21-
virtual std::vector<UpdateDatum> update(RolloutStorage &rollouts) = 0;
21+
virtual std::vector<UpdateDatum> update(RolloutStorage &rollouts, float decay_level = 1) = 0;
2222
};
2323

2424
inline Algorithm::~Algorithm() {}

include/cpprl/algorithms/ppo.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ class PPO : public Algorithm
1616
{
1717
private:
1818
Policy &policy;
19-
float clip_param, value_loss_coef, entropy_coef, max_grad_norm;
19+
float value_loss_coef, entropy_coef, max_grad_norm, original_learning_rate, original_clip_param;
2020
int num_epoch, num_mini_batch;
21-
std::unique_ptr<torch::optim::Optimizer> optimizer;
21+
std::unique_ptr<torch::optim::Adam> optimizer;
2222

2323
public:
2424
PPO(Policy &policy,
@@ -31,6 +31,6 @@ class PPO : public Algorithm
3131
float epsilon = 1e-8,
3232
float max_grad_norm = 0.5);
3333

34-
std::vector<UpdateDatum> update(RolloutStorage &rollouts);
34+
std::vector<UpdateDatum> update(RolloutStorage &rollouts, float decay_level = 1);
3535
};
3636
}

src/algorithms/a2c.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,18 @@ A2C::A2C(Policy &policy,
2424
value_loss_coef(value_loss_coef),
2525
entropy_coef(entropy_coef),
2626
max_grad_norm(max_grad_norm),
27+
original_learning_rate(learning_rate),
2728
optimizer(std::make_unique<torch::optim::RMSprop>(
2829
policy->parameters(),
2930
torch::optim::RMSpropOptions(learning_rate)
3031
.eps(epsilon)
3132
.alpha(alpha))) {}
3233

33-
std::vector<UpdateDatum> A2C::update(RolloutStorage &rollouts)
34+
std::vector<UpdateDatum> A2C::update(RolloutStorage &rollouts, float decay_level)
3435
{
36+
// Decay learning rate
37+
optimizer->options.learning_rate_ = original_learning_rate * decay_level;
38+
3539
// Prep work
3640
auto full_obs_shape = rollouts.get_observations().sizes();
3741
std::vector<int64_t> obs_shape(full_obs_shape.begin() + 2,
@@ -79,9 +83,9 @@ std::vector<UpdateDatum> A2C::update(RolloutStorage &rollouts)
7983

8084
TEST_CASE("A2C")
8185
{
82-
torch::manual_seed(0);
8386
SUBCASE("update() learns basic pattern")
8487
{
88+
torch::manual_seed(0);
8589
auto base = std::make_shared<MlpBase>(1, false, 5);
8690
ActionSpace space{"Discrete", {2}};
8791
Policy policy(space, base);
@@ -151,6 +155,7 @@ TEST_CASE("A2C")
151155

152156
SUBCASE("update() learns basic game")
153157
{
158+
torch::manual_seed(0);
154159
auto base = std::make_shared<MlpBase>(1, false, 5);
155160
ActionSpace space{"Discrete", {2}};
156161
Policy policy(space, base);

src/algorithms/ppo.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,24 @@ PPO::PPO(Policy &policy,
2626
float epsilon,
2727
float max_grad_norm)
2828
: policy(policy),
29-
clip_param(clip_param),
3029
value_loss_coef(value_loss_coef),
3130
entropy_coef(entropy_coef),
3231
max_grad_norm(max_grad_norm),
32+
original_learning_rate(learning_rate),
33+
original_clip_param(clip_param),
3334
num_epoch(num_epoch),
3435
num_mini_batch(num_mini_batch),
3536
optimizer(std::make_unique<torch::optim::Adam>(
3637
policy->parameters(),
3738
torch::optim::AdamOptions(learning_rate)
3839
.eps(epsilon))) {}
3940

40-
std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts)
41+
std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level)
4142
{
43+
// Decay lr and clip parameter
44+
float clip_param = original_clip_param * decay_level;
45+
optimizer->options.learning_rate_ = original_learning_rate * decay_level;
46+
4247
// Calculate advantages
4348
auto returns = rollouts.get_returns();
4449
auto value_preds = rollouts.get_value_predictions();

0 commit comments

Comments
 (0)