Skip to content
This repository was archived by the owner on Dec 28, 2023. It is now read-only.

Commit 9a0cfb0

Browse files
author
Omegastick
committed
Add actor loss coefficient hyperparameter to A2C
1 parent f51c086 commit 9a0cfb0

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

example/gym_client.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ int main(int argc, char *argv[])
131131
std::unique_ptr<Algorithm> algo;
132132
if (algorithm == "A2C")
133133
{
134-
algo = std::make_unique<A2C>(policy, value_loss_coef, entropy_coef, learning_rate);
134+
algo = std::make_unique<A2C>(policy, actor_loss_coef, value_loss_coef, entropy_coef, learning_rate);
135135
}
136136
else if (algorithm == "PPO")
137137
{

include/cpprl/algorithms/a2c.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ class A2C : public Algorithm
1616
{
1717
private:
1818
Policy &policy;
19-
float value_loss_coef, entropy_coef, max_grad_norm, original_learning_rate;
19+
float actor_loss_coef, value_loss_coef, entropy_coef, max_grad_norm, original_learning_rate;
2020
std::unique_ptr<torch::optim::RMSprop> optimizer;
2121

2222
public:
2323
A2C(Policy &policy,
24+
float actor_loss_coef,
2425
float value_loss_coef,
2526
float entropy_coef,
2627
float learning_rate,

src/algorithms/a2c.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
namespace cpprl
1515
{
1616
A2C::A2C(Policy &policy,
17+
float actor_loss_coef,
1718
float value_loss_coef,
1819
float entropy_coef,
1920
float learning_rate,
2021
float epsilon,
2122
float alpha,
2223
float max_grad_norm)
2324
: policy(policy),
25+
actor_loss_coef(actor_loss_coef),
2426
value_loss_coef(value_loss_coef),
2527
entropy_coef(entropy_coef),
2628
max_grad_norm(max_grad_norm),
@@ -90,7 +92,7 @@ TEST_CASE("A2C")
9092
ActionSpace space{"Discrete", {2}};
9193
Policy policy(space, base);
9294
RolloutStorage storage(5, 2, {1}, space, 5, torch::kCPU);
93-
A2C a2c(policy, 0.5, 1e-3, 0.001);
95+
A2C a2c(policy, 1, 0.5, 1e-3, 0.001);
9496

9597
// The reward is the action
9698
auto pre_game_probs = policy->get_probs(
@@ -160,7 +162,7 @@ TEST_CASE("A2C")
160162
ActionSpace space{"Discrete", {2}};
161163
Policy policy(space, base);
162164
RolloutStorage storage(5, 2, {1}, space, 5, torch::kCPU);
163-
A2C a2c(policy, 0.5, 1e-7, 0.0001);
165+
A2C a2c(policy, 1, 0.5, 1e-7, 0.0001);
164166

165167
// The game is: If the action matches the input, give a reward of 1, otherwise -1
166168
auto pre_game_probs = policy->get_probs(

0 commit comments

Comments
 (0)