@@ -238,10 +238,12 @@ def _generate_and_score_completions(
238238 mean_grouped_rewards = mean_grouped_rewards .repeat_interleave (self .num_generations , dim = 0 )
239239 advantages = rewards - mean_grouped_rewards
240240
241+ grouped_std_rewards = rewards .view (- 1 , self .num_generations ).std (dim = 1 )
242+ grouped_std_rewards = grouped_std_rewards .repeat_interleave (self .num_generations , dim = 0 )
243+
241244 if self .scale_rewards in ["group" , "none" ]:
242245 # If self.scale_rewards = "none", we'll still log group level std
243- std_rewards = rewards .view (- 1 , self .num_generations ).std (dim = 1 )
244- std_rewards = std_rewards .repeat_interleave (self .num_generations , dim = 0 )
246+ std_rewards = grouped_std_rewards .clone ()
245247 elif self .scale_rewards == "batch" :
246248 # Compute global std
247249 std_rewards = rewards .std ().expand_as (rewards )
@@ -261,7 +263,7 @@ def _generate_and_score_completions(
261263 )
262264 all_process_advantages = advantages .clone () # keep the aggregated advantages for logging
263265 advantages = advantages [process_slice ]
264- std_rewards = std_rewards [process_slice ]
266+ grouped_std_rewards = grouped_std_rewards [process_slice ]
265267
266268 # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
267269 for i , reward_func_name in enumerate (self .reward_func_names ):
@@ -316,7 +318,7 @@ def _generate_and_score_completions(
316318 )
317319 outputs_after_sampling_buffer = self .update_with_replay_buffer (
318320 advantages ,
319- std_rewards ,
321+ grouped_std_rewards ,
320322 prompt_ids ,
321323 prompt_mask ,
322324 completion_ids ,
0 commit comments