Skip to content

Commit

Permalink
Bugfix: Preserve token fields when converting TrainingArguments to SF…
Browse files Browse the repository at this point in the history
…TConfig (#1794)

* Preserve token fields when converting TrainingArguments to SFTConfig

TrainingArguments.to_dict() redacts token fields, so we have to
individually copy them over when converting to SFTConfig to avoid
breaking push_to_hub functionality.

Also adds a test.

* run precommit

* one-line args_as_dict definition per suggestion from kashif

* generalize token copying to match TrainingArguments behavior

* unwrap |= on dict, to support python 3.8

* use .update instead of |= or for-loop
  • Loading branch information
noahlt authored Jul 3, 2024
1 parent b6af2ed commit 78f8228
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
3 changes: 3 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def test_sft_trainer_backward_compatibility(self):
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
hub_token="not_a_real_token",
)

trainer = SFTTrainer(
Expand All @@ -232,6 +233,8 @@ def test_sft_trainer_backward_compatibility(self):
eval_dataset=self.eval_dataset,
)

assert trainer.args.hub_token == training_args.hub_token

trainer.train()

assert trainer.state.log_history[(-1)]["train_loss"] is not None
Expand Down
5 changes: 4 additions & 1 deletion trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ def __init__(
warnings.warn(f"No `SFTConfig` passed, using `output_dir={output_dir}`.")
args = SFTConfig(output_dir=output_dir)
elif args is not None and args.__class__.__name__ == "TrainingArguments":
args = SFTConfig(**args.to_dict())
args_as_dict = args.to_dict()
# Manually copy token values as TrainingArguments.to_dict() redacts them
args_as_dict.update({k: getattr(args, k) for k in args_as_dict.keys() if k.endswith("_token")})
args = SFTConfig(**args_as_dict)

if model_init_kwargs is not None:
warnings.warn(
Expand Down

0 comments on commit 78f8228

Please sign in to comment.