Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLM Inference] Support Qwen2_Moe Inference with MultiGPU #9121

Merged
merged 5 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,13 @@
assert config.ring_id != -1
assert config.num_heads % config.nranks == 0
assert config.dim_feedforward % config.nranks == 0
assert config.moe_config.shared_expert_intermediate_size % config.nranks == 0

Check warning on line 384 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L384

Added line #L384 was not covered by tests
self.num_heads = config.num_heads // config.nranks
self.kv_num_heads = config.kv_num_heads // config.nranks
dim_feedforward = config.dim_feedforward // config.nranks
self.dim_feedforward = dim_feedforward
shared_expert_intermediate_size = config.moe_config.shared_expert_intermediate_size // config.nranks
self.config.moe_config.shared_expert_intermediate_size = shared_expert_intermediate_size

Check warning on line 390 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L389-L390

Added lines #L389 - L390 were not covered by tests

self.num_layers = config.num_layers
assert self.num_layers > 0
Expand Down Expand Up @@ -641,6 +644,9 @@
# row parallel
_set_var_distributed(linear_weight)
_set_var_distributed(ffn2_weight)
if self.config.moe_config.use_shared_expert(i):
_set_var_distributed(shared_expert_ffn1_weight)
_set_var_distributed(shared_expert_ffn2_weight)

Check warning on line 649 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L647-L649

Added lines #L647 - L649 were not covered by tests

self.ln_scales.append(ln_scale)
self.ln_biases.append(ln_bias)
Expand Down Expand Up @@ -2240,9 +2246,12 @@
assert config.ring_id != -1
assert config.num_heads % config.nranks == 0
assert config.dim_feedforward % config.nranks == 0
assert config.moe_config.shared_expert_intermediate_size % config.nranks == 0

Check warning on line 2249 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L2249

Added line #L2249 was not covered by tests
self.num_heads = config.num_heads // config.nranks
self.kv_num_heads = config.kv_num_heads // config.nranks
self.dim_feedforward = config.dim_feedforward // config.nranks
shared_expert_intermediate_size = config.moe_config.shared_expert_intermediate_size // config.nranks
self.config.moe_config.shared_expert_intermediate_size = shared_expert_intermediate_size

Check warning on line 2254 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L2253-L2254

Added lines #L2253 - L2254 were not covered by tests

self.num_layers = config.num_layers
assert self.num_layers > 0
Expand Down
110 changes: 76 additions & 34 deletions paddlenlp/experimental/transformers/qwen2_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import paddle
from paddle import nn
from paddle.distributed import fleet

Check warning on line 21 in paddlenlp/experimental/transformers/qwen2_moe/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2_moe/modeling.py#L21

Added line #L21 was not covered by tests
from paddle.nn.quant import weight_quantize

from paddlenlp.experimental.transformers.fused_transformer_layers import (
Expand All @@ -34,6 +35,7 @@
)
from paddlenlp.experimental.transformers.utils import infererence_model_from_pretrained
from paddlenlp.transformers import Qwen2MoeConfig, Qwen2MoePretrainedModel
from paddlenlp.transformers.conversion_utils import split_param_func

Check warning on line 38 in paddlenlp/experimental/transformers/qwen2_moe/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2_moe/modeling.py#L38

Added line #L38 was not covered by tests
from paddlenlp.transformers.model_outputs import ( # CausalLMOutputWithCrossAttentions,
BaseModelOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
Expand Down Expand Up @@ -106,7 +108,26 @@
self.moe_intermediate_size = config.moe_intermediate_size
self.shared_expert_intermediate_size = config.shared_expert_intermediate_size

self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size)
if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0:
self.embed_tokens = fleet.meta_parallel.VocabParallelEmbedding(

Check warning on line 112 in paddlenlp/experimental/transformers/qwen2_moe/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2_moe/modeling.py#L111-L112

Added lines #L111 - L112 were not covered by tests
self.vocab_size,
self.hidden_size,
weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
)
else:
self.embed_tokens = nn.Embedding(

Check warning on line 118 in paddlenlp/experimental/transformers/qwen2_moe/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2_moe/modeling.py#L118

Added line #L118 was not covered by tests
self.vocab_size,
self.hidden_size,
)

# get ring_id
ring_id = -1
try:
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
ring_id = model_parallel_group.id
except:
pass

Check warning on line 130 in paddlenlp/experimental/transformers/qwen2_moe/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2_moe/modeling.py#L124-L130

Added lines #L124 - L130 were not covered by tests

ln_scale_attrs = [paddle.ParamAttr(name="fuseqwen2_moe.{}.ln_scale".format(i)) for i in range(self.num_layers)]
qkv_weight_attrs = [
Expand Down Expand Up @@ -216,8 +237,8 @@
quant_type=self.quant_type,
activation="swiglu",
num_layers=config.num_hidden_layers,
nranks=1,
ring_id=-1,
nranks=config.tensor_parallel_degree,
ring_id=ring_id,
ln_scale_attrs=ln_scale_attrs,
qkv_weight_attrs=qkv_weight_attrs,
qkv_weight_scale_attrs=qkv_weight_scale_attrs,
Expand All @@ -233,6 +254,7 @@
epsilon=self.rms_norm_eps,
norm_type="rmsnorm",
use_neox_rotary_style=self.use_neox,
rank_id=config.tensor_parallel_rank,
moe_config=moe_config,
)

Expand All @@ -258,6 +280,7 @@
@paddle.no_grad()
def set_state_dict(self, state_dict):
head_size = self.hidden_size // self.num_attention_heads
split_fn = split_param_func()

Check warning on line 283 in paddlenlp/experimental/transformers/qwen2_moe/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2_moe/modeling.py#L283

Added line #L283 was not covered by tests
dtype = paddle.get_default_dtype()
embed_tokens_weight = paddle.to_tensor(state_dict["qwen2_moe.embed_tokens.weight"]).cast(
self.embed_tokens.weight.dtype
Expand All @@ -272,36 +295,47 @@
self.transformer_block.ln_scales[idx].dtype
)
self.transformer_block.ln_scales[idx].set_value(ln_scale)

unfused_state_dict["qwen2_moe.self_attn.q_proj.weight"] = state_dict[
"qwen2_moe.layers.{}.self_attn.q_proj.weight".format(idx)
]
unfused_state_dict["qwen2_moe.self_attn.k_proj.weight"] = state_dict[
"qwen2_moe.layers.{}.self_attn.k_proj.weight".format(idx)
]
unfused_state_dict["qwen2_moe.self_attn.v_proj.weight"] = state_dict[
"qwen2_moe.layers.{}.self_attn.v_proj.weight".format(idx)
]

concated_qkv_weight = (
np.concatenate(
[
unfused_state_dict["qwen2_moe.self_attn.q_proj.weight"],
unfused_state_dict["qwen2_moe.self_attn.k_proj.weight"],
unfused_state_dict["qwen2_moe.self_attn.v_proj.weight"],
],
if "qwen2_moe.layers.{}.self_attn.qkv_proj.weight".format(idx) in state_dict.keys():
concated_qkv_weight = np.concatenate(

Check warning on line 299 in paddlenlp/experimental/transformers/qwen2_moe/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2_moe/modeling.py#L298-L299

Added lines #L298 - L299 were not covered by tests
split_fn(
state_dict["qwen2_moe.layers.{}.self_attn.qkv_proj.weight".format(idx)],
is_qkv=True,
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree,
),
axis=-1,
)
.transpose(1, 0)
.reshape(
(
self.num_attention_heads // self.config.tensor_parallel_degree
+ 2 * self.num_key_value_heads // self.config.tensor_parallel_degree
).transpose(1, 0)
else:
unfused_state_dict = {}
unfused_state_dict["qwen2_moe.self_attn.q_proj.weight"] = state_dict[

Check warning on line 310 in paddlenlp/experimental/transformers/qwen2_moe/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2_moe/modeling.py#L309-L310

Added lines #L309 - L310 were not covered by tests
"qwen2_moe.layers.{}.self_attn.q_proj.weight".format(idx)
]
unfused_state_dict["qwen2_moe.self_attn.k_proj.weight"] = state_dict[

Check warning on line 313 in paddlenlp/experimental/transformers/qwen2_moe/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2_moe/modeling.py#L313

Added line #L313 was not covered by tests
"qwen2_moe.layers.{}.self_attn.k_proj.weight".format(idx)
]
unfused_state_dict["qwen2_moe.self_attn.v_proj.weight"] = state_dict[

Check warning on line 316 in paddlenlp/experimental/transformers/qwen2_moe/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2_moe/modeling.py#L316

Added line #L316 was not covered by tests
"qwen2_moe.layers.{}.self_attn.v_proj.weight".format(idx)
]

concated_qkv_weight = (

Check warning on line 320 in paddlenlp/experimental/transformers/qwen2_moe/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2_moe/modeling.py#L320

Added line #L320 was not covered by tests
np.concatenate(
[
unfused_state_dict["qwen2_moe.self_attn.q_proj.weight"],
unfused_state_dict["qwen2_moe.self_attn.k_proj.weight"],
unfused_state_dict["qwen2_moe.self_attn.v_proj.weight"],
],
axis=-1,
)
.transpose(1, 0)
.reshape(
(
self.num_attention_heads // self.config.tensor_parallel_degree
+ 2 * self.num_key_value_heads // self.config.tensor_parallel_degree
)
* (head_size),
self.hidden_size,
)
* (head_size),
self.hidden_size,
)
)

qkv_weight = paddle.to_tensor(concated_qkv_weight).cast(dtype)

Expand Down Expand Up @@ -693,7 +727,7 @@
hidden_states = outputs[0]

# if labels is None,means we need full output, instead of tensor_parallel_output
# tensor_parallel_output is togather with ParallelCrossEntropy
# tensor_parallel_output is together with ParallelCrossEntropy
tensor_parallel_output = (
self.config.tensor_parallel_output and labels is not None and self.config.tensor_parallel_degree > 1
)
Expand Down Expand Up @@ -832,26 +866,34 @@
# Row Linear
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
"layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
# "layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这条注释给删掉吧

}

# Column Linear
if config.fuse_attention_qkv:
base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True)
else:
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)

Check warning on line 877 in paddlenlp/experimental/transformers/qwen2_moe/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2_moe/modeling.py#L877

Added line #L877 was not covered by tests
# if we have enough num_key_value_heads to split, then split it.
if config.num_key_value_heads % config.tensor_parallel_degree == 0:
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True)

Check warning on line 883 in paddlenlp/experimental/transformers/qwen2_moe/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2_moe/modeling.py#L882-L883

Added lines #L882 - L883 were not covered by tests

if config.fuse_attention_ffn:
base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial(
fn, is_column=True, is_naive_2fuse=True
)
else:
base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True)
for expert_idx in range(config.num_experts):
base_actions[f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial(fn, is_column=True)
base_actions[f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial(fn, is_column=True)
base_actions[f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial(fn, is_column=False)
base_actions["layers.0.mlp.shared_expert.up_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.shared_expert.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.shared_expert.down_proj.weight"] = partial(fn, is_column=False)

Check warning on line 896 in paddlenlp/experimental/transformers/qwen2_moe/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2_moe/modeling.py#L890-L896

Added lines #L890 - L896 were not covered by tests

for key, action in base_actions.items():
if "layers.0." in key:
Expand Down
Loading