-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
set all fusedrope inputs to bf16 #140
Conversation
@skaulintel , what is the reason to hard cast data type to bf16? what about --bf16 is not used in the model's command? |
+1 |
@mandy-li In order to calculate RoPE in bf16, we have to cast all inputs to bf16. From my analysis, I see that when we pass --bf16 True to the training script when running llama-7b (examples/language-modeling/run_lora_clm.py) I see the following dtypes passed to apply_customized_rope q: torch.float32, k: torch.bfloat16, cos: torch.float32, sin: torch.float32 Since query is passed in as float32, we won't compute that in bf16. but for the key states we will compute them in bf16. @kbinias please review. |
When passing --bf16 to train/inference cmd, ensure all input tensors are cast to bf16.
Change fused rope input type to bf16 for llama models