diff --git a/llm/alignment/dpo/run_dpo.py b/llm/alignment/dpo/run_dpo.py index 3945375aee43..17d8c7b8efa2 100644 --- a/llm/alignment/dpo/run_dpo.py +++ b/llm/alignment/dpo/run_dpo.py @@ -35,6 +35,7 @@ AutoTokenizer, LlamaForCausalLM, LlamaForCausalLMPipe, + register_sequence_parallel_allreduce_hooks, ) from paddlenlp.trl import ( DPOTrainer, @@ -138,7 +139,10 @@ def main(): if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list): raise NotImplementedError(f"{model.__class__} not support flash mask.") - + if training_args.sequence_parallel: + register_sequence_parallel_allreduce_hooks( + model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce + ) if model_args.tokenizer_name_or_path is not None: tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path) else: diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 96db6a6ad697..d084a910ff65 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -53,6 +53,7 @@ LlamaForCausalLM, LlamaForCausalLMPipe, LlamaTokenizer, + register_sequence_parallel_allreduce_hooks, ) from paddlenlp.transformers.configuration_utils import LlmMetaConfig from paddlenlp.utils.llm_utils import ( @@ -193,7 +194,10 @@ def neft_post_hook(module, input, output): neft_post_hook_handle = model.get_input_embeddings().register_forward_post_hook(neft_post_hook) else: raise NotImplementedError("Only support neftune for model with get_input_embeddings") - + if training_args.sequence_parallel: + register_sequence_parallel_allreduce_hooks( + model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce + ) # Load tokenizer & dataset tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, from_aistudio=model_args.from_aistudio) # init chat_template for tokenizer @@ -518,6 +522,8 @@ def compute_metrics_do_generation(eval_preds): training_args.pipeline_parallel_degree > 1 or training_args.sequence_parallel or training_args.autotuner_benchmark + or data_args.zero_padding + or data_args.pad_to_max_length ): # NOTE(gongenlei): new add autotuner_benchmark max_length = data_args.max_length diff --git a/llm/utils/argument.py b/llm/utils/argument.py index 2424f325ee1a..60b6f89b3377 100644 --- a/llm/utils/argument.py +++ b/llm/utils/argument.py @@ -137,6 +137,10 @@ class DataArgument: "help": "@deprecated Please use `zero_padding`. Whether to use InTokens data stream, same as `zero_padding`." }, ) # Alias for zero_padding + pad_to_max_length: bool = field( + default=False, + metadata={"help": "Pad the input sequence to `max_length`."}, + ) def __post_init__(self): if self.task_name_or_path is not None: