Skip to content

Commit

Permalink
sasarkar/fusedrope inp bf16 #1026
Browse files Browse the repository at this point in the history
  • Loading branch information
hsubramony committed May 31, 2024
1 parent 2509c74 commit 7713a65
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,15 @@ def prepare_inputs_for_generation(
def apply_customized_rope(q, k, cos, sin, position_ids):
if q.device.type == "hpu" and has_fused_rope:
# TODO: remove `.clone()` when it is fixed in SynapseAI
if k.dtype == torch.bfloat16:
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
), FusedRoPE.apply(
k,
cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
position_ids,
)
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
), FusedRoPE.apply(
Expand Down

0 comments on commit 7713a65

Please sign in to comment.