Skip to content

Commit 37942bc

Browse files
authored
Buffer samples based on group level stds. (#4492)
1 parent 66cd02a commit 37942bc

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

tests/experimental/test_grpo_with_replay_buffer_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,9 @@ def test_update_with_inputs_different_seq_len(self):
250250

251251

252252
@pytest.mark.low_priority
253+
@pytest.mark.parametrize("scale_rewards", ["batch", "group"])
253254
class TestGRPOWithReplayBufferTrainer(TrlTestCase):
254-
def test_training_with_replay_buffer(self):
255+
def test_training_with_replay_buffer(self, scale_rewards):
255256
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
256257

257258
# Guarantee that some rewards have 0 std
@@ -269,6 +270,7 @@ def custom_reward_func(completions, **kwargs):
269270
max_completion_length=8, # reduce the completion length to reduce memory usage
270271
replay_buffer_size=8,
271272
report_to="none",
273+
scale_rewards=scale_rewards,
272274
)
273275
trainer = GRPOWithReplayBufferTrainer(
274276
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",

trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,12 @@ def _generate_and_score_completions(
238238
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
239239
advantages = rewards - mean_grouped_rewards
240240

241+
grouped_std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
242+
grouped_std_rewards = grouped_std_rewards.repeat_interleave(self.num_generations, dim=0)
243+
241244
if self.scale_rewards in ["group", "none"]:
242245
# If self.scale_rewards = "none", we'll still log group level std
243-
std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
244-
std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0)
246+
std_rewards = grouped_std_rewards.clone()
245247
elif self.scale_rewards == "batch":
246248
# Compute global std
247249
std_rewards = rewards.std().expand_as(rewards)
@@ -261,7 +263,7 @@ def _generate_and_score_completions(
261263
)
262264
all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
263265
advantages = advantages[process_slice]
264-
std_rewards = std_rewards[process_slice]
266+
grouped_std_rewards = grouped_std_rewards[process_slice]
265267

266268
# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
267269
for i, reward_func_name in enumerate(self.reward_func_names):
@@ -316,7 +318,7 @@ def _generate_and_score_completions(
316318
)
317319
outputs_after_sampling_buffer = self.update_with_replay_buffer(
318320
advantages,
319-
std_rewards,
321+
grouped_std_rewards,
320322
prompt_ids,
321323
prompt_mask,
322324
completion_ids,

0 commit comments

Comments
 (0)