@@ -24,14 +24,16 @@ PPO::PPO(Policy &policy,
2424 float entropy_coef,
2525 float learning_rate,
2626 float epsilon,
27- float max_grad_norm)
27+ float max_grad_norm,
28+ float kl_target)
2829 : policy(policy),
2930 actor_loss_coef (actor_loss_coef),
3031 value_loss_coef(value_loss_coef),
3132 entropy_coef(entropy_coef),
3233 max_grad_norm(max_grad_norm),
3334 original_learning_rate(learning_rate),
3435 original_clip_param(clip_param),
36+ kl_target(kl_target),
3537 num_epoch(num_epoch),
3638 num_mini_batch(num_mini_batch),
3739 optimizer(std::make_unique<torch::optim::Adam>(
@@ -57,6 +59,9 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
5759 float total_value_loss = 0 ;
5860 float total_action_loss = 0 ;
5961 float total_entropy = 0 ;
62+ float kl_divergence = 0 ;
63+ float kl_early_stopped = -1 ;
64+ int num_updates = 0 ;
6065
6166 // Epoch loop
6267 for (int epoch = 0 ; epoch < num_epoch; ++epoch)
@@ -86,6 +91,17 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
8691 mini_batch.masks ,
8792 mini_batch.actions );
8893
94+ // Calculate approximate KL divergence for info and early stopping
95+ kl_divergence = (mini_batch.action_log_probs - evaluate_result[1 ])
96+ .mean ()
97+ .item ()
98+ .toFloat ();
99+ if (kl_divergence > kl_target * 1.5 )
100+ {
101+ kl_early_stopped = num_updates;
102+ goto finish_update;
103+ }
104+
89105 // Calculate difference ratio between old and new action probabilites
90106 auto ratio = torch::exp (evaluate_result[1 ] -
91107 mini_batch.action_log_probs );
@@ -114,22 +130,34 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
114130 loss.backward ();
115131 // TODO: Implement gradient norm clipping
116132 optimizer->step ();
133+ num_updates++;
117134
118135 total_value_loss += value_loss.item ().toFloat ();
119136 total_action_loss += action_loss.item ().toFloat ();
120137 total_entropy += evaluate_result[2 ].item ().toFloat ();
121138 }
122139 }
123140
124- auto num_updates = num_epoch * num_mini_batch;
125-
141+ finish_update:
126142 total_value_loss /= num_updates;
127143 total_action_loss /= num_updates;
128144 total_entropy /= num_updates;
129145
130- return {{" Value loss" , total_value_loss},
131- {" Action loss" , total_action_loss},
132- {" Entropy" , total_entropy}};
146+ if (kl_early_stopped > -1 )
147+ {
148+ return {{" Value loss" , total_value_loss},
149+ {" Action loss" , total_action_loss},
150+ {" Entropy" , total_entropy},
151+ {" KL divergence" , kl_divergence},
152+ {" KL divergence early stop update" , kl_early_stopped}};
153+ }
154+ else
155+ {
156+ return {{" Value loss" , total_value_loss},
157+ {" Action loss" , total_action_loss},
158+ {" Entropy" , total_entropy},
159+ {" KL divergence" , kl_divergence}};
160+ }
133161}
134162
135163TEST_CASE (" PPO" )
0 commit comments