Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove additional float/clone() for perf #1374

Open
wants to merge 2 commits into
base: transformers_future
Choose a base branch
from

Conversation

jiminha
Copy link
Collaborator

@jiminha jiminha commented Sep 28, 2024

What does this PR do?

Remove float() from 4.45 upgrade due to perf issue
Extra aten::item, cast is causing perf degradation

DONOTMERGE NOW : Need accuracy test to see if there is any accuracy drop.

Extra aten::item, cast is causing perf degradation
@@ -3580,8 +3579,7 @@ def _assisted_decoding(

# 2.3. Process the new logits
# .float() is needed to retain precision for later logits manipulations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, I think you should at least remove this comment

@regisss
Copy link
Collaborator

regisss commented Oct 2, 2024

@jiminha Have you got time to check if we get the same output sequences removing all these .float()?

@jiminha
Copy link
Collaborator Author

jiminha commented Oct 2, 2024

@jiminha Have you got time to check if we get the same output sequences removing all these .float()?

I tested all text_generation and test_encoder_decoder test cases from pytest and compared the outputs and it all looked the same. We'd like to understand more though why the original transformer added this float() for all logits' last output computation, what specific test cases that they ran and found this float() is needed. Would you be able to check with them?

@libinta libinta added run-test Run CI for PRs from external contributors and removed synapse1.18 labels Oct 2, 2024
@regisss
Copy link
Collaborator

regisss commented Oct 2, 2024

@jiminha It's explained in the first message of this issue: huggingface/transformers#30860

Basically, we don't want to cast all the logits to float in the forward causal-lm modelsand we only do it:

  • if labels are provided to the forward
  • and in generate for the last logit which is needed for computing the next token

So we should probably remove this float here and keep the ones in generate. It's not done in this version of Transformers because it's a breaking change (the model outputs won't have the same type anymore) but they'll do it in v4.46.

@jiminha
Copy link
Collaborator Author

jiminha commented Oct 2, 2024

@jiminha It's explained in the first message of this issue: huggingface/transformers#30860

Basically, we don't want to cast all the logits to float in the forward causal-lm modelsand we only do it:

  • if labels are provided to the forward
  • and in generate for the last logit which is needed for computing the next token

So we should probably remove this float here and keep the ones in generate. It's not done in this version of Transformers because it's a breaking change (the model outputs won't have the same type anymore) but they'll do it in v4.46.

Thanks for the link and explanation. So are we actually having double float cast on this v4.45 in the code since we added float() on the generation/utils.py and also kept the same logic in the model(causal LM) file as well? Maybe that's why I didn't see any regression even after removing float() on the util.

The perf regression what I saw was in this t5 model (t5 test_encoder_decoder test) which doesn't use the causalLM and didn't have float conversion to begin with, so the duration increased with this change.

https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L1768C8-L1768C50

@@ -2370,7 +2369,7 @@ def _sample(
next_token_scores = logits_processor(input_ids, next_token_logits)
else:
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[:, -1, :].float()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider keeping this

@@ -2814,7 +2813,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
else:
next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2)
else:
next_token_logits = outputs.logits[:, -1, :].float()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't normally run to here

@jiminha
Copy link
Collaborator Author

jiminha commented Oct 2, 2024

For T5 model, I think it shows bigger diff in the perf in smaller batch/samples. For bigger batch/samples the difference is very small.

Batch2, sample 200 (current test)

Original
predict_runtime = 0:00:17.73
predict_samples = 200
predict_samples_per_second = 11.28

remove float from _sample

predict_runtime = 0:00:16.51
predict_samples = 200
predict_samples_per_second = 12.108

**bs128/ sample 1000 **

Original
predict_runtime = 0:00:23.91
predict_samples = 1000
predict_samples_per_second = 41.814

Remove float from _sample
predict_runtime = 0:00:23.86
predict_samples = 1000
predict_samples_per_second = 41.909

@libinta libinta removed the run-test Run CI for PRs from external contributors label Oct 2, 2024
@regisss
Copy link
Collaborator

regisss commented Oct 3, 2024

For T5 it would be in T5ForConditionalGeneration but it's true that I don't see any cast to float in the forward.
Let me try with a check on the model type in generate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants