Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set all fusedrope inputs to bf16 #140

Merged
merged 4 commits into from
Apr 10, 2024

Conversation

skaulintel
Copy link

@skaulintel skaulintel commented Apr 1, 2024

When passing --bf16 to train/inference cmd, ensure all input tensors are cast to bf16.

Change fused rope input type to bf16 for llama models

@skaulintel skaulintel requested review from a user and kbinias April 1, 2024 19:29
@mandy-li
Copy link

mandy-li commented Apr 1, 2024

@skaulintel , what is the reason to hard cast data type to bf16? what about --bf16 is not used in the model's command?

@ghost
Copy link

ghost commented Apr 2, 2024

@skaulintel , what is the reason to hard cast data type to bf16? what about --bf16 is not used in the model's command?

+1

@skaulintel
Copy link
Author

@mandy-li In order to calculate RoPE in bf16, we have to cast all inputs to bf16. From my analysis, I see that when we pass --bf16 True to the training script when running llama-7b (examples/language-modeling/run_lora_clm.py) I see the following dtypes passed to apply_customized_rope

q: torch.float32, k: torch.bfloat16, cos: torch.float32, sin: torch.float32

Since query is passed in as float32, we won't compute that in bf16. but for the key states we will compute them in bf16. @kbinias please review.

@skaulintel skaulintel changed the title Update modeling_llama.py set all fusedrope inputs to bf16 Apr 10, 2024
@astachowiczhabana
Copy link

huggingface#1026

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants