Skip to content
This repository was archived by the owner on Dec 28, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions example/gym_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,30 @@ using namespace cpprl;

// Algorithm hyperparameters
const std::string algorithm = "PPO";
const int batch_size = 512;
const int batch_size = 256;
const float clip_param = 0.2;
const float discount_factor = 0.99;
const float entropy_coef = 1e-5;
const float entropy_coef = 0.0;
const float gae = 0.95;
const float learning_rate = 3e-4;
const int log_interval = 1;
const int num_epoch = 3;
const int num_mini_batch = 32;
const int max_frames = 1e+7;
const int num_epoch = 10;
const int num_mini_batch = 8;
const int reward_average_window_size = 10;
const bool use_gae = true;
const bool use_lr_decay = true;
const float value_loss_coef = 0.5;

// Environment hyperparameters
const float env_gamma = discount_factor; // Set to -1 to disable
const std::string env_name = "BipedalWalker-v2";
const std::string env_name = "BipedalWalkerHardcore-v2";
const int num_envs = 8;
const float render_reward_threshold = 300;
const float render_reward_threshold = 250;

// Model hyperparameters
const int hidden_size = 64;
const bool recurrent = false;
const bool recurrent = true;
const bool use_cuda = false;

std::vector<float> flatten_vector(std::vector<float> const &input)
Expand Down Expand Up @@ -143,7 +145,8 @@ int main(int argc, char *argv[])

auto start_time = std::chrono::high_resolution_clock::now();

for (int update = 0; update < 100000; ++update)
int num_updates = max_frames / (batch_size * num_envs);
for (int update = 0; update < num_updates; ++update)
{
for (int step = 0; step < batch_size; ++step)
{
Expand Down Expand Up @@ -227,7 +230,16 @@ int main(int argc, char *argv[])
}
storage.compute_returns(next_value, use_gae, discount_factor, gae);

auto update_data = algo->update(storage);
float decay_level;
if (use_lr_decay)
{
decay_level = 1. - static_cast<float>(update) / num_updates;
}
else
{
decay_level = 1;
}
auto update_data = algo->update(storage, decay_level);
storage.after_update();

if (update % log_interval == 0 && update > 0)
Expand All @@ -237,7 +249,7 @@ int main(int argc, char *argv[])
auto run_time_secs = std::chrono::duration_cast<std::chrono::seconds>(run_time);
auto fps = total_steps / (run_time_secs.count() + 1e-9);
spdlog::info("---");
spdlog::info("Update: {}", update);
spdlog::info("Update: {}/{}", update, num_updates);
spdlog::info("Total frames: {}", total_steps);
spdlog::info("FPS: {}", fps);
for (const auto &datum : update_data)
Expand Down
6 changes: 3 additions & 3 deletions include/cpprl/algorithms/a2c.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class A2C : public Algorithm
{
private:
Policy &policy;
float value_loss_coef, entropy_coef, max_grad_norm;
std::unique_ptr<torch::optim::Optimizer> optimizer;
float value_loss_coef, entropy_coef, max_grad_norm, original_learning_rate;
std::unique_ptr<torch::optim::RMSprop> optimizer;

public:
A2C(Policy &policy,
Expand All @@ -28,6 +28,6 @@ class A2C : public Algorithm
float alpha = 0.99,
float max_grad_norm = 0.5);

std::vector<UpdateDatum> update(RolloutStorage &rollouts);
std::vector<UpdateDatum> update(RolloutStorage &rollouts, float decay_level = 1);
};
}
2 changes: 1 addition & 1 deletion include/cpprl/algorithms/algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Algorithm
public:
virtual ~Algorithm() = 0;

virtual std::vector<UpdateDatum> update(RolloutStorage &rollouts) = 0;
virtual std::vector<UpdateDatum> update(RolloutStorage &rollouts, float decay_level = 1) = 0;
};

inline Algorithm::~Algorithm() {}
Expand Down
6 changes: 3 additions & 3 deletions include/cpprl/algorithms/ppo.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ class PPO : public Algorithm
{
private:
Policy &policy;
float clip_param, value_loss_coef, entropy_coef, max_grad_norm;
float value_loss_coef, entropy_coef, max_grad_norm, original_learning_rate, original_clip_param;
int num_epoch, num_mini_batch;
std::unique_ptr<torch::optim::Optimizer> optimizer;
std::unique_ptr<torch::optim::Adam> optimizer;

public:
PPO(Policy &policy,
Expand All @@ -31,6 +31,6 @@ class PPO : public Algorithm
float epsilon = 1e-8,
float max_grad_norm = 0.5);

std::vector<UpdateDatum> update(RolloutStorage &rollouts);
std::vector<UpdateDatum> update(RolloutStorage &rollouts, float decay_level = 1);
};
}
9 changes: 7 additions & 2 deletions src/algorithms/a2c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@ A2C::A2C(Policy &policy,
value_loss_coef(value_loss_coef),
entropy_coef(entropy_coef),
max_grad_norm(max_grad_norm),
original_learning_rate(learning_rate),
optimizer(std::make_unique<torch::optim::RMSprop>(
policy->parameters(),
torch::optim::RMSpropOptions(learning_rate)
.eps(epsilon)
.alpha(alpha))) {}

std::vector<UpdateDatum> A2C::update(RolloutStorage &rollouts)
std::vector<UpdateDatum> A2C::update(RolloutStorage &rollouts, float decay_level)
{
// Decay learning rate
optimizer->options.learning_rate_ = original_learning_rate * decay_level;

// Prep work
auto full_obs_shape = rollouts.get_observations().sizes();
std::vector<int64_t> obs_shape(full_obs_shape.begin() + 2,
Expand Down Expand Up @@ -79,9 +83,9 @@ std::vector<UpdateDatum> A2C::update(RolloutStorage &rollouts)

TEST_CASE("A2C")
{
torch::manual_seed(0);
SUBCASE("update() learns basic pattern")
{
torch::manual_seed(0);
auto base = std::make_shared<MlpBase>(1, false, 5);
ActionSpace space{"Discrete", {2}};
Policy policy(space, base);
Expand Down Expand Up @@ -151,6 +155,7 @@ TEST_CASE("A2C")

SUBCASE("update() learns basic game")
{
torch::manual_seed(0);
auto base = std::make_shared<MlpBase>(1, false, 5);
ActionSpace space{"Discrete", {2}};
Policy policy(space, base);
Expand Down
9 changes: 7 additions & 2 deletions src/algorithms/ppo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,24 @@ PPO::PPO(Policy &policy,
float epsilon,
float max_grad_norm)
: policy(policy),
clip_param(clip_param),
value_loss_coef(value_loss_coef),
entropy_coef(entropy_coef),
max_grad_norm(max_grad_norm),
original_learning_rate(learning_rate),
original_clip_param(clip_param),
num_epoch(num_epoch),
num_mini_batch(num_mini_batch),
optimizer(std::make_unique<torch::optim::Adam>(
policy->parameters(),
torch::optim::AdamOptions(learning_rate)
.eps(epsilon))) {}

std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts)
std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level)
{
// Decay lr and clip parameter
float clip_param = original_clip_param * decay_level;
optimizer->options.learning_rate_ = original_learning_rate * decay_level;

// Calculate advantages
auto returns = rollouts.get_returns();
auto value_preds = rollouts.get_value_predictions();
Expand Down