Skip to content

Commit

Permalink
[cherry pick] Update sequence parallel linear import (#8706) (#8728)
Browse files Browse the repository at this point in the history
* Update sequence parallel linear import

* update lora models

* update lora sequence parallel layer

* update import
  • Loading branch information
DrownFish19 authored Jul 9, 2024
1 parent e773524 commit 88244e3
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 150 deletions.
12 changes: 9 additions & 3 deletions paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
RowParallelLinear,
)

from ...transformers import linear_utils

try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
AllGatherOp,
Expand All @@ -33,15 +35,19 @@
mark_as_sequence_parallel_parameter,
)
except:
pass
AllGatherOp = None
ReduceScatterOp = None
mark_as_sequence_parallel_parameter = None
ColumnSequenceParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowSequenceParallelLinear = linear_utils.RowSequenceParallelLinear


from paddlenlp.transformers.mc2_parallel_linear import (
from ...transformers.mc2_parallel_linear import (
MC2ColumnParallelCoreLinear,
MC2ColumnSeqParallelCoreLinear,
MC2RowParallelCoreLinear,
MC2RowSeqParallelCoreLinear,
)

from .lora_quick_layers import quick_lora


Expand Down
10 changes: 3 additions & 7 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
RowParallelLinear,
)

from ...transformers import linear_utils
from ...transformers.conversion_utils import ConversionMixin
from ...transformers.model_utils import (
PretrainedModel,
Expand All @@ -47,11 +48,6 @@
from .lora_config import LoRAConfig

try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
ColumnSequenceParallelLinear,
RowSequenceParallelLinear,
)

from .lora_layers import (
ColumnParallelLoRALinear,
ColumnParallelLoRAMergedLinear,
Expand Down Expand Up @@ -470,7 +466,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
elif isinstance(module, ColumnSequenceParallelLinear):
elif isinstance(module, linear_utils.ColumnSequenceParallelLinear):
# recover the original output_features
output_features = module.weight.shape[1] * module.world_size
lora_module = ColumnSequenceParallelLoRALinear(
Expand Down Expand Up @@ -499,7 +495,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=True)
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
elif isinstance(module, RowSequenceParallelLinear):
elif isinstance(module, linear_utils.RowSequenceParallelLinear):
# recover the original output_features
lora_module = RowSequenceParallelLoRALinear(
in_features=module.weight.shape[0] * module.world_size,
Expand Down
45 changes: 23 additions & 22 deletions paddlenlp/transformers/gemma/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@

try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
ColumnSequenceParallelLinear,
GatherOp,
RowSequenceParallelLinear,
ScatterOp,
mark_as_sequence_parallel_parameter,
)
Expand All @@ -54,6 +52,8 @@
)
from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model

from .. import linear_utils
from ..linear_utils import Linear
from ..segment_parallel_utils import ReshardLayer
from .configuration import (
GEMMA_PRETRAINED_INIT_CONFIGURATION,
Expand Down Expand Up @@ -422,11 +422,11 @@ def __init__(self, config):
self.tensor_parallel_degree = config.tensor_parallel_degree

if config.sequence_parallel:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear
else:
ColumnParallelLinear = mpu.ColumnParallelLinear
RowParallelLinear = mpu.RowParallelLinear
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

if config.tensor_parallel_degree > 1:
self.gate_proj = ColumnParallelLinear(
Expand All @@ -448,9 +448,9 @@ def __init__(self, config):
has_bias=False,
)
else:
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False)

def forward(self, x):
# GeGLU
Expand Down Expand Up @@ -509,11 +509,11 @@ def __init__(self, config: GemmaConfig, layerwise_recompute: bool = False):
self.use_fused_rope = False

if config.sequence_parallel:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear
else:
ColumnParallelLinear = mpu.ColumnParallelLinear
RowParallelLinear = mpu.RowParallelLinear
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

if config.tensor_parallel_degree > 1:
self.q_proj = ColumnParallelLinear(
Expand All @@ -537,29 +537,29 @@ def __init__(self, config: GemmaConfig, layerwise_recompute: bool = False):
gather_output=False,
)
else:
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)

else:
self.q_proj = nn.Linear(
self.q_proj = Linear(
self.hidden_size,
self.config.num_attention_heads * self.head_dim,
bias_attr=False,
)
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
Expand All @@ -573,7 +573,7 @@ def __init__(self, config: GemmaConfig, layerwise_recompute: bool = False):
input_is_parallel=True,
)
else:
self.o_proj = nn.Linear(
self.o_proj = Linear(
self.config.num_attention_heads * self.head_dim,
self.hidden_size,
bias_attr=False,
Expand Down Expand Up @@ -992,10 +992,11 @@ def _init_weights(self, layer):
nn.Linear,
nn.Embedding,
mpu.VocabParallelEmbedding,
mpu.ColumnParallelLinear,
mpu.RowParallelLinear,
ColumnSequenceParallelLinear,
RowSequenceParallelLinear,
mpu.ColumnParallelLinear,
linear_utils.RowSequenceParallelLinear,
linear_utils.ColumnSequenceParallelLinear,
GemmaLMHead,
),
):
# In the dygraph mode, use the `set_value` to reset the parameter directly,
Expand Down
52 changes: 26 additions & 26 deletions paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np
import paddle
import paddle.distributed.fleet.meta_parallel as mpu
import paddle.incubate as incubate
import paddle.nn as nn
import paddle.nn.functional as F
Expand All @@ -32,9 +33,7 @@

try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
ColumnSequenceParallelLinear,
GatherOp,
RowSequenceParallelLinear,
ScatterOp,
mark_as_sequence_parallel_parameter,
)
Expand All @@ -45,7 +44,8 @@

from ...utils.converter import StateDictNameMapping
from ...utils.log import logger
from .. import PretrainedModel, register_base_model
from .. import PretrainedModel, linear_utils, register_base_model
from ..linear_utils import Linear
from ..model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand Down Expand Up @@ -210,11 +210,11 @@ def __init__(
self.num_attention_heads = config.num_attention_heads # default, without tensor parallel

if config.sequence_parallel:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

if config.tensor_parallel_degree > 1:
assert config.num_attention_heads % config.tensor_parallel_degree == 0
Expand Down Expand Up @@ -262,13 +262,13 @@ def __init__(
)
else:
if self.config.fuse_attention_qkv:
self.qkv_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias_attr=True)
self.qkv_proj = Linear(config.hidden_size, 3 * config.hidden_size, bias_attr=True)
else:
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=True)
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=True)
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=True)
self.q_proj = Linear(config.hidden_size, config.hidden_size, bias_attr=True)
self.k_proj = Linear(config.hidden_size, config.hidden_size, bias_attr=True)
self.v_proj = Linear(config.hidden_size, config.hidden_size, bias_attr=True)

self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=True)
self.out_proj = Linear(config.hidden_size, config.hidden_size, bias_attr=True)

def _fuse_prepare_qkv(self, query, use_cache=False, past_key_value=None):
if self.config.sequence_parallel:
Expand Down Expand Up @@ -583,11 +583,11 @@ def __init__(self, config: GPTConfig):
self.self_attn = MultiHeadAttention(config=config)

if config.sequence_parallel:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

# TODO:config.fuse_attention_ffn @DrownFish19
if config.tensor_parallel_degree > 1:
Expand All @@ -607,8 +607,8 @@ def __init__(self, config: GPTConfig):
fuse_matmul_bias=self.config.use_fused_linear,
)
else:
self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias_attr=True)
self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias_attr=True)
self.linear1 = Linear(config.hidden_size, config.intermediate_size, bias_attr=True)
self.linear2 = Linear(config.intermediate_size, config.hidden_size, bias_attr=True)

self.norm1 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5)
self.norm2 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5)
Expand Down Expand Up @@ -980,11 +980,11 @@ def _init_weights(self, layer):
(
nn.Linear,
nn.Embedding,
fleet.meta_parallel.VocabParallelEmbedding,
fleet.meta_parallel.ColumnParallelLinear,
fleet.meta_parallel.RowParallelLinear,
ColumnSequenceParallelLinear,
RowSequenceParallelLinear,
mpu.VocabParallelEmbedding,
mpu.RowParallelLinear,
mpu.ColumnParallelLinear,
linear_utils.RowSequenceParallelLinear,
linear_utils.ColumnSequenceParallelLinear,
),
):
# In the dygraph mode, use the `set_value` to reset the parameter directly,
Expand Down Expand Up @@ -1295,7 +1295,7 @@ def __init__(self, config):
super(GPTPretrainingCriterion, self).__init__()
self.config = config
if config.tensor_parallel_degree > 1 and config.tensor_parallel_output:
self.loss_func = fleet.meta_parallel.ParallelCrossEntropy(ignore_index=config.ignore_index)
self.loss_func = mpu.ParallelCrossEntropy(ignore_index=config.ignore_index)
else:
self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=config.ignore_index)

Expand Down Expand Up @@ -1660,7 +1660,7 @@ def __init__(self, config: GPTConfig):
self.gpt = GPTModel(config) # allow gpt to be config
dropout_p = config.hidden_dropout_prob if config.classifier_dropout is None else config.classifier_dropout
self.dropout = nn.Dropout(dropout_p)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.classifier = Linear(config.hidden_size, config.num_labels)

def forward(
self,
Expand Down Expand Up @@ -1774,7 +1774,7 @@ def __init__(self, config: GPTConfig):
super(GPTForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.gpt = GPTModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias_attr=False)
self.score = Linear(config.hidden_size, config.num_labels, bias_attr=False)

def forward(
self,
Expand Down
21 changes: 19 additions & 2 deletions paddlenlp/transformers/linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,25 @@
ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear
RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear
except:
ColumnSequenceParallelLinear = None
RowSequenceParallelLinear = None

class ColumnSequenceParallelLinearPass(object):
"""
A dummy class for ColumnSequenceParallelLinear, used when the actual class
cannot be imported from sequence_parallel_utils.
"""

pass

class RowSequenceParallelLinearPass(object):
"""
A dummy class for RowSequenceParallelLinear, used when the actual class
cannot be imported from sequence_parallel_utils.
"""

pass

ColumnSequenceParallelLinear = ColumnSequenceParallelLinearPass
RowSequenceParallelLinear = RowSequenceParallelLinearPass

if get_env_device() == "npu":
if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None:
Expand Down
8 changes: 3 additions & 5 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def swiglu(x, y=None):

try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
ColumnSequenceParallelLinear,
GatherOp,
RowSequenceParallelLinear,
ScatterOp,
mark_as_sequence_parallel_parameter,
)
Expand Down Expand Up @@ -1331,11 +1329,11 @@ def _init_weights(self, layer):
nn.Linear,
nn.Embedding,
mpu.VocabParallelEmbedding,
mpu.ColumnParallelLinear,
mpu.RowParallelLinear,
mpu.ColumnParallelLinear,
linear_utils.RowSequenceParallelLinear,
linear_utils.ColumnSequenceParallelLinear,
LlamaLMHead,
ColumnSequenceParallelLinear,
RowSequenceParallelLinear,
),
):
# In the dygraph mode, use the `set_value` to reset the parameter directly,
Expand Down
Loading

0 comments on commit 88244e3

Please sign in to comment.