Skip to content

Commit

Permalink
Fix llama when rope scaling is not None (#9086)
Browse files Browse the repository at this point in the history
* Fix llama when rope scaling is not None

* fix style

* fix style
  • Loading branch information
yangw1234 committed Oct 6, 2023
1 parent fcb1c61 commit 36dd4af
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/llm/src/bigdl/llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ def llama_attention_forward_4_31(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
use_fuse_rope = query_states.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
use_fuse_rope = use_fuse_rope and self.config.rope_scaling is None

if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
Expand Down

0 comments on commit 36dd4af

Please sign in to comment.