Skip to content
This repository was archived by the owner on Dec 28, 2023. It is now read-only.

Commit 0cf3556

Browse files
author
Omegastick
committed
Add KL divergence cap
1 parent 3bbd5be commit 0cf3556

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

include/cpprl/algorithms/ppo.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class PPO : public Algorithm
1616
{
1717
private:
1818
Policy &policy;
19-
float actor_loss_coef, value_loss_coef, entropy_coef, max_grad_norm, original_learning_rate, original_clip_param;
19+
float actor_loss_coef, value_loss_coef, entropy_coef, max_grad_norm, original_learning_rate, original_clip_param, kl_target;
2020
int num_epoch, num_mini_batch;
2121
std::unique_ptr<torch::optim::Adam> optimizer;
2222

@@ -30,7 +30,8 @@ class PPO : public Algorithm
3030
float entropy_coef,
3131
float learning_rate,
3232
float epsilon = 1e-8,
33-
float max_grad_norm = 0.5);
33+
float max_grad_norm = 0.5,
34+
float kl_target = 0.01);
3435

3536
std::vector<UpdateDatum> update(RolloutStorage &rollouts, float decay_level = 1);
3637
};

src/algorithms/ppo.cpp

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,16 @@ PPO::PPO(Policy &policy,
2424
float entropy_coef,
2525
float learning_rate,
2626
float epsilon,
27-
float max_grad_norm)
27+
float max_grad_norm,
28+
float kl_target)
2829
: policy(policy),
2930
actor_loss_coef(actor_loss_coef),
3031
value_loss_coef(value_loss_coef),
3132
entropy_coef(entropy_coef),
3233
max_grad_norm(max_grad_norm),
3334
original_learning_rate(learning_rate),
3435
original_clip_param(clip_param),
36+
kl_target(kl_target),
3537
num_epoch(num_epoch),
3638
num_mini_batch(num_mini_batch),
3739
optimizer(std::make_unique<torch::optim::Adam>(
@@ -57,6 +59,9 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
5759
float total_value_loss = 0;
5860
float total_action_loss = 0;
5961
float total_entropy = 0;
62+
float kl_divergence = 0;
63+
float kl_early_stopped = -1;
64+
int num_updates = 0;
6065

6166
// Epoch loop
6267
for (int epoch = 0; epoch < num_epoch; ++epoch)
@@ -86,6 +91,17 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
8691
mini_batch.masks,
8792
mini_batch.actions);
8893

94+
// Calculate approximate KL divergence for info and early stopping
95+
kl_divergence = (mini_batch.action_log_probs - evaluate_result[1])
96+
.mean()
97+
.item()
98+
.toFloat();
99+
if (kl_divergence > kl_target * 1.5)
100+
{
101+
kl_early_stopped = num_updates;
102+
goto finish_update;
103+
}
104+
89105
// Calculate difference ratio between old and new action probabilites
90106
auto ratio = torch::exp(evaluate_result[1] -
91107
mini_batch.action_log_probs);
@@ -114,22 +130,34 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
114130
loss.backward();
115131
// TODO: Implement gradient norm clipping
116132
optimizer->step();
133+
num_updates++;
117134

118135
total_value_loss += value_loss.item().toFloat();
119136
total_action_loss += action_loss.item().toFloat();
120137
total_entropy += evaluate_result[2].item().toFloat();
121138
}
122139
}
123140

124-
auto num_updates = num_epoch * num_mini_batch;
125-
141+
finish_update:
126142
total_value_loss /= num_updates;
127143
total_action_loss /= num_updates;
128144
total_entropy /= num_updates;
129145

130-
return {{"Value loss", total_value_loss},
131-
{"Action loss", total_action_loss},
132-
{"Entropy", total_entropy}};
146+
if (kl_early_stopped > -1)
147+
{
148+
return {{"Value loss", total_value_loss},
149+
{"Action loss", total_action_loss},
150+
{"Entropy", total_entropy},
151+
{"KL divergence", kl_divergence},
152+
{"KL divergence early stop update", kl_early_stopped}};
153+
}
154+
else
155+
{
156+
return {{"Value loss", total_value_loss},
157+
{"Action loss", total_action_loss},
158+
{"Entropy", total_entropy},
159+
{"KL divergence", kl_divergence}};
160+
}
133161
}
134162

135163
TEST_CASE("PPO")

0 commit comments

Comments
 (0)