Closed
Description
Feature request
I am training whisper using previous prompts and according to whitepaper, its better to ignore the loss for previous tokens.
We only mask out the training loss over the previous context text, and train the model to predict all other tokens
Which means:
For: <|startofprev|>Here goes previous transcription tokens<|startoftranscript|><|transcribe|><|en|>hello world!<|endoftext|>
the loss must be masked for all tokens that are going before<|startoftranscript|>.
Simply putting -100 in labels is not correct since model actually attend to the decoder_input_ids.
Training without masking gives this:
- Start training whisper with prompts (so called "previous tokens")
- Observe whisper learning to predict prompts. (which its not supposed to do)
- Observe loss never dropping below 1.0
- Observe a lot of hallucinations
transformers
version: 4.34.1- Platform: Windows-10-10.0.22621-SP0
- Python version: 3.9.13
- Huggingface_hub version: 0.16.4
- Safetensors version: 0.3.1
- Accelerate version: 0.21.0
- Accelerate config: not found
- PyTorch version (GPU?): 2.1.0.dev20230803+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
cc: @sanchit-gandhi
Motivation
Finetuning whisper on audio with text context.
Your contribution
class ExtendedWhisperForConditionalGeneration(WhisperForConditionalGeneration):
def __init__(self, config: WhisperConfig):
super().__init__(config)
self.model = WhisperModel(config)
self.post_init()
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if labels is not None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.proj_out(outputs[0])
loss = None
if labels is not None:
# Find the position of the <|startoftranscript|> token in labels
transcribe_positions = (labels == 50258).nonzero(as_tuple=True)[1] - 1
max_position = labels.shape[1]
mask = (
torch.arange(max_position)
.expand(len(labels), max_position)
.to(labels.device)
)
if (
len(transcribe_positions) > 0
): # Ensure there's at least one <|startoftranscript|> token
mask = mask > transcribe_positions[:, None]
# Modify the labels to be -100 (ignored in CrossEntropyLoss) for positions in the mask
labels = torch.where(mask, labels, torch.tensor(-100).to(labels.device))
loss_fct = CrossEntropyLoss()
# Compute loss with modified labels
loss = loss_fct(
lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)
)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
Metadata
Metadata
Assignees
Labels
No labels