Skip to content

Ignore whisper previous tokens during loss calculation #27384

Closed
@DavraYoung

Description

@DavraYoung

Feature request

I am training whisper using previous prompts and according to whitepaper, its better to ignore the loss for previous tokens.

We only mask out the training loss over the previous context text, and train the model to predict all other tokens

Which means:
For: <|startofprev|>Here goes previous transcription tokens<|startoftranscript|><|transcribe|><|en|>hello world!<|endoftext|>
the loss must be masked for all tokens that are going before<|startoftranscript|>.

Simply putting -100 in labels is not correct since model actually attend to the decoder_input_ids.
Training without masking gives this:

  1. Start training whisper with prompts (so called "previous tokens")
  2. Observe whisper learning to predict prompts. (which its not supposed to do)
  3. Observe loss never dropping below 1.0
  4. Observe a lot of hallucinations
  • transformers version: 4.34.1
  • Platform: Windows-10-10.0.22621-SP0
  • Python version: 3.9.13
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.1
  • Accelerate version: 0.21.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0.dev20230803+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

cc: @sanchit-gandhi

Motivation

Finetuning whisper on audio with text context.

Your contribution

class ExtendedWhisperForConditionalGeneration(WhisperForConditionalGeneration):
    def __init__(self, config: WhisperConfig):
        super().__init__(config)
        self.model = WhisperModel(config)
        self.post_init()

    def forward(
        self,
        input_features: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        if labels is not None:
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )

        outputs = self.model(
            input_features,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        lm_logits = self.proj_out(outputs[0])

        loss = None
        if labels is not None:
            # Find the position of the <|startoftranscript|> token in labels
            transcribe_positions = (labels == 50258).nonzero(as_tuple=True)[1] - 1
            max_position = labels.shape[1]
            mask = (
                torch.arange(max_position)
                .expand(len(labels), max_position)
                .to(labels.device)
            )

            if (
                len(transcribe_positions) > 0
            ):  # Ensure there's at least one <|startoftranscript|> token
                mask = mask > transcribe_positions[:, None]

            # Modify the labels to be -100 (ignored in CrossEntropyLoss) for positions in the mask
            labels = torch.where(mask, labels, torch.tensor(-100).to(labels.device))

            loss_fct = CrossEntropyLoss()
            # Compute loss with modified labels
            loss = loss_fct(
                lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)
            )
        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions