Skip to content

Commit 206964c

Browse files
kashifqgallouedec
andauthored
🎢 [Callbacks] BEMA (#3855)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent 39efa8a commit 206964c

File tree

6 files changed

+346
-7
lines changed

6 files changed

+346
-7
lines changed

docs/source/callbacks.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,8 @@
1818

1919
## MergeModelCallback
2020

21-
[[autodoc]] MergeModelCallback
21+
[[autodoc]] MergeModelCallback
22+
23+
## BEMACallback
24+
25+
[[autodoc]] BEMACallback

docs/source/paper_index.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,19 @@ training_args = GRPOConfig(
2424
gradient_accumulation_steps=1,
2525
steps_per_generation=4, # partition rollout batch into 4 mini-batches. GSPO paper (v2), section 5.1. Must be 4 times gradient_accumulation_steps
2626
)
27-
```
27+
```
28+
29+
## EMA Without the Lag: Bias-Corrected Iterate Averaging Schemes
30+
31+
**📜 Paper**: https://huggingface.co/papers/2508.00180
32+
33+
Bias-Corrected Exponential Moving Average (BEMA) improves the stability and efficiency of language model fine-tuning by reducing stochasticity and eliminating bias. To use BEMA with SFT as described in the paper, you can use the [`BEMACallback`]:
34+
35+
```python
36+
from trl import BEMACallback, SFTTrainer
37+
38+
trainer = SFTTrainer(
39+
...
40+
callbacks=[BEMACallback()],
41+
)
42+
```

tests/test_callbacks.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import json
1616
import os
17+
from unittest.mock import call, patch
1718

1819
from datasets import load_dataset
1920
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments
@@ -22,7 +23,15 @@
2223
from transformers.utils import is_peft_available
2324

2425
from tests.testing_utils import require_comet, require_mergekit
25-
from trl import BasePairwiseJudge, DPOConfig, DPOTrainer, LogCompletionsCallback, MergeModelCallback, WinRateCallback
26+
from trl import (
27+
BasePairwiseJudge,
28+
BEMACallback,
29+
DPOConfig,
30+
DPOTrainer,
31+
LogCompletionsCallback,
32+
MergeModelCallback,
33+
WinRateCallback,
34+
)
2635
from trl.mergekit_utils import MergeConfig
2736

2837
from .testing_utils import TrlTestCase
@@ -362,3 +371,125 @@ def test_every_checkpoint(self):
362371
for checkpoint in checkpoints:
363372
merged_path = os.path.join(checkpoint, "merged")
364373
self.assertTrue(os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}.")
374+
375+
376+
class BEMACallbackTester(TrlTestCase):
377+
def setUp(self):
378+
super().setUp()
379+
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
380+
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
381+
self.tokenizer.pad_token = self.tokenizer.eos_token
382+
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
383+
384+
def tokenize_function(examples, tokenizer):
385+
out = tokenizer(examples["text"], padding="max_length", max_length=17)
386+
out["labels"] = out["input_ids"].copy()
387+
return out
388+
389+
self.dataset = dataset.map(
390+
tokenize_function, fn_kwargs={"tokenizer": self.tokenizer}, remove_columns=["text"], batched=True
391+
)
392+
393+
def test_model_saved(self):
394+
"""Test that BEMACallback saves the BEMA model."""
395+
training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
396+
bema_callback = BEMACallback(update_freq=2)
397+
trainer = Trainer(
398+
model=self.model,
399+
args=training_args,
400+
train_dataset=self.dataset["train"],
401+
processing_class=self.tokenizer,
402+
callbacks=[bema_callback],
403+
)
404+
trainer.train()
405+
406+
# Check that the BEMA model was saved and can be loaded
407+
bema_path = os.path.join(self.tmp_dir, "bema")
408+
self.assertTrue(os.path.isdir(bema_path), "BEMA directory was not created")
409+
AutoModelForCausalLM.from_pretrained(bema_path)
410+
411+
def test_update_frequency_0(self):
412+
"""Test that BEMA callback respects the update frequency."""
413+
training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
414+
bema_callback = BEMACallback(update_freq=2)
415+
416+
with patch.object(bema_callback, "_update_bema_weights") as mock_update:
417+
trainer = Trainer(
418+
model=self.model,
419+
args=training_args,
420+
train_dataset=self.dataset["train"],
421+
processing_class=self.tokenizer,
422+
callbacks=[bema_callback],
423+
)
424+
425+
trainer.train()
426+
427+
# Total 9 steps (17 samples, batch size 8, 3 epochs).
428+
# BEMA starts after step 0 and updates every 2 steps → updates at 2, 4, 5, 8
429+
self.assertEqual(mock_update.call_args_list, [call(2), call(4), call(6), call(8)])
430+
431+
def test_update_frequency_1(self):
432+
"""Test that BEMA callback respects the update frequency."""
433+
training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
434+
bema_callback = BEMACallback(update_freq=3)
435+
436+
with patch.object(bema_callback, "_update_bema_weights") as mock_update:
437+
trainer = Trainer(
438+
model=self.model,
439+
args=training_args,
440+
train_dataset=self.dataset["train"],
441+
processing_class=self.tokenizer,
442+
callbacks=[bema_callback],
443+
)
444+
445+
trainer.train()
446+
447+
# Total 9 steps (17 samples, batch size 8, 3 epochs).
448+
# BEMA starts after step 0 and updates every 3 steps → updates at 3, 6, 9
449+
self.assertEqual(mock_update.call_args_list, [call(3), call(6), call(9)])
450+
451+
def test_update_frequency_2(self):
452+
"""Test that BEMA callback respects the update frequency."""
453+
training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
454+
bema_callback = BEMACallback(update_freq=2, update_after=3)
455+
456+
with patch.object(bema_callback, "_update_bema_weights") as mock_update:
457+
trainer = Trainer(
458+
model=self.model,
459+
args=training_args,
460+
train_dataset=self.dataset["train"],
461+
processing_class=self.tokenizer,
462+
callbacks=[bema_callback],
463+
)
464+
465+
trainer.train()
466+
467+
# Total 9 steps (17 samples, batch size 8, 3 epochs).
468+
# BEMA starts after step 3 and updates every 2 steps → updates at 5, 7, 9
469+
self.assertEqual(mock_update.call_args_list, [call(5), call(7), call(9)])
470+
471+
def test_no_bema(self):
472+
"""Test that BEMACallback works without BEMA updates."""
473+
training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
474+
bema_callback = BEMACallback(update_freq=2, bias_power=0.0)
475+
trainer = Trainer(
476+
model=self.model,
477+
args=training_args,
478+
train_dataset=self.dataset["train"],
479+
processing_class=self.tokenizer,
480+
callbacks=[bema_callback],
481+
)
482+
trainer.train()
483+
484+
def test_no_ema(self):
485+
"""Test that BEMACallback works without EMA updates."""
486+
training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
487+
bema_callback = BEMACallback(update_freq=2, ema_power=0.0)
488+
trainer = Trainer(
489+
model=self.model,
490+
args=training_args,
491+
train_dataset=self.dataset["train"],
492+
processing_class=self.tokenizer,
493+
callbacks=[bema_callback],
494+
)
495+
trainer.train()

trl/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
"KTOConfig",
7070
"KTOTrainer",
7171
"LogCompletionsCallback",
72-
"MergeModelCallback",
7372
"ModelConfig",
7473
"NashMDConfig",
7574
"NashMDTrainer",
@@ -93,7 +92,7 @@
9392
"XPOConfig",
9493
"XPOTrainer",
9594
],
96-
"trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"],
95+
"trainer.callbacks": ["BEMACallback", "MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"],
9796
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"],
9897
}
9998

@@ -163,7 +162,6 @@
163162
KTOConfig,
164163
KTOTrainer,
165164
LogCompletionsCallback,
166-
MergeModelCallback,
167165
ModelConfig,
168166
NashMDConfig,
169167
NashMDTrainer,
@@ -187,7 +185,7 @@
187185
XPOConfig,
188186
XPOTrainer,
189187
)
190-
from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback
188+
from .trainer.callbacks import BEMACallback, MergeModelCallback, RichProgressCallback, SyncRefModelCallback
191189
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
192190

193191
try:

trl/trainer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"bco_config": ["BCOConfig"],
2424
"bco_trainer": ["BCOTrainer"],
2525
"callbacks": [
26+
"BEMACallback",
2627
"LogCompletionsCallback",
2728
"MergeModelCallback",
2829
"RichProgressCallback",
@@ -93,6 +94,7 @@
9394
from .bco_config import BCOConfig
9495
from .bco_trainer import BCOTrainer
9596
from .callbacks import (
97+
BEMACallback,
9698
LogCompletionsCallback,
9799
MergeModelCallback,
98100
RichProgressCallback,

0 commit comments

Comments
 (0)