Skip to content

Fix SFTTrainer token accuracy computation with PromptEncoder #3821

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

zk-quantum
Copy link

Fix SFTTrainer token accuracy computation with PromptEncoder

What does this PR do?

This PR fixes issue #3812 where SFTTrainer fails with a RuntimeError when using PEFT's PromptEncoder configuration. The error occurs because PromptEncoder adds virtual tokens to the model input, causing a dimension mismatch between logits and labels during token accuracy computation.

Fixes #3812

Problem

When using PromptEncoder with num_virtual_tokens, the model prepends virtual tokens to the input sequence. This causes:

  • Logits shape: [batch_size, sequence_length + num_virtual_tokens, vocab_size]
  • Labels shape: [batch_size, sequence_length]

The dimension mismatch causes: RuntimeError: The size of tensor a (123) must match the size of tensor b (91) at non-singleton dimension 1

Solution

Modified the compute_loss method in SFTTrainer to:

  1. Detect when logits and labels have different sequence lengths
  2. Calculate the number of virtual tokens as the difference
  3. Skip the virtual tokens in the logits before computing accuracy
  4. Handle edge cases gracefully with warnings

Testing

Added comprehensive unit tests in tests/test_sft_prompt_encoder.py:

  • Test basic PromptEncoder functionality with 16 virtual tokens
  • Test multiple virtual token counts (8, 32, 64)
  • Verify that token accuracy is properly computed without errors

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines.
  • Did you write any new necessary tests?

- Handle dimension mismatch when using PEFT PromptEncoder with virtual tokens
- Skip virtual tokens in logits before computing token accuracy
- Add comprehensive tests for PromptEncoder configurations
- Fixes huggingface#3812
@kashif
Copy link
Collaborator

kashif commented Jul 31, 2025

thanks @zk-quantum so your solution looks a bit too generic and it might mask out other issues with shape mis-matches.

I think we should have a self.num_virtual_tokens = None and in the _prepare_peft_model perhaps we can do something like:

        # Store num_virtual_tokens for token accuracy computation
        if peft_config is not None:
            self.num_virtual_tokens = getattr(peft_config, "num_virtual_tokens", None)
        elif is_peft_available() and isinstance(model, PeftModel):
            # Model is already a PeftModel, extract num_virtual_tokens from config
            if hasattr(model, "peft_config") and model.peft_config:
                active_adapter = getattr(model, "active_adapter", "default")
                if active_adapter in model.peft_config:
                    peft_config_obj = model.peft_config[active_adapter]
                    self.num_virtual_tokens = getattr(peft_config_obj, "num_virtual_tokens", None)
                else:
                    self.num_virtual_tokens = None
            else:
                self.num_virtual_tokens = None
        else:
            self.num_virtual_tokens = None

but perhaps cleaner, and then in the compute_loss we can shift_logits = shift_logits[:, self.num_virtual_tokens :, :] if the self.num_virtual_tokens is not None AND equal to the difference of the shape mismatch to be sure. What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Prompttuning/ptuning not working with SFTTrainer due to token accuracy
2 participants