|
14 | 14 |
|
15 | 15 | import json |
16 | 16 | import os |
| 17 | +from unittest.mock import call, patch |
17 | 18 |
|
18 | 19 | from datasets import load_dataset |
19 | 20 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments |
|
22 | 23 | from transformers.utils import is_peft_available |
23 | 24 |
|
24 | 25 | from tests.testing_utils import require_comet, require_mergekit |
25 | | -from trl import BasePairwiseJudge, DPOConfig, DPOTrainer, LogCompletionsCallback, MergeModelCallback, WinRateCallback |
| 26 | +from trl import ( |
| 27 | + BasePairwiseJudge, |
| 28 | + BEMACallback, |
| 29 | + DPOConfig, |
| 30 | + DPOTrainer, |
| 31 | + LogCompletionsCallback, |
| 32 | + MergeModelCallback, |
| 33 | + WinRateCallback, |
| 34 | +) |
26 | 35 | from trl.mergekit_utils import MergeConfig |
27 | 36 |
|
28 | 37 | from .testing_utils import TrlTestCase |
@@ -362,3 +371,125 @@ def test_every_checkpoint(self): |
362 | 371 | for checkpoint in checkpoints: |
363 | 372 | merged_path = os.path.join(checkpoint, "merged") |
364 | 373 | self.assertTrue(os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}.") |
| 374 | + |
| 375 | + |
| 376 | +class BEMACallbackTester(TrlTestCase): |
| 377 | + def setUp(self): |
| 378 | + super().setUp() |
| 379 | + self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") |
| 380 | + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") |
| 381 | + self.tokenizer.pad_token = self.tokenizer.eos_token |
| 382 | + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling") |
| 383 | + |
| 384 | + def tokenize_function(examples, tokenizer): |
| 385 | + out = tokenizer(examples["text"], padding="max_length", max_length=17) |
| 386 | + out["labels"] = out["input_ids"].copy() |
| 387 | + return out |
| 388 | + |
| 389 | + self.dataset = dataset.map( |
| 390 | + tokenize_function, fn_kwargs={"tokenizer": self.tokenizer}, remove_columns=["text"], batched=True |
| 391 | + ) |
| 392 | + |
| 393 | + def test_model_saved(self): |
| 394 | + """Test that BEMACallback saves the BEMA model.""" |
| 395 | + training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none") |
| 396 | + bema_callback = BEMACallback(update_freq=2) |
| 397 | + trainer = Trainer( |
| 398 | + model=self.model, |
| 399 | + args=training_args, |
| 400 | + train_dataset=self.dataset["train"], |
| 401 | + processing_class=self.tokenizer, |
| 402 | + callbacks=[bema_callback], |
| 403 | + ) |
| 404 | + trainer.train() |
| 405 | + |
| 406 | + # Check that the BEMA model was saved and can be loaded |
| 407 | + bema_path = os.path.join(self.tmp_dir, "bema") |
| 408 | + self.assertTrue(os.path.isdir(bema_path), "BEMA directory was not created") |
| 409 | + AutoModelForCausalLM.from_pretrained(bema_path) |
| 410 | + |
| 411 | + def test_update_frequency_0(self): |
| 412 | + """Test that BEMA callback respects the update frequency.""" |
| 413 | + training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none") |
| 414 | + bema_callback = BEMACallback(update_freq=2) |
| 415 | + |
| 416 | + with patch.object(bema_callback, "_update_bema_weights") as mock_update: |
| 417 | + trainer = Trainer( |
| 418 | + model=self.model, |
| 419 | + args=training_args, |
| 420 | + train_dataset=self.dataset["train"], |
| 421 | + processing_class=self.tokenizer, |
| 422 | + callbacks=[bema_callback], |
| 423 | + ) |
| 424 | + |
| 425 | + trainer.train() |
| 426 | + |
| 427 | + # Total 9 steps (17 samples, batch size 8, 3 epochs). |
| 428 | + # BEMA starts after step 0 and updates every 2 steps → updates at 2, 4, 5, 8 |
| 429 | + self.assertEqual(mock_update.call_args_list, [call(2), call(4), call(6), call(8)]) |
| 430 | + |
| 431 | + def test_update_frequency_1(self): |
| 432 | + """Test that BEMA callback respects the update frequency.""" |
| 433 | + training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none") |
| 434 | + bema_callback = BEMACallback(update_freq=3) |
| 435 | + |
| 436 | + with patch.object(bema_callback, "_update_bema_weights") as mock_update: |
| 437 | + trainer = Trainer( |
| 438 | + model=self.model, |
| 439 | + args=training_args, |
| 440 | + train_dataset=self.dataset["train"], |
| 441 | + processing_class=self.tokenizer, |
| 442 | + callbacks=[bema_callback], |
| 443 | + ) |
| 444 | + |
| 445 | + trainer.train() |
| 446 | + |
| 447 | + # Total 9 steps (17 samples, batch size 8, 3 epochs). |
| 448 | + # BEMA starts after step 0 and updates every 3 steps → updates at 3, 6, 9 |
| 449 | + self.assertEqual(mock_update.call_args_list, [call(3), call(6), call(9)]) |
| 450 | + |
| 451 | + def test_update_frequency_2(self): |
| 452 | + """Test that BEMA callback respects the update frequency.""" |
| 453 | + training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none") |
| 454 | + bema_callback = BEMACallback(update_freq=2, update_after=3) |
| 455 | + |
| 456 | + with patch.object(bema_callback, "_update_bema_weights") as mock_update: |
| 457 | + trainer = Trainer( |
| 458 | + model=self.model, |
| 459 | + args=training_args, |
| 460 | + train_dataset=self.dataset["train"], |
| 461 | + processing_class=self.tokenizer, |
| 462 | + callbacks=[bema_callback], |
| 463 | + ) |
| 464 | + |
| 465 | + trainer.train() |
| 466 | + |
| 467 | + # Total 9 steps (17 samples, batch size 8, 3 epochs). |
| 468 | + # BEMA starts after step 3 and updates every 2 steps → updates at 5, 7, 9 |
| 469 | + self.assertEqual(mock_update.call_args_list, [call(5), call(7), call(9)]) |
| 470 | + |
| 471 | + def test_no_bema(self): |
| 472 | + """Test that BEMACallback works without BEMA updates.""" |
| 473 | + training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none") |
| 474 | + bema_callback = BEMACallback(update_freq=2, bias_power=0.0) |
| 475 | + trainer = Trainer( |
| 476 | + model=self.model, |
| 477 | + args=training_args, |
| 478 | + train_dataset=self.dataset["train"], |
| 479 | + processing_class=self.tokenizer, |
| 480 | + callbacks=[bema_callback], |
| 481 | + ) |
| 482 | + trainer.train() |
| 483 | + |
| 484 | + def test_no_ema(self): |
| 485 | + """Test that BEMACallback works without EMA updates.""" |
| 486 | + training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none") |
| 487 | + bema_callback = BEMACallback(update_freq=2, ema_power=0.0) |
| 488 | + trainer = Trainer( |
| 489 | + model=self.model, |
| 490 | + args=training_args, |
| 491 | + train_dataset=self.dataset["train"], |
| 492 | + processing_class=self.tokenizer, |
| 493 | + callbacks=[bema_callback], |
| 494 | + ) |
| 495 | + trainer.train() |
0 commit comments