-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLM-paddle] add llama1-7b pretrain with callback #239
Merged
Merged
Changes from 14 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
24630bf
modify gitignore
LaiXinyi823 cd32e31
add paddle llama
LaiXinyi823 efb63cc
add recompute and sharding for llama7b
LaiXinyi823 3e6bed6
adapte to the driver & fix start_paddle_task
LaiXinyi823 0678669
fix llama1-7b fig files and trainer
LaiXinyi823 45c4220
[callback] llama1-7B pretrain
LaiXinyi823 256107f
modify the llama case config name in test_conf.py
LaiXinyi823 af2fd2f
update config
DrownFish19 dc0b439
Merge remote-tracking branch 'flagperf/main' into callback
DrownFish19 925452b
Merge branch 'FlagOpen:main' into callback
DrownFish19 9b247dc
update config
DrownFish19 93d2542
Merge branch 'callback' of https://github.com/LaiXinyi823/FlagPerf in…
DrownFish19 94bd7b9
add metrics in README.md
DrownFish19 81ac81d
update README.md
DrownFish19 4eb09f6
remove llama 13B files
DrownFish19 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .event import Event | ||
from .base import Driver | ||
from .event import Event | ||
from .log_event import LogEventManager |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import os | ||
from contextlib import contextmanager | ||
|
||
import paddle | ||
import paddle.distributed as dist | ||
from paddlenlp.trainer import ( | ||
TrainerCallback, | ||
TrainerControl, | ||
TrainerState, | ||
TrainingArguments, | ||
) | ||
from paddlenlp.trainer.trainer_utils import IntervalStrategy | ||
|
||
from .base import Driver | ||
from .event import Event | ||
from typing import Dict | ||
|
||
def barrier(): | ||
if dist.is_initialized(): | ||
dist.barrier() | ||
|
||
def is_main_process(): | ||
if dist.is_initialized(): | ||
if "PADDLE_TRAINER_ID" in os.environ: | ||
return int(os.environ["PADDLE_TRAINER_ID"]) == 0 | ||
else: | ||
return dist.get_rank() == 0 | ||
|
||
return True | ||
|
||
class PaddleCallback(TrainerCallback): | ||
def __init__(self, driver: Driver): | ||
self.driver = driver | ||
|
||
def on_init_end( | ||
self, | ||
args: TrainingArguments, | ||
state: TrainerState, | ||
control: TrainerState, | ||
**kwargs | ||
): | ||
self.driver.event(Event.INIT_END) | ||
|
||
def on_train_begin( | ||
self, | ||
args: TrainingArguments, | ||
state: TrainerState, | ||
control: TrainerControl, | ||
**kwargs | ||
): | ||
self.driver.event(Event.TRAIN_START) | ||
|
||
def on_train_end( | ||
self, | ||
args: TrainingArguments, | ||
state: TrainerState, | ||
control: TrainerControl, | ||
**kwargs | ||
): | ||
self.driver.event(Event.TRAIN_END) | ||
|
||
def on_epoch_begin( | ||
self, | ||
args: TrainingArguments, | ||
state: TrainerState, | ||
control: TrainerControl, | ||
**kwargs | ||
): | ||
self.driver.event(Event.EPOCH_BEGIN, epoch=state.epoch) | ||
|
||
def on_epoch_end( | ||
self, | ||
args: TrainingArguments, | ||
state: TrainerState, | ||
control: TrainerControl, | ||
**kwargs | ||
): | ||
self.driver.event(Event.EPOCH_END, epoch=state.epoch) | ||
|
||
def on_step_begin( | ||
self, | ||
args: TrainingArguments, | ||
state: TrainerState, | ||
control: TrainerControl, | ||
**kwargs | ||
): | ||
self.driver.event(Event.STEP_BEGIN, step=state.global_step + 1) | ||
|
||
def on_evaluate( | ||
self, | ||
args: TrainingArguments, | ||
state: TrainerState, | ||
control: TrainerControl, | ||
**kwargs | ||
): | ||
logs = kwargs["metrics"] | ||
logs["global_step"] = state.global_step | ||
self.driver.event(Event.EVALUATE, result=logs) | ||
if kwargs["metrics"]["eval_ppl"] < self.driver.config.target_ppl: | ||
control.should_training_stop = True | ||
|
||
|
||
|
||
def on_log( | ||
self, | ||
args: TrainingArguments, | ||
state: TrainerState, | ||
control: TrainerControl, | ||
logs=None, | ||
**kwargs | ||
): | ||
_ = logs.pop("total_flos", None) | ||
if state.is_local_process_zero: | ||
self.driver.logger.log(Event.STEP_END, message=logs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../llama1_7B/README.md |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../llama1_7B/paddle/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
### 模型信息 | ||
#### 模型介绍 | ||
We introduce LLaMA, a collection of foundation language models ranging from 7B to 65B parameters. We train our models on trillions of tokens, and show that it is possible to train state-of-the-art models using publicly available datasets exclusively, without resorting to proprietary and inaccessible datasets. In particular, LLaMA-13B outperforms GPT-3 (175B) on most benchmarks, and LLaMA65B is competitive with the best models, Chinchilla-70B and PaLM-540B. We release all our models to the research community1. | ||
|
||
Please refer to this paper for a detailed description of LLaMA1: | ||
[LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) | ||
|
||
#### 模型代码来源 | ||
Paddle case代码来源: | ||
https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/llama licensed under the Apache License, Version 2.0. | ||
|
||
#### 数据集 | ||
##### 测试数据集下载地址 | ||
测试数据集中提供了处理好的openwebtext 100k条 doc的训练样本: | ||
``` | ||
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_ids.npy | ||
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_idx.npz | ||
``` | ||
|
||
##### 预处理 | ||
> 无需预处理 | ||
|
||
#### 模型实现 | ||
* 运行自动加载 | ||
|
||
#### 模型checkpoint | ||
* 运行自动下载 | ||
* Paddle的 LLaMA 模型的权重的使用则需要遵循[License](../../paddlenlp/transformers/llama/LICENSE)。 | ||
|
||
### 框架与芯片支持情况 | ||
| | Pytorch |Paddle|TensorFlow2| | ||
| ---- | ---- | ---- | ---- | | ||
| Nvidia GPU |N/A |✅ |N/A| | ||
| | | | | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from ._base import * | ||
from .mutable_params import mutable_params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
# ========================================================= | ||
# Required parameters | ||
# ========================================================= | ||
vendor: str = None | ||
|
||
device: str = "gpu" | ||
|
||
|
||
# ========================================================= | ||
# data | ||
# ========================================================= | ||
# The name of the dataset to use (via the datasets library). | ||
input_dir : str = "data" | ||
|
||
# Train/valid/test data split. | ||
split: str = "949,50,1" | ||
|
||
# The maximum total input sequence length after tokenization. Sequences longer " | ||
# "than this will be truncated, sequences shorter will be padded. | ||
max_seq_length: int = 2048 | ||
|
||
# Mask token prob. | ||
masked_lm_prob: float = 0.15 | ||
|
||
# Short sequence prob. | ||
short_seq_prob: float = 0. | ||
|
||
# Use share folder for data dir and output dir on multi machine. | ||
share_folder: bool = False | ||
|
||
# Whether to favor long ngrams | ||
favor_longer_ngram: bool = False | ||
|
||
# Max N Grams | ||
max_ngrams: int = 3 | ||
|
||
# mmap/lazy format converted from preprocessed data. | ||
data_impl: str = "mmap" | ||
|
||
# Drop the last incomplete batch if it is not divisible by the batch size. | ||
dataloader_drop_last: bool = False | ||
|
||
# Number of subprocesses to use for data loading. | ||
# 0 means that the data will be loaded in the main process. | ||
dataloader_num_workers: int = 1 | ||
|
||
|
||
# ========================================================= | ||
# Model | ||
# ========================================================= | ||
# Only support for llama pre-training for now. | ||
model_type: str = "llama" | ||
|
||
# Path to pretrained model or model identifier from https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html | ||
model_name_or_path: str = "facebook/llama-7b" # "facebook/llama-7b" | ||
|
||
# Pretrained tokenizer name or path if not the same as model_name | ||
tokenizer_name_or_path: str = "facebook/llama-7b" | ||
|
||
# Pre-training from existing paddlenlp model weights. Default Fasle and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddlenlp models. | ||
continue_training: bool = True | ||
|
||
# use flash attention | ||
use_flash_attention: bool = False | ||
|
||
# use fused rms_norm | ||
use_fused_rms_norm: bool = False | ||
|
||
# ========================================================= | ||
# trainer args | ||
# ========================================================= | ||
# The output directory where the model predictions and checkpoints will be written. | ||
output_dir: str = None | ||
|
||
# Whether to run training. | ||
do_train: bool = True | ||
|
||
# Whether to run eval on the dev set. | ||
do_eval: bool = True | ||
|
||
# Batch size per GPU core/CPU for training. | ||
per_device_train_batch_size: int = 1 | ||
|
||
# Batch size per GPU core/CPU for evaluation. | ||
per_device_eval_batch_size: int = 1 | ||
|
||
# Number of updates steps to accumulate before performing a backward/update pass. | ||
gradient_accumulation_steps: int = 1 | ||
|
||
# If > 0: set total number of training steps to perform. Override num_train_epochs. | ||
max_steps: int = -1 | ||
|
||
# Log every X updates steps. | ||
logging_steps: int = 20 | ||
log_freq = logging_steps | ||
|
||
# Random seed that will be set at the beginning of training. | ||
seed: int = 42 | ||
|
||
# Whether or not to use Paddle Sharding Data Parallel training (in distributed training | ||
# only). The base option should be `stage1`, `stage2` or `stage3` and you can add | ||
# CPU-offload to `stage2` or `stage3` like this: stage2 offload` or `stage3 offload`. | ||
# sharding: str = None | ||
|
||
# tensor_parallel_degree means split the transformer layer to how many parts. | ||
# default -1 for not use tensor parallel, Suggest tensor_parallel_degree<=8 for better proformance. | ||
# Note, this need model support in source code. | ||
tensor_parallel_degree: int = -1 | ||
|
||
# pipeline_parallel_degree means split all transformer layers to how many stages. | ||
# default -1 for not use pipeline parallel. | ||
# Note. this need model support in source code, see llama modeling_pp.py file | ||
pipeline_parallel_degree: int = -1 | ||
|
||
# Recompute the forward pass to calculate gradients. Used for saving memory. | ||
recompute: bool = True | ||
|
||
# Whether or not to disable the tqdm progress bars. | ||
disable_tqdm : bool = True | ||
|
||
# Run an evaluation every X steps. | ||
eval_steps: int = 1000 | ||
|
||
# Number of updates steps before two checkpoint saves if `save_strategy="steps"`. | ||
save_steps: int = 5000 | ||
|
||
# The steps use to control the learing rate. If the step > decay_steps, will use the min_lr. | ||
decay_steps: int = None | ||
|
||
# virtual_pp_degree | ||
virtual_pp_degree: int = 1 | ||
|
||
# use sequence parallel. If mp_degree=1, sequence_parallel is forced to be False. | ||
sequence_parallel: bool = False | ||
|
||
# Whether to use distributed dataloader | ||
distributed_dataloader: bool = True | ||
|
||
# recompute训练的粒度 | ||
# 可选 `full` `full_attn` `core_attn` | ||
# full即recompute全部transformer | ||
# full_attn表明只recompute所有self attention部分 | ||
# core_attn表明只recompute `softmax(qkT)v` 部分 | ||
# 注:显存占用方面,`core_attn` > `full_attn` > `full`,若所选策略产生OOM错误,可以适当更改 | ||
recompute_granularity: int = "full" | ||
|
||
# target perplexity value | ||
target_ppl: float = 10.0 | ||
|
||
# ========================================================= | ||
# fp16 config args | ||
# ========================================================= | ||
# Whether to use fp16 (mixed) precision instead of 32-bit | ||
fp16: bool = True | ||
|
||
# For fp16: AMP optimization level selected in ['O0', 'O1', and 'O2']. | ||
fp16_opt_level: str = 'O0' | ||
|
||
# Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA | ||
# architecture or using CPU (no_cuda). This is an experimental API and it may change. | ||
bf16: bool = False | ||
|
||
# The value of initial scale_loss for fp16. | ||
scale_loss: float = 1024.0 | ||
|
||
|
||
# ========================================================= | ||
# dist args | ||
# ========================================================= | ||
# Whether to read local rank from ENVVAR | ||
use_env: bool = True | ||
|
||
# Communication backend for distributed training on gpus | ||
dist_backend: str = "nccl" | ||
|
||
local_rank: int = -1 | ||
|
||
|
||
# ========================================================= | ||
# lr_scheduler args | ||
# ========================================================= | ||
# initial learning rate | ||
learning_rate: float = 0.0001 | ||
|
||
# Minimum learning rate deacyed to. | ||
min_learning_rate : float = 1e-05 | ||
|
||
# Linear warmup over warmup_ratio fraction of total steps. | ||
warmup_ratio: float = 0.01 | ||
|
||
# Linear warmup over warmup_steps. | ||
warmup_steps: int = 0 | ||
|
||
# weight decay coefficient for L2 regularization | ||
weight_decay: float = 0.01 | ||
|
||
# The scheduler type to use. suppor linear, cosine, constant, constant_with_warmup | ||
lr_scheduler_type: str = "linear" | ||
|
||
|
||
# ========================================================= | ||
# optimizer args | ||
# ========================================================= | ||
# Beta1 for AdamW optimizer | ||
adam_beta1: float = 0.9 | ||
|
||
# Beta2 for AdamW optimizer | ||
adam_beta2: float = 0.999 | ||
|
||
# Epsilon for AdamW optimizer. | ||
adam_epsilon: float = 1e-8 | ||
|
||
# Max gradient norm. | ||
max_grad_norm: float = 1.0 | ||
|
||
|
||
# ========================================================= | ||
# load and save args | ||
# ========================================================= | ||
# Path to a directory containing a model checkpoint. | ||
output_dir: str = "llama-paddle/output" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请仿照其他case的格式书写一下llama1_7B预训练任务,包括模型描述、数据集下载、数据集处理脚本、代码来源开源协议等