Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLM-paddle] add llama1-7b pretrain with callback #239

Merged
merged 15 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix llama1-7b fig files and trainer
fix llama1-7b docker run cmd
modify docker paddle version
  • Loading branch information
LaiXinyi823 committed Sep 7, 2023
commit 0678669aee623396ff1e18e314abe3390f8e2ed9
16 changes: 1 addition & 15 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,4 @@ __pycache__/
.pytest_cache
training/result/*
inference/result/*
inference/onnxs/*

training/benchmarks/llama/memory_profiler_train_create.log
training/benchmarks/llama/run.sh
training/benchmarks/llama/paddle/data
training/benchmarks/llama/paddle/dataloaders/sentencepiece.bpe.model
training/=2.0.0rc
training/creat.sh
training/benchmarks/llama/paddle/proxy.sh
training/benchmarks/llama/paddle/run.sh
training/benchmarks/llama/paddle/creat.sh
training/benchmarks/llama/paddle/test.sh
training/benchmarks/llama/paddle/creat.sh
training/benchmarks/llama/paddle/test.sh
start_paddle_task_1.py
inference/onnxs/*
23 changes: 18 additions & 5 deletions training/benchmarks/driver/dist_paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,20 @@ def init_dist_training_env(config):
if dist.get_world_size() <= 1:
config.device = paddle.device.get_device()
config.world_size = get_world_size()
config.dataset_world_size = get_world_size()
else:
dist.init_parallel_env()
config.device = paddle.device.get_device()
config.world_size = get_world_size()
config.dataset_world_size = get_world_size()
if config.sharding:
strategy = fleet.DistributedStrategy()
hybrid_configs = {
"dp_degree": 1,
"sharding_degree": config.world_size,
}
strategy.hybrid_configs = hybrid_configs
fleet.init(is_collective=True, strategy=strategy)
print('------------------------')
print('device numbers:', config.world_size)
print('the processing uses', config.device)
Expand Down Expand Up @@ -97,18 +107,21 @@ def format_step(step):
s += "Validation Iteration: {} ".format(step[2])
return s

def all_gather(tensor_list, tensor):
return dist.all_gather(tensor_list, tensor)
def _nested_gather(tr_loss):
tr_log_losses = []
dist.all_gather(tr_log_losses, tr_loss)
tr_log_losses = [t if len(t.shape) > 0 else t.reshape_([-1]) for t in tr_log_losses]
concat = paddle.concat(tr_log_losses, axis=0)
return concat

def fused_allreduce_gradients(params):
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
return fused_allreduce_gradients(params, None)

def group_sharded_parallel(model, optimizer, level, scaler):
return dist.sharding.group_sharded_parallel(model, optimizer, level, scaler=scaler)
def group_sharded_parallel(model, optimizer, level, scaler, **extra_kwargs):
return dist.sharding.group_sharded_parallel(model, optimizer, level, scaler=scaler, **extra_kwargs)

def get_data_parallel_group():
fleet.init()
hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
return dp_group
7 changes: 2 additions & 5 deletions training/benchmarks/llama/paddle/config/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,6 @@
# full core_attn
recompute_granularity: str = "full"

# virtual_pp_degree
virtual_pp_degree: int = 1

# Pre-training from existing paddlenlp model weights. Default Fasle and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddlenlp models.
continue_training: bool = False

Expand Down Expand Up @@ -152,9 +149,9 @@
# fp16 config args
# =========================================================
# Run model in fp16 mode
fp16: bool = True
amp: bool = True

fp16_opt_level = 'O2'
amp_opt_level = 'O2'

bf16: bool = False

Expand Down
2 changes: 1 addition & 1 deletion training/benchmarks/llama/paddle/config/mutable_params.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

mutable_params = [
"split", "max_seq_length", "per_device_train_batch_size", "per_device_eval_batch_size",
"use_flash_attention", "use_fused_rms_norm", "fp16", "fp16_opt_level", "gradient_accumulation_steps",
"use_flash_attention", "use_fused_rms_norm", "amp", "amp_opt_level", "gradient_accumulation_steps",
"max_steps", "eval_steps", "learning_rate", "min_learning_rate", "weight_decay", "warmup_ratio",
"seed", "sharding", "use_recompute"
]
Expand Down
16 changes: 0 additions & 16 deletions training/benchmarks/llama/paddle/dataloaders/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from .dataset import GPTDataset
from paddlenlp.utils.log import logger
from paddle.io import DataLoader, DistributedBatchSampler
from icecream import ic

def get_train_data_file(config):
input_dir = os.path.join(config.base_path, config.data_dir)
Expand Down Expand Up @@ -163,14 +162,6 @@ def _collate_data(data, stack_fn=Stack()):
return train_dataset, valid_dataset, test_dataset, _collate_data

def _get_train_sampler(config,train_dataset) -> Optional[paddle.io.Sampler]:
if config.world_size <= 1:
return paddle.io.BatchSampler(
dataset=train_dataset,
shuffle=False,
batch_size=config.per_device_train_batch_size,
drop_last=config.dataloader_drop_last,
)

return DistributedBatchSampler(
train_dataset,
batch_size=config.per_device_train_batch_size,
Expand All @@ -181,13 +172,6 @@ def _get_train_sampler(config,train_dataset) -> Optional[paddle.io.Sampler]:
)

def _get_eval_sampler(config, eval_dataset):
if config.world_size <= 1:
return paddle.io.BatchSampler(
eval_dataset,
batch_size=config.per_device_eval_batch_size,
shuffle=False,
drop_last=False,
)
return DistributedBatchSampler(
eval_dataset,
num_replicas=config.world_size,
Expand Down
2 changes: 0 additions & 2 deletions training/benchmarks/llama/paddle/dataloaders/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import paddle
from paddle.io import DataLoader

from icecream import ic

def construct_samples_and_shuffle_data(
name, data_prefix, documents, sizes, num_samples, seq_length, seed, build_data_file
):
Expand Down
21 changes: 1 addition & 20 deletions training/benchmarks/llama/paddle/dataloaders/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,7 @@ class LlamaTokenizer(PretrainedTokenizer):
resource_files_names = {
"vocab_file": "sentencepiece.bpe.model",
}
pretrained_resource_files_map = {
"vocab_file": {
"__internal_testing__/micro-random-llama": "https://bj.bcebos.com/paddlenlp/models/transformers/llama/sentencepiece.bpe.model",
"__internal_testing__/tiny-random-llama": "https://bj.bcebos.com/paddlenlp/models/transformers/llama/sentencepiece.bpe.model",
"facebook/llama-7b": "https://bj.bcebos.com/paddlenlp/models/transformers/llama/sentencepiece.bpe.model",
"facebook/llama-13b": "https://bj.bcebos.com/paddlenlp/models/transformers/llama/sentencepiece.bpe.model",
"facebook/llama-30b": "https://bj.bcebos.com/paddlenlp/models/transformers/llama/sentencepiece.bpe.model",
"facebook/llama-65b": "https://bj.bcebos.com/paddlenlp/models/transformers/llama/sentencepiece.bpe.model",
},
}

pretrained_init_configuration = {
"__internal_testing__/micro-random-llama": {},
"__internal_testing__/tiny-random-llama": {},
"facebook/llama-7b": {},
"facebook/llama-13b": {},
"facebook/llama-30b": {},
"facebook/llama-65b": {},
}


def __init__(
self,
vocab_file,
Expand Down
17 changes: 8 additions & 9 deletions training/benchmarks/llama/paddle/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from model.models.modeling import LlamaConfig, LlamaForCausalLM
from memory_profiler import profile
from contextlib import contextmanager
import paddle
# @profile(precision=4, stream=open("memory_profiler_train_create.log", "w+"))
Expand All @@ -21,7 +20,7 @@ def create_model(config):
use_cache=config.use_cache,
use_recompute=config.use_recompute,
use_flash_attention=config.use_flash_attention,
fp16_opt_level=config.fp16_opt_level)
fp16_opt_level=config.amp_opt_level)
@contextmanager
def dtype_guard(dtype="float32"):
origin_dtype = paddle.get_default_dtype()
Expand All @@ -31,13 +30,13 @@ def dtype_guard(dtype="float32"):
finally:
paddle.set_default_dtype(origin_dtype)

# if config.fp16:
# dtype = "float16"
# else:
# dtype = "float32"
if config.amp:
dtype = "float16"
else:
dtype = "float32"

# with dtype_guard(dtype):
# model = LlamaForCausalLM(llama_config)
with dtype_guard(dtype):
model = LlamaForCausalLM(llama_config)

model = LlamaForCausalLM(llama_config)
# model = LlamaForCausalLM(llama_config)
return llama_config, model
10 changes: 5 additions & 5 deletions training/benchmarks/llama/paddle/model/models/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ def __init__(
recompute_granularity="full",
use_flash_attention=False,
use_fused_rms_norm=False,
tensor_parallel_output=True,
tensor_parallel_output=False,
lm_shift_labels=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
fp16_opt_level="O0",
amp_opt_level="O0",
**kwargs,
):
super().__init__(
Expand All @@ -138,7 +138,7 @@ def __init__(
self.use_fused_rms_norm = use_fused_rms_norm
self.tensor_parallel_output = tensor_parallel_output
self.lm_shift_labels = lm_shift_labels
self.fp16_opt_level = fp16_opt_level
self.amp_opt_level = amp_opt_level

self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
Expand Down Expand Up @@ -250,7 +250,7 @@ def scaled_dot_product_attention(
attn_weights, paddle.full([1], float(finfo(query_states.dtype).min), dtype=attn_weights.dtype)
)

if config.fp16_opt_level is not None:
if config.amp_opt_level is not None:
with paddle.amp.auto_cast(False):
attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype)
else:
Expand Down Expand Up @@ -318,7 +318,7 @@ def forward(self, hidden_states):
if self.config.use_fused_rms_norm:
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)

if self.config.fp16_opt_level is not None:
if self.config.amp_opt_level is not None:
with paddle.amp.auto_cast(False):
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
Expand Down
8 changes: 4 additions & 4 deletions training/benchmarks/llama/paddle/optimizers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ def create_optimizer(config, model: nn.Layer, lr_scheduler):
def apply_decay_param_fun(x):
return x in decay_parameters


optimizer = AdamW(parameters=parameter_list,
learning_rate=lr_scheduler,
apply_decay_param_fun=apply_decay_param_fun,
beta1=config.adam_beta1,
beta2=config.adam_beta2,
epsilon=config.adam_epsilon,
weight_decay=config.weight_decay,
grad_clip=nn.ClipGradByGlobalNorm(config.max_grad_norm)
if config.max_grad_norm > 0
else None,
grad_clip=None,
# grad_clip=nn.ClipGradByGlobalNorm(config.max_grad_norm)
# if config.max_grad_norm > 0
# else None,
multi_precision=True
)
return optimizer
28 changes: 7 additions & 21 deletions training/benchmarks/llama/paddle/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ def main():
trainer.init()

dist_paddle.barrier()
init_evaluation_start = time.time()
training_state.eval_avg_loss = evaluator.evaluate(trainer, config)
init_evaluation_end = time.time()
init_evaluation_info = dict(
eval_loss=training_state.eval_avg_loss,
time=init_evaluation_end - init_evaluation_start)
llama_driver.event(Event.INIT_EVALUATION, init_evaluation_info)
# init_evaluation_start = time.time()
# training_state.eval_avg_loss = evaluator.evaluate(trainer, config)
# init_evaluation_end = time.time()
# init_evaluation_info = dict(
# eval_loss=training_state.eval_avg_loss,
# time=init_evaluation_end - init_evaluation_start)
# llama_driver.event(Event.INIT_EVALUATION, init_evaluation_info)

if not config.do_train:
return config, training_state
Expand All @@ -110,7 +110,6 @@ def main():
training_state.raw_train_time = time.time() - train_start_time

return config, training_state, trainer.tr_loss
# return None, None, None

if __name__ == "__main__":
now = time.time()
Expand All @@ -134,16 +133,3 @@ def main():
else:
finished_info = {"e2e_time": e2e_time}
logger.log(Event.FINISHED, message=finished_info, stacklevel=0)

# 可视化 loss
# ic(trainer.tr_loss)
# print(sum(tr_loss) / len(tr_loss))

# plt.switch_backend('Agg')

# plt.figure()
# plt.plot(tr_loss,'b',label = 'loss')
# plt.ylabel('loss')
# plt.xlabel('perf_step')
# plt.legend()
# plt.savefig("./step_loss_dp.jpg")
1 change: 0 additions & 1 deletion training/benchmarks/llama/paddle/schedulers/factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from paddlenlp.transformers import CosineAnnealingWithWarmupDecay, LinearAnnealingWithWarmupDecay
from icecream import ic

def create_scheduler(config):
if config.decay_steps is None:
Expand Down
Loading