Skip to content

Adapting Whisper to the new loss_function attribute #36119

Open
@yoadsn

Description

@yoadsn

@ArthurZucker @muellerzr Following up on #35838, #34191, #34198, #34283

I would like to help bring in Whisper into this. I see it was not included in the last #35875 round of fixes related to the loss function bug fix (grad acc.) nor the new global "loss_function" attr. Being an encodec model derived from Bart code in many places around loss and decoder input token handling - I suspect Bart would also benefit from such attention.

So - Would like to help with the following missing support:

It does not accept kwargs

In 'forward' (For Conditional Gen) - Seems straight-forward to follow @muellerzr work and implement considering test passing. Anything special to consider there?

I does not use the global "loss_function" attr (introduced with #34191)

I find that the closest Loss implementation would be ForMaskedLMLoss since seems like the shifted labels are expected from how the existing loss is calc'd

if labels is not None:
loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))

Some background on the above

I find this was derived from the Bart implementation, which forced the user to either provide decoder_input_ids or derived them from labels by shifting them to the right as part of the denoising pre-training task type - this lead to a situation where labels are expected to be left shifted compared to the logits which is properly served by the above loss calculation.

Whisper, inherited that, but has a more involved input id prefixing scheme. the model is hardly the place to grab the "decoder start token id" which is required to accomplish the "shift right" of labels and get the decoder_input_ids, and anyway - for Whisper this prefix during inference is critical to determining the task and control over that is properly reflected in other args. (language, task, notimestamps etc)

Thus, proper collators suggested by @sanchit-gandhi in his great guidance and the work on Distill-Whisper have explicitly specified both labels and decoder_input_ids that worked around the auto (now unusable) "label shift righting". (See code here)

Or otherwise "cooked" the labels to contain all but the first "decode start token id" as a hack. (More at #27384) and even the Collator code in the popular blog post about Whisper FT does:

       # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

Which of course is a workaround to mitigate that Bart heritage.

WDYT @sanchit-gandhi, Did I get this right?

Anyway - this is why the ForCausalLMLoss probably won't be a fit - it will shift the labels left to match against logits positions.

Would like to know if that proper loss then to use is ForMaskedLMLoss or maybe a new ForConditionalGenerationLMLoss actually. Personally, I think a new one should exist, that does exactly what ForMaskedLMLoss with some shared implementation for both.

Also, as an aside I would love to see the Bart derived "decode_input_ids from labels" logic adapted to Whisper - but not sure I have the experience to know how.

Grad acc loss bug still applies to Whisper

As it is implemented now - you can (thankfully) customize the loss calc using "compute_loss_func" which was introduced in #34198 - and this is mandatory for anyone who want to avoid the grad acc loss described here and fixed in many PR's around the above mentioned efforts.
This is actually an open bug for Whisper which did not enjoy the common fixed_cross_entropy injection onto other models.

Thanks guys for all the great documentation on this, so much easier to try and contribute back!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Feature requestRequest for a new featureGood Second IssueIssues that are more difficult to do than "Good First" issues - give it a try if you want!

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions