Skip to content

Commit

Permalink
using bigdl-llm fused rope for llama (#9066)
Browse files Browse the repository at this point in the history
* optimize llama xpu rope

* fix bug

* fix style

* refine append cache

* remove check

* do not cache cos sin

* remove unnecessary changes

* clean up

* fix style

* check for training
  • Loading branch information
yangw1234 committed Oct 6, 2023
1 parent 5004464 commit fcb1c61
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
16 changes: 12 additions & 4 deletions python/llm/src/bigdl/llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand All @@ -58,7 +59,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:


def llama_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu":
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
else:
Expand Down Expand Up @@ -116,9 +117,16 @@ def llama_attention_forward_4_31(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")

if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")

if past_key_value is not None:
# reuse k, v, self_attention
Expand Down
15 changes: 15 additions & 0 deletions python/llm/src/bigdl/llm/transformers/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,18 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
else:
invalidInputError(False,
f"{model_family} is not supported.")


def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
if q.device.type != "xpu":
invalidInputError(False,
f"only xpu is supported in this function")
import linear_q4_0
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox"]:
linear_q4_0.apply_rotary_embedding_half_qk(q, k, position_ids, q_embed, k_embed)
return q_embed, k_embed
else:
invalidInputError(False,
f"{model_family} is not supported.")

0 comments on commit fcb1c61

Please sign in to comment.