Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
Browse files Browse the repository at this point in the history
…nto develop
  • Loading branch information
DesmonDay committed May 7, 2024
2 parents 1935b46 + 9f3cf82 commit edc04f3
Show file tree
Hide file tree
Showing 25 changed files with 586 additions and 177 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

* **2024.04.24 [PaddleNLP v2.8](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v2.8.0)**:自研极致收敛的RsLoRA+算法,大幅提升PEFT训练收敛速度以及训练效果;引入高性能生成加速到RLHF PPO算法,打破 PPO 训练中生成速度瓶颈,PPO训练性能大幅领先。通用化支持 FastFNN、FusedQKV等多个大模型训练性能优化方式,大模型训练更快、更稳定。

* **2024.01.04 [PaddleNLP v2.7](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v2.7.1)**: 大模型体验全面升级,统一工具链大模型入口。统一预训练、精调、压缩、推理以及部署等环节的实现代码,到 `PaddleNLP/llm`目录。全新大[模型工具链文档](https://paddlenlp.readthedocs.io/zh/latest/llm/finetune.html),一站式指引用户从大模型入门到业务部署上线。全断点存储机制 Unified Checkpoint,大大提高大模型存储的通用性。高效微调升级,支持了高效微调+LoRA同时使用,支持了QLoRA等算法。
* **2024.01.04 [PaddleNLP v2.7](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v2.7.1)**: 大模型体验全面升级,统一工具链大模型入口。统一预训练、精调、压缩、推理以及部署等环节的实现代码,到 `PaddleNLP/llm`目录。全新[大模型工具链文档](https://paddlenlp.readthedocs.io/zh/latest/llm/finetune.html),一站式指引用户从大模型入门到业务部署上线。全断点存储机制 Unified Checkpoint,大大提高大模型存储的通用性。高效微调升级,支持了高效微调+LoRA同时使用,支持了QLoRA等算法。

* **2023.08.15 [PaddleNLP v2.6](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v2.6.0)**: 发布[全流程大模型工具链](./llm),涵盖预训练,精调,压缩,推理以及部署等各个环节,为用户提供端到端的大模型方案和一站式的开发体验;内置[4D并行分布式Trainer](./docs/trainer.md)[高效微调算法LoRA/Prefix Tuning](./llm#33-lora), [自研INT8/INT4量化算法](./llm#6-量化)等等;全面支持[LLaMA 1/2](./llm/llama), [BLOOM](.llm/bloom), [ChatGLM 1/2](./llm/chatglm), [GLM](./llm/glm), [OPT](./llm/opt)等主流大模型

Expand Down
2 changes: 1 addition & 1 deletion llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ PaddleNLP将飞桨4D并行策略加入到Trainer API中, 用户只需修改Tra

此项目支持了LLaMA、GPT-3、BaiChuan、Qwen 等大模型的预训练。用户切换配置config文件,即可一键运行。

数据详细制作流程可参考[此处](https://paddlenlp.readthedocs.io/zh/latest/pretraining/dataset.html) : https://paddlenlp.readthedocs.io/zh/latest/pretraining/dataset.html
数据详细制作流程可参考[此处](https://paddlenlp.readthedocs.io/zh/latest/llm/pretraining/dataset.html) : https://paddlenlp.readthedocs.io/zh/latest/llm/pretraining/dataset.html

为了方便用户运行测试本模型,本项目提供了处理好的100k条doc的训练样本:
```shell
Expand Down
4 changes: 2 additions & 2 deletions llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def tokenize_example(tokenizer, example, data_args):
return tokenized_source, tokenized_target_input_ids


def tokenize_rounds_example(tokenizer, example, data_args):
def tokenize_rounds_example(tokenizer, example, data_args, **kwargs):
"""tokenize multi-rounds examples with chat_template.json
Args:
Expand All @@ -117,7 +117,7 @@ def tokenize_rounds_example(tokenizer, example, data_args):

# 1. only tokenize input_ids
conversation_result: list[tuple[list[int], list[int]]] = tokenizer.encode_chat_inputs(
conversations, context_data=context_data
conversations, context_data=context_data, **kwargs
)
system_ids = conversation_result.pop("system", []) or []

Expand Down
17 changes: 16 additions & 1 deletion llm/llama/auto_parallel/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ class PreTrainingArguments(TrainingArguments):
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
},
)
eliminate_transpose: bool = field(
default=False,
metadata={
"help": "Enable eliminate_transpose pass, which should replace transpose with reshape when sequence parallel is enabled."
},
)
job_schedule_profiler_start: int = field(
default=-1,
metadata={"help": "The step to start job_schedule_profiler."},
Expand Down Expand Up @@ -132,6 +138,11 @@ def __post_init__(self):
fused_passes.enable = True
fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass")

if self.eliminate_transpose:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("eliminate_transpose")

logger.info(self.strategy)


Expand Down Expand Up @@ -574,7 +585,11 @@ def fn(layer):
# Create the learning_rate sheduler and optimizer
if training_args.decay_steps is None:
training_args.decay_steps = training_args.max_steps
warmup_steps = training_args.warmup_ratio * training_args.max_steps

if training_args.warmup_steps > 0:
warmup_steps = training_args.warmup_steps
else:
warmup_steps = training_args.warmup_ratio * training_args.max_steps

lr_scheduler = None
if training_args.lr_scheduler_type.value == "cosine":
Expand Down
18 changes: 18 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device


def add_start_docstrings(*docstr):
Expand Down Expand Up @@ -218,6 +219,10 @@ class ModelArguments:
default=False,
metadata={"help": "recompute_use_reentrant"},
)
num_hidden_layers: Optional[int] = field(
default=None,
metadata={"help": "num_hidden_layers."},
)


def create_pretrained_dataset(
Expand Down Expand Up @@ -451,6 +456,9 @@ def main():
if model_args.no_recompute_layers is not None:
model_args.no_recompute_layers.sort()

config.num_hidden_layers = (
model_args.num_hidden_layers if model_args.num_hidden_layers is not None else config.num_hidden_layers
)
config.use_flash_attention = model_args.use_flash_attention
config.use_fused_rms_norm = model_args.use_fused_rms_norm
config.use_fast_layer_norm = model_args.use_fast_layer_norm
Expand Down Expand Up @@ -483,6 +491,16 @@ def main():
config.num_attention_heads % config.sep_parallel_degree == 0
), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}"

if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
try:
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401

LinearConfig.enable_accumulate_steps_opt()
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
except ImportError:
# It's OK, not use accumulate_steps optimization
pass

print("Final pre-training config:", config)

# Set the dtype for loading model
Expand Down
2 changes: 2 additions & 0 deletions llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,11 @@ def get_lora_target_modules(model):
".*v_proj.*",
".*k_proj.*",
".*o_proj.*",
".*qkv_proj.*",
".*gate_proj.*",
".*down_proj.*",
".*up_proj.*",
".*gate_up_fused_proj.*",
]
elif model.base_model_prefix == "opt":
target_modules = [
Expand Down
3 changes: 1 addition & 2 deletions paddlenlp/taskflow/text_feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,7 @@ def _parse_batch(batch_examples):
)
else:
tokenized_inputs = self._tokenizer(
text=[""] * len(batch_examples),
text_pair=batch_examples,
text=batch_examples,
padding="max_length",
truncation=True,
max_seq_len=self.max_seq_len,
Expand Down
59 changes: 59 additions & 0 deletions paddlenlp/transformers/linear_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This file is used for replacing Paddle's native Linear implementations with vendors' customized implementations
"""

import paddle.distributed.fleet.meta_parallel as mpu
from paddle import nn
from paddle.distributed.fleet.utils import sequence_parallel_utils

from paddlenlp.transformers.mc2_parallel_linear import (
MC2ColumnSeqParallelLinear,
MC2RowSeqParallelLinear,
)
from paddlenlp.utils.tools import get_env_device

Linear = nn.Linear
ColumnParallelLinear = mpu.ColumnParallelLinear
RowParallelLinear = mpu.RowParallelLinear
ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear
RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear

if get_env_device() == "npu":
if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None:
ColumnSequenceParallelLinear = MC2ColumnSeqParallelLinear
RowSequenceParallelLinear = MC2RowSeqParallelLinear
elif get_env_device() == "xpu":
try:
from paddle_xpu.layers.nn import ColumnParallelLinear as XPUColumnParallelLinear
from paddle_xpu.layers.nn import Linear as XPULinear
from paddle_xpu.layers.nn import RowParallelLinear as XPURowParallelLinear
from paddle_xpu.layers.nn.sequence_parallel import (
XPUColumnSequenceParallelLinear,
XPURowSequenceParallelLinear,
)

Linear = XPULinear
ColumnParallelLinear = XPUColumnParallelLinear
RowParallelLinear = XPURowParallelLinear
ColumnSequenceParallelLinear = XPUColumnSequenceParallelLinear
RowSequenceParallelLinear = XPURowSequenceParallelLinear
except ImportError:
# If paddle_xpu is not installed, just use Paddle's native Linear implementations
pass
else:
# By default, use Paddle's native Linear implementations
pass
Loading

0 comments on commit edc04f3

Please sign in to comment.