-
Notifications
You must be signed in to change notification settings - Fork 0
KL from base trainer #311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
KL from base trainer #311
Conversation
xantheocracy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in progress
| @@ -0,0 +1,3 @@ | |||
| name: kl_from_base_penalty | |||
| kl_from_base_coef: 0.001 | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this coefficient is completely arbitrary
I have no intuition on how to set it in a sensible way; I suppose this will require a hyperparameter sweep
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you can add that as a comment just so we know it is arbitrary directly from the code that would be great!
there is a non-trivial chance this has introduced bugs sorryyyyyyy <3 but I think it would have been insanely confusing to keep the old naming convention |
xantheocracy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems reasonable‽
stefan319
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calculation of kl divergence looks good! Have a question about whether it should be used in the loss or for logging.
|
|
||
| num_valid = valid_mask.sum().clamp(min=1.0) | ||
| kl_divergence = kl_per_token.sum() / num_valid | ||
| loss = loss + self.kl_from_base_coef * kl_divergence |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding of this ablation is we wish to decompose kl divergence as D_{KL}(\pi || \pi_{ref}) = -H(\pi) + \text{CrossEntropy}(\pi, \pi_{ref}). Then by calculating entropy and kl divergence we can understand the role of cross entropy in the PRG. I believe this means we only need to track kl divergence not append to the loss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is separate to the game. IIRC this is a mitigation strategy where we pressume the pre-trained model, since it hasn't gone through any post-training or RLHF, doesn't have any malicious objectives that could be learned in pt. Therefore, regularizing the model during post-training towards staying close to the pre-trained distribution could mitigate side objectives. We should check with Jacob though!
aristizabal95
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
| return model_organism, tokenizer | ||
| finally: | ||
| # Clean up temp directory to avoid storage leaks | ||
| shutil.rmtree(temp_dir, ignore_errors=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch!
| # type: ignore | ||
| ##################################################################### | ||
| # THIS FILE IS A COPY OF THE TRL GRPO TRAINER FILE # | ||
| # WITH THE ENTROPY BONUS ADDED. # |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor note to update this to KL from base model
| # after super().__init__() so accelerator is available # | ||
| # Use FastLanguageModel to match unsloth optimizations # | ||
| ############################################################ | ||
| from unsloth import FastLanguageModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe move this import to the top? unless there's good reasons for it being here
|
|
||
| num_valid = valid_mask.sum().clamp(min=1.0) | ||
| kl_divergence = kl_per_token.sum() / num_valid | ||
| loss = loss + self.kl_from_base_coef * kl_divergence |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is separate to the game. IIRC this is a mitigation strategy where we pressume the pre-trained model, since it hasn't gone through any post-training or RLHF, doesn't have any malicious objectives that could be learned in pt. Therefore, regularizing the model during post-training towards staying close to the pre-trained distribution could mitigate side objectives. We should check with Jacob though!
No description provided.