Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fuse_attention_ffn support for qwen #8526

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddlenlp/transformers/qwen/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(
use_flash_attention=False,
use_fused_rms_norm=False,
use_fused_rope=False,
fuse_attention_ffn=False,
sequence_parallel=False,
intermediate_size=22016,
tensor_parallel_output=True,
no_bias=True,
Expand Down Expand Up @@ -77,6 +79,8 @@ def __init__(
self.use_flash_attention = use_flash_attention
self.use_fused_rms_norm = use_fused_rms_norm
self.use_fused_rope = use_fused_rope
self.fuse_attention_ffn = fuse_attention_ffn
self.sequence_parallel = sequence_parallel
self.no_bias = no_bias

self.long_sequence_strategy_type = long_sequence_strategy_type
Expand Down
74 changes: 54 additions & 20 deletions paddlenlp/transformers/qwen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@
from paddle.distributed.fleet.utils import recompute
from paddle.utils import try_import

try:
from paddle.incubate.nn.functional import swiglu
except ImportError:

Check warning on line 31 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L31

Added line #L31 was not covered by tests

def swiglu(x, y=None):
if y is None:
x, y = paddle.chunk(x, chunks=2, axis=-1)
return F.silu(x) * y

Check warning on line 36 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L33-L36

Added lines #L33 - L36 were not covered by tests


from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPast,
Expand All @@ -35,6 +45,7 @@
from paddlenlp.utils.log import logger

from ...utils.converter import StateDictNameMapping, init_name_mappings
from .. import linear_utils
from ..model_outputs import ModelOutput
from .configuration import QWenConfig

Expand Down Expand Up @@ -329,37 +340,60 @@
def __init__(self, config):
super().__init__()
ff_dim_in = config.intermediate_size // 2
self.fuse_attention_ffn = config.fuse_attention_ffn

if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

Check warning on line 347 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L346-L347

Added lines #L346 - L347 were not covered by tests
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

if config.tensor_parallel_degree > 1:
self.w1 = mpu.ColumnParallelLinear(
config.hidden_size,
ff_dim_in,
gather_output=False,
has_bias=False,
)
self.w2 = mpu.ColumnParallelLinear(
config.hidden_size,
ff_dim_in,
gather_output=False,
has_bias=False,
)
self.c_proj = mpu.RowParallelLinear(
if self.fuse_attention_ffn:
self.gate_up_fused_proj = ColumnParallelLinear(

Check warning on line 354 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L353-L354

Added lines #L353 - L354 were not covered by tests
config.hidden_size,
ff_dim_in * 2,
gather_output=False,
has_bias=False,
)
else:
self.w1 = ColumnParallelLinear(

Check warning on line 361 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L361

Added line #L361 was not covered by tests
config.hidden_size,
ff_dim_in,
gather_output=False,
has_bias=False,
)
self.w2 = ColumnParallelLinear(

Check warning on line 367 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L367

Added line #L367 was not covered by tests
config.hidden_size,
ff_dim_in,
gather_output=False,
has_bias=False,
)
self.c_proj = RowParallelLinear(

Check warning on line 373 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L373

Added line #L373 was not covered by tests
ff_dim_in,
config.hidden_size,
input_is_parallel=True,
has_bias=False,
)
else:
self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
if self.fuse_attention_ffn:
self.gate_up_fused_proj = nn.Linear(config.hidden_size, ff_dim_in * 2, bias_attr=not config.no_bias)

Check warning on line 381 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L381

Added line #L381 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def get_tensor_parallel_split_mappings(num_hidden_layers):
final_actions = {}
base_actions = {
# Column Linear
"lm_head.weight": partial(fn, is_column=True),
"qwen.h.0.mlp.w2.weight": partial(fn, is_column=True),
"qwen.h.0.mlp.w1.weight": partial(fn, is_column=True),
"qwen.h.0.attn.c_attn.weight": partial(fn, is_column=True, is_naive_3fuse=True),
"qwen.h.0.attn.c_attn.bias": partial(fn, is_column=True, is_naive_3fuse=True),
# Row Linear
"qwen.wte.weight": partial(fn, is_column=False),
"qwen.h.0.mlp.c_proj.weight": partial(fn, is_column=False),
"qwen.h.0.attn.c_proj.weight": partial(fn, is_column=False),
}
for key, action in base_actions.items():
if "h.0." in key:
for i in range(num_hidden_layers):
final_actions[key.replace("h.0.", f"h.{i}.")] = action
final_actions[key] = action
return final_actions
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)

你要适配一下切分规则,tensor parallel 对 gate_up_fused_proj 的切分规则

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好哒,后续还有一个支持sp的pr,在下一个pr一起补充一下。

else:
self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias_attr=not config.no_bias)

def forward(self, hidden_states):
# up
a1 = self.w1(hidden_states)
# gate
a2 = self.w2(hidden_states)
intermediate_parallel = a1 * F.silu(a2)
# down
# a1 = self.w1(hidden_states)
# # gate
# a2 = self.w2(hidden_states)
# intermediate_parallel = a1 * F.silu(a2)
if self.fuse_attention_ffn:
intermediate_parallel = swiglu(self.gate_up_fused_proj(hidden_states))

Check warning on line 394 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L394

Added line #L394 was not covered by tests
else:
intermediate_parallel = swiglu(self.w2(hidden_states), self.w1(hidden_states))
output = self.c_proj(intermediate_parallel)
return output

Expand Down
Loading