-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SFT] add SFT Trainer Config dataclass #1530
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding this config @kashif ! LGTM apart from some nits.
Given this is the most used trainer in the lib, would you mind adding a small backwards compatibility integration test where we init the trainer in two ways (setting all args we deprecated in the init) and then checking the results after 1 train step are the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much ! In addition to @lewtun comments i have two tiny comments, WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great thanks ! failing test is unrelated to this PR: https://github.com/huggingface/trl/actions/runs/8786654281/job/24110021581?pr=1530
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, with a minor nit on the dataset split names. Feel free to merge after fixing that!
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Add a SFTConfig training argument dataclass together with deprecation warnings to the SFTTrainer