Skip to content

Commit

Permalink
[Inference]qwen2-a8w8c8 support use_fake_parameter (#9109)
Browse files Browse the repository at this point in the history
* use_fake_parameter

* fixed qwen2 and llama fake data. support qwen2 multi-card inference

* check
  • Loading branch information
ckl117 authored Sep 12, 2024
1 parent 93a9b2c commit db270d9
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 68 deletions.
2 changes: 1 addition & 1 deletion paddlenlp/experimental/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def __init__(
self.key_map = key_map_dict
self.scale = {}
for scale_type, key_template in self.key_map.items():
self.scale[scale_type] = np.full([num_of_layers], fill_value=-1.0)
self.scale[scale_type] = np.full([num_of_layers], fill_value=-1.0, dtype="float32")
for i in range(num_of_layers):
if key_template.replace("#", str(i)) in self.scale_dict.keys():
self.scale[scale_type][i] = 1 / self.scale_dict[key_template.replace("#", str(i))]
Expand Down
38 changes: 21 additions & 17 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ def __init__(self, config: LlamaConfig):
self.rope_theta = config.rope_theta
self.use_neox = True

self.use_fake_parameter = config.get("use_fake_parameter", False)

self.use_weight_only = False
if config.quant_type == "weight_only_int8":
self.use_weight_only = True
Expand All @@ -417,8 +419,9 @@ def __init__(self, config: LlamaConfig):
self.shift = config.quantization_config.shift
self.smooth = config.quantization_config.smooth
self.shift_smooth_all_linears = config.quantization_config.shift_smooth_all_linears

self.use_fake_parameter = config.get("use_fake_parameter", False)
if self.use_fake_parameter:
self.shift_smooth_all_linears = True
config.quantization_config.shift_smooth_all_linears = True

if self.use_weight_only:
assert (
Expand Down Expand Up @@ -1066,6 +1069,22 @@ def set_state_dict(self, state_dict):
unfused_state_dict["mlp.up_proj.bias"] = paddle.zeros(
shape=[self.intermediate_size], dtype=paddle.get_default_dtype()
)
else:
unfused_state_dict["self_attn.q_proj.bias"] = state_dict[
"llama.layers.{}.self_attn.q_proj.bias".format(idx)
]
unfused_state_dict["self_attn.k_proj.bias"] = state_dict[
"llama.layers.{}.self_attn.k_proj.bias".format(idx)
]
unfused_state_dict["self_attn.v_proj.bias"] = state_dict[
"llama.layers.{}.self_attn.v_proj.bias".format(idx)
]
unfused_state_dict["mlp.gate_proj.bias"] = state_dict[
"llama.layers.{}.mlp.gate_proj.bias".format(idx)
]
unfused_state_dict["mlp.up_proj.bias"] = state_dict[
"llama.layers.{}.mlp.up_proj.bias".format(idx)
]

self.transformer_block.ln_biases[idx].set_value(
paddle.to_tensor(state_dict["llama.layers.{}.input_layernorm.bias".format(idx)])
Expand All @@ -1074,16 +1093,6 @@ def set_state_dict(self, state_dict):
paddle.to_tensor(state_dict["llama.layers.{}.post_attention_layernorm.bias".format(idx)])
)

unfused_state_dict["self_attn.q_proj.bias"] = state_dict[
"llama.layers.{}.self_attn.q_proj.bias".format(idx)
]
unfused_state_dict["self_attn.k_proj.bias"] = state_dict[
"llama.layers.{}.self_attn.k_proj.bias".format(idx)
]
unfused_state_dict["self_attn.v_proj.bias"] = state_dict[
"llama.layers.{}.self_attn.v_proj.bias".format(idx)
]

concated_qkv_biases = np.concatenate(
[
unfused_state_dict["self_attn.q_proj.bias"],
Expand All @@ -1095,11 +1104,6 @@ def set_state_dict(self, state_dict):

self.transformer_block.qkv_biases[idx].set_value(paddle.to_tensor(concated_qkv_biases))

unfused_state_dict["mlp.gate_proj.bias"] = state_dict[
"llama.layers.{}.mlp.gate_proj.bias".format(idx)
]
unfused_state_dict["mlp.up_proj.bias"] = state_dict["llama.layers.{}.mlp.up_proj.bias".format(idx)]

concated_ffn1_bias = np.concatenate(
[unfused_state_dict["mlp.gate_proj.bias"], unfused_state_dict["mlp.up_proj.bias"]], axis=-1
)
Expand Down
Loading

0 comments on commit db270d9

Please sign in to comment.