Skip to content

Conversation

@xantheocracy
Copy link
Collaborator

No description provided.

Copy link
Collaborator Author

@xantheocracy xantheocracy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in progress

@@ -0,0 +1,3 @@
name: kl_from_base_penalty
kl_from_base_coef: 0.001
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this coefficient is completely arbitrary

I have no intuition on how to set it in a sensible way; I suppose this will require a hyperparameter sweep

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can add that as a comment just so we know it is arbitrary directly from the code that would be great!

@xantheocracy xantheocracy marked this pull request as draft December 2, 2025 11:16
@xantheocracy xantheocracy marked this pull request as ready for review December 4, 2025 16:08
@xantheocracy
Copy link
Collaborator Author

  • we previously referred to the instruction-tuned model as base_model throughout the codebase which was a little inaccurate
    • as this generally refers to a model with no post-training aiui
  • now that the codebase also contains reference to the actual base model for computing KL, I figured that would be too confusing so I've changed the code to refer to it_model for what we previously called base_model
  • the it_model config contains a reference to its base model as base_hf_path

there is a non-trivial chance this has introduced bugs sorryyyyyyy <3 but I think it would have been insanely confusing to keep the old naming convention

Copy link
Collaborator Author

@xantheocracy xantheocracy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems reasonable‽

Copy link
Contributor

@stefan319 stefan319 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calculation of kl divergence looks good! Have a question about whether it should be used in the loss or for logging.


num_valid = valid_mask.sum().clamp(min=1.0)
kl_divergence = kl_per_token.sum() / num_valid
loss = loss + self.kl_from_base_coef * kl_divergence
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding of this ablation is we wish to decompose kl divergence as D_{KL}(\pi || \pi_{ref}) = -H(\pi) + \text{CrossEntropy}(\pi, \pi_{ref}). Then by calculating entropy and kl divergence we can understand the role of cross entropy in the PRG. I believe this means we only need to track kl divergence not append to the loss?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is separate to the game. IIRC this is a mitigation strategy where we pressume the pre-trained model, since it hasn't gone through any post-training or RLHF, doesn't have any malicious objectives that could be learned in pt. Therefore, regularizing the model during post-training towards staying close to the pre-trained distribution could mitigate side objectives. We should check with Jacob though!

Copy link
Owner

@aristizabal95 aristizabal95 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

return model_organism, tokenizer
finally:
# Clean up temp directory to avoid storage leaks
shutil.rmtree(temp_dir, ignore_errors=True)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

# type: ignore
#####################################################################
# THIS FILE IS A COPY OF THE TRL GRPO TRAINER FILE #
# WITH THE ENTROPY BONUS ADDED. #
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor note to update this to KL from base model

# after super().__init__() so accelerator is available #
# Use FastLanguageModel to match unsloth optimizations #
############################################################
from unsloth import FastLanguageModel
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe move this import to the top? unless there's good reasons for it being here


num_valid = valid_mask.sum().clamp(min=1.0)
kl_divergence = kl_per_token.sum() / num_valid
loss = loss + self.kl_from_base_coef * kl_divergence
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is separate to the game. IIRC this is a mitigation strategy where we pressume the pre-trained model, since it hasn't gone through any post-training or RLHF, doesn't have any malicious objectives that could be learned in pt. Therefore, regularizing the model during post-training towards staying close to the pre-trained distribution could mitigate side objectives. We should check with Jacob though!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants