Skip to content

Commit

Permalink
Support Qwen2-7b MLP in int4 and transpose_value_cache=True (#11968)
Browse files Browse the repository at this point in the history
  • Loading branch information
yangw1234 authored Sep 2, 2024
1 parent 65e281b commit c48817b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 11 deletions.
7 changes: 5 additions & 2 deletions python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):
model.llm.config.model_type = "llama"
model = model.llm

if model.config.model_type == "qwen2":
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_forward
model.apply(split_mlp_down_proj)

# lm_head to cpu optimization
if cpu_lm_head:
# disable the optimization by default
Expand Down Expand Up @@ -134,8 +139,6 @@ def optimize_llm(
intra_pp = 2
if inter_pp is None:
inter_pp = 4 if model.config.intermediate_size == 18944 else 1
if model.config.intermediate_size == 18944:
transpose_value_cache = False

from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward
from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner
Expand Down
53 changes: 44 additions & 9 deletions python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,30 @@
from ipex_llm.transformers.npu_models.common import reshape_lm_head_input
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP


def split_mlp_down_proj(module: torch.nn.Module):
if isinstance(module, Qwen2MLP) and module.down_proj.in_features == 18944:
new_linear_0 = torch.nn.Linear(0, 0, bias=False)
new_weight_0 = torch.nn.Parameter(module.down_proj.weight[:, :9472], requires_grad=False)
new_linear_0.weight = new_weight_0
new_linear_0.in_features = new_weight_0.size(1)
new_linear_0.out_features = new_weight_0.size(0)
module.down_proj_0 = new_linear_0
new_linear_1 = torch.nn.Linear(0, 0, bias=False)
new_weight_1 = torch.nn.Parameter(module.down_proj.weight[:, 9472:], requires_grad=False)
new_linear_1.weight = new_weight_1
new_linear_1.in_features = new_weight_1.size(1)
new_linear_1.out_features = new_weight_1.size(0)
module.down_proj_1 = new_linear_1

del module.down_proj


def split_mlp_forward(self, x):
h = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
return self.down_proj_0(h[:, :, :9472]) + self.down_proj_1(h[:, :, 9472:])


class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
Expand Down Expand Up @@ -201,7 +225,7 @@ def __init__(
self.compile()
print("end compiling")

def mlp(self, hidden_states):
def mlp(self, hidden_states, seq_len):
mm1 = self.linear(
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
)
Expand All @@ -211,9 +235,13 @@ def mlp(self, hidden_states):
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
if self.intermediate_size == 18944:
# for qwen2-7b
hidden_states = self.linear(
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=np.int8
)
mm1_0 = self.slice(mm1, begin=[0, 0, 0], end=[1, seq_len, 9472])
mm1_1 = self.slice(mm1, begin=[0, 0, 9472], end=[1, seq_len, 18944])
hidden_states_0 = self.linear(mm1_0, self.hidden_size, 9472,
bias=False, wt_dtype=self.dtype)
hidden_states_1 = self.linear(mm1_1, self.hidden_size, 9472,
bias=False, wt_dtype=self.dtype)
hidden_states = hidden_states_0 + hidden_states_1
else:
hidden_states = self.linear(
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype
Expand Down Expand Up @@ -257,7 +285,7 @@ def build_decoder(
hidden_states = self.eltwise_add(residual, attn_output)
residual = hidden_states
hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, self.seq_len)
hidden_states = self.eltwise_add(residual, hidden_states)
hidden_states = self.convert_to_fp16(hidden_states)

Expand Down Expand Up @@ -343,9 +371,13 @@ def __init__(
)
self.backend_decoders.append(decoder)

offset = 0
for i in range(intra_stages):
start, end = self.layer_ranges[i]
self.backend_decoders[i].set_weights(self.op_id, op_parameters[start * 7:end * 7])
curr_linear_ops = len(self.backend_decoders[i].linear_ops)
curr_parameters = self.op_parameters[offset:offset + curr_linear_ops]
self.backend_decoders[i].set_weights(self.op_id, curr_parameters)
offset = offset + curr_linear_ops

def forward(
self,
Expand Down Expand Up @@ -543,7 +575,8 @@ def run_decode(
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
(mlp_layer.down_proj_0.weight, mlp_layer.down_proj_0.scale),
(mlp_layer.down_proj_1.weight, mlp_layer.down_proj_1.scale)
]

cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
Expand Down Expand Up @@ -814,6 +847,8 @@ def run_prefill(
transpose_value=transpose_value_cache
)
convert_forward(model, Qwen2Attention, qwen2_attention_forward)
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
convert_forward(model, Qwen2MLP, split_mlp_forward)
deocderlayers = model.model.layers

while True:
Expand All @@ -836,7 +871,6 @@ def run_prefill(

hidden_states = layer_outputs[0]
next_decoder_cache = layer_outputs[1]

result_queue.put((hidden_states, next_decoder_cache))


Expand Down Expand Up @@ -1124,10 +1158,11 @@ def qwen2_attention_forward(
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)

cache_kwargs = {"max_seq_len": max_seq_len, "transpose": transpose_value, }

if past_key_value is not None:
if transpose_value:
value_states = value_states.transpose(-1, -2)
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)

Expand Down

0 comments on commit c48817b

Please sign in to comment.