Skip to content

Commit

Permalink
use fused qkv forward in qwen2 (#10185)
Browse files Browse the repository at this point in the history
* use fused qkv forward in qwen2

* support both

* fix style

* fix rope

* remove pring

* fix style

* clean up
  • Loading branch information
yangw1234 committed Mar 1, 2024
1 parent 509e206 commit f4d7dbc
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 67 deletions.
13 changes: 9 additions & 4 deletions python/llm/src/bigdl/llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ def llama_attention_forward_4_31_quantized(
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
0,
self.head_dim)
self.head_dim,
self.rotary_emb.base,)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
Expand Down Expand Up @@ -511,7 +512,8 @@ def llama_attention_forward_4_31_original(
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
self.head_dim,
self.rotary_emb.base,)
kv_seq_len += 1

else:
Expand Down Expand Up @@ -762,7 +764,9 @@ def llama_attention_selective_batching_forward_4_31(
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
self.head_dim,
self.rotary_emb.base,
)
kv_seq_len += 1
else:
if self.config.pretraining_tp > 1:
Expand Down Expand Up @@ -942,7 +946,8 @@ def llama_attention_forward_4_36(
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
self.head_dim,
self.rotary_emb.base,)
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
if self.layer_idx == 0:
Expand Down
3 changes: 2 additions & 1 deletion python/llm/src/bigdl/llm/transformers/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def mixtral_attention_forward(
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
self.head_dim,
self.rotary_emb.base,)
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
if self.layer_idx == 0:
Expand Down
148 changes: 86 additions & 62 deletions python/llm/src/bigdl/llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ def qwen2_attention_forward_quantized(
attn_weights = None

return attn_output, attn_weights, past_key_value
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
FP8E5 = ggml_tensor_qtype["fp8_e5m2"]


def qwen2_attention_forward_origin(
Expand All @@ -247,72 +250,93 @@ def qwen2_attention_forward_origin(
device = hidden_states.device

enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5]
decoding_fast_path = (qtype_check and use_fuse_rope
and enough_kv_room and bsz * q_len == 1)
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
kv_seq_len = cache_k.shape[-2]
import linear_q4_0
args = [hidden_states, self.q_proj.weight, self.k_proj.weight, self.v_proj.weight,
self.q_proj.bias, self.k_proj.bias, self.v_proj.bias, position_ids, cache_k,
cache_v, self.q_proj.weight.qtype, kv_seq_len, self.head_dim, self.rotary_emb.base]
query_states, key_states, value_states = linear_q4_0.forward_qkv_bias(*args)
kv_seq_len += 1
if self.layer_idx == 0:
past_key_value.seen_tokens = kv_seq_len
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = \
key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = \
value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
invalidInputError(
False,
"The cache structure has changed since version v4.36. "
f"If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, "
"please make sure to initialize the attention class with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
sin, cos, "qwen2",
position_ids)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)

if past_key_value is not None:
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value.seen_tokens += key_states.shape[-2]

if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = \
key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = \
value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
invalidInputError(
False,
"The cache structure has changed since version v4.36. "
f"If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, "
"please make sure to initialize the attention class with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
sin, cos, "qwen2",
position_ids)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]

if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)

new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v

key_states, value_states = append_kv_cache(cache_k,
cache_v,
key_states,
value_states)

# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)

if past_key_value is not None:
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value.seen_tokens += key_states.shape[-2]

if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]

if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)

new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v

key_states, value_states = append_kv_cache(cache_k,
cache_v,
key_states,
value_states)

# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
Expand Down

0 comments on commit f4d7dbc

Please sign in to comment.