@@ -61,6 +61,7 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
6161 float total_entropy = 0 ;
6262 float kl_divergence = 0 ;
6363 float kl_early_stopped = -1 ;
64+ float clip_fraction = 0 ;
6465 int num_updates = 0 ;
6566
6667 // Epoch loop
@@ -107,11 +108,19 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
107108 mini_batch.action_log_probs );
108109
109110 // PPO loss formula
110- auto surr_1 = ratio * mini_batch.advantages ;
111+ auto surr_1 = ratio * mini_batch.advantages . mean () ;
111112 auto surr_2 = (torch::clamp (ratio,
112113 1.0 - clip_param,
113114 1.0 + clip_param) *
114- mini_batch.advantages );
115+ mini_batch.advantages )
116+ .mean ();
117+ clip_fraction += (ratio - 1.0 )
118+ .abs ()
119+ .gt (clip_param)
120+ .to (torch::kFloat )
121+ .mean ()
122+ .item ()
123+ .toFloat ();
115124 auto action_loss = -torch::min (surr_1, surr_2).mean ();
116125
117126 // Value loss
@@ -148,11 +157,13 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
148157 total_value_loss /= num_updates;
149158 total_action_loss /= num_updates;
150159 total_entropy /= num_updates;
160+ clip_fraction /= num_updates;
151161
152162 if (kl_early_stopped > -1 )
153163 {
154164 return {{" Value loss" , total_value_loss},
155165 {" Action loss" , total_action_loss},
166+ {" Clip fraction" , clip_fraction},
156167 {" Entropy" , total_entropy},
157168 {" KL divergence" , kl_divergence},
158169 {" KL divergence early stop update" , kl_early_stopped}};
@@ -161,6 +172,7 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
161172 {
162173 return {{" Value loss" , total_value_loss},
163174 {" Action loss" , total_action_loss},
175+ {" Clip fraction" , clip_fraction},
164176 {" Entropy" , total_entropy},
165177 {" KL divergence" , kl_divergence}};
166178 }
0 commit comments