Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLM Inference] Support Qwen2_Moe Inference with MultiGPU #9121

Merged
merged 5 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,13 @@ def __init__(self, config: FusedMultiTransformerConfig):
assert config.ring_id != -1
assert config.num_heads % config.nranks == 0
assert config.dim_feedforward % config.nranks == 0
assert config.moe_config.shared_expert_intermediate_size % config.nranks == 0
self.num_heads = config.num_heads // config.nranks
self.kv_num_heads = config.kv_num_heads // config.nranks
dim_feedforward = config.dim_feedforward // config.nranks
self.dim_feedforward = dim_feedforward
shared_expert_intermediate_size = config.moe_config.shared_expert_intermediate_size // config.nranks
self.config.moe_config.shared_expert_intermediate_size = shared_expert_intermediate_size

self.num_layers = config.num_layers
assert self.num_layers > 0
Expand Down Expand Up @@ -644,6 +647,9 @@ def __init__(self, config: FusedMultiTransformerConfig):
# row parallel
_set_var_distributed(linear_weight)
_set_var_distributed(ffn2_weight)
if self.config.moe_config.use_shared_expert(i):
_set_var_distributed(shared_expert_ffn1_weight)
_set_var_distributed(shared_expert_ffn2_weight)

self.ln_scales.append(ln_scale)
self.ln_biases.append(ln_bias)
Expand Down Expand Up @@ -2243,9 +2249,12 @@ def __init__(self, config: FusedMultiTransformerConfig):
assert config.ring_id != -1
assert config.num_heads % config.nranks == 0
assert config.dim_feedforward % config.nranks == 0
assert config.moe_config.shared_expert_intermediate_size % config.nranks == 0
self.num_heads = config.num_heads // config.nranks
self.kv_num_heads = config.kv_num_heads // config.nranks
self.dim_feedforward = config.dim_feedforward // config.nranks
shared_expert_intermediate_size = config.moe_config.shared_expert_intermediate_size // config.nranks
self.config.moe_config.shared_expert_intermediate_size = shared_expert_intermediate_size

self.num_layers = config.num_layers
assert self.num_layers > 0
Expand Down
109 changes: 75 additions & 34 deletions paddlenlp/experimental/transformers/qwen2_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import paddle
from paddle import nn
from paddle.distributed import fleet
from paddle.nn.quant import weight_quantize

from paddlenlp.experimental.transformers.fused_transformer_layers import (
Expand All @@ -34,6 +35,7 @@
)
from paddlenlp.experimental.transformers.utils import infererence_model_from_pretrained
from paddlenlp.transformers import Qwen2MoeConfig, Qwen2MoePretrainedModel
from paddlenlp.transformers.conversion_utils import split_param_func
from paddlenlp.transformers.model_outputs import ( # CausalLMOutputWithCrossAttentions,
BaseModelOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
Expand Down Expand Up @@ -106,7 +108,26 @@ def __init__(self, config: Qwen2MoeConfig):
self.moe_intermediate_size = config.moe_intermediate_size
self.shared_expert_intermediate_size = config.shared_expert_intermediate_size

self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size)
if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0:
self.embed_tokens = fleet.meta_parallel.VocabParallelEmbedding(
self.vocab_size,
self.hidden_size,
weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
)
else:
self.embed_tokens = nn.Embedding(
self.vocab_size,
self.hidden_size,
)

# get ring_id
ring_id = -1
try:
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
ring_id = model_parallel_group.id
except:
pass

ln_scale_attrs = [paddle.ParamAttr(name="fuseqwen2_moe.{}.ln_scale".format(i)) for i in range(self.num_layers)]
qkv_weight_attrs = [
Expand Down Expand Up @@ -216,8 +237,8 @@ def __init__(self, config: Qwen2MoeConfig):
quant_type=self.quant_type,
activation="swiglu",
num_layers=config.num_hidden_layers,
nranks=1,
ring_id=-1,
nranks=config.tensor_parallel_degree,
ring_id=ring_id,
ln_scale_attrs=ln_scale_attrs,
qkv_weight_attrs=qkv_weight_attrs,
qkv_weight_scale_attrs=qkv_weight_scale_attrs,
Expand All @@ -233,6 +254,7 @@ def __init__(self, config: Qwen2MoeConfig):
epsilon=self.rms_norm_eps,
norm_type="rmsnorm",
use_neox_rotary_style=self.use_neox,
rank_id=config.tensor_parallel_rank,
moe_config=moe_config,
)

Expand All @@ -258,6 +280,7 @@ def set_input_embeddings(self, value):
@paddle.no_grad()
def set_state_dict(self, state_dict):
head_size = self.hidden_size // self.num_attention_heads
split_fn = split_param_func()
dtype = paddle.get_default_dtype()
embed_tokens_weight = paddle.to_tensor(state_dict["qwen2_moe.embed_tokens.weight"]).cast(
self.embed_tokens.weight.dtype
Expand All @@ -272,36 +295,47 @@ def set_state_dict(self, state_dict):
self.transformer_block.ln_scales[idx].dtype
)
self.transformer_block.ln_scales[idx].set_value(ln_scale)

unfused_state_dict["qwen2_moe.self_attn.q_proj.weight"] = state_dict[
"qwen2_moe.layers.{}.self_attn.q_proj.weight".format(idx)
]
unfused_state_dict["qwen2_moe.self_attn.k_proj.weight"] = state_dict[
"qwen2_moe.layers.{}.self_attn.k_proj.weight".format(idx)
]
unfused_state_dict["qwen2_moe.self_attn.v_proj.weight"] = state_dict[
"qwen2_moe.layers.{}.self_attn.v_proj.weight".format(idx)
]

concated_qkv_weight = (
np.concatenate(
[
unfused_state_dict["qwen2_moe.self_attn.q_proj.weight"],
unfused_state_dict["qwen2_moe.self_attn.k_proj.weight"],
unfused_state_dict["qwen2_moe.self_attn.v_proj.weight"],
],
if "qwen2_moe.layers.{}.self_attn.qkv_proj.weight".format(idx) in state_dict.keys():
concated_qkv_weight = np.concatenate(
split_fn(
state_dict["qwen2_moe.layers.{}.self_attn.qkv_proj.weight".format(idx)],
is_qkv=True,
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree,
),
axis=-1,
)
.transpose(1, 0)
.reshape(
(
self.num_attention_heads // self.config.tensor_parallel_degree
+ 2 * self.num_key_value_heads // self.config.tensor_parallel_degree
).transpose(1, 0)
else:
unfused_state_dict = {}
unfused_state_dict["qwen2_moe.self_attn.q_proj.weight"] = state_dict[
"qwen2_moe.layers.{}.self_attn.q_proj.weight".format(idx)
]
unfused_state_dict["qwen2_moe.self_attn.k_proj.weight"] = state_dict[
"qwen2_moe.layers.{}.self_attn.k_proj.weight".format(idx)
]
unfused_state_dict["qwen2_moe.self_attn.v_proj.weight"] = state_dict[
"qwen2_moe.layers.{}.self_attn.v_proj.weight".format(idx)
]

concated_qkv_weight = (
np.concatenate(
[
unfused_state_dict["qwen2_moe.self_attn.q_proj.weight"],
unfused_state_dict["qwen2_moe.self_attn.k_proj.weight"],
unfused_state_dict["qwen2_moe.self_attn.v_proj.weight"],
],
axis=-1,
)
.transpose(1, 0)
.reshape(
(
self.num_attention_heads // self.config.tensor_parallel_degree
+ 2 * self.num_key_value_heads // self.config.tensor_parallel_degree
)
* (head_size),
self.hidden_size,
)
* (head_size),
self.hidden_size,
)
)

qkv_weight = paddle.to_tensor(concated_qkv_weight).cast(dtype)

Expand Down Expand Up @@ -693,7 +727,7 @@ def forward(
hidden_states = outputs[0]

# if labels is None,means we need full output, instead of tensor_parallel_output
# tensor_parallel_output is togather with ParallelCrossEntropy
# tensor_parallel_output is together with ParallelCrossEntropy
tensor_parallel_output = (
self.config.tensor_parallel_output and labels is not None and self.config.tensor_parallel_degree > 1
)
Expand Down Expand Up @@ -832,26 +866,33 @@ def get_tensor_parallel_split_mappings(num_layers):
# Row Linear
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
"layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
}

# Column Linear
if config.fuse_attention_qkv:
base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True)
else:
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
# if we have enough num_key_value_heads to split, then split it.
if config.num_key_value_heads % config.tensor_parallel_degree == 0:
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True)

if config.fuse_attention_ffn:
base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial(
fn, is_column=True, is_naive_2fuse=True
)
else:
base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True)
for expert_idx in range(config.num_experts):
base_actions[f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial(fn, is_column=True)
base_actions[f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial(fn, is_column=True)
base_actions[f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial(fn, is_column=False)
base_actions["layers.0.mlp.shared_expert.up_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.shared_expert.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.shared_expert.down_proj.weight"] = partial(fn, is_column=False)

for key, action in base_actions.items():
if "layers.0." in key:
Expand Down
Loading