Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llm]support pad to max_length & fix sp bug #9040

Merged
merged 4 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion llm/alignment/dpo/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
AutoTokenizer,
LlamaForCausalLM,
LlamaForCausalLMPipe,
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.trl import (
DPOTrainer,
Expand Down Expand Up @@ -138,7 +139,10 @@ def main():

if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
raise NotImplementedError(f"{model.__class__} not support flash mask.")

if training_args.sequence_parallel:
register_sequence_parallel_allreduce_hooks(
model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce
)
if model_args.tokenizer_name_or_path is not None:
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
else:
Expand Down
8 changes: 7 additions & 1 deletion llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
LlamaForCausalLM,
LlamaForCausalLMPipe,
LlamaTokenizer,
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
from paddlenlp.utils.llm_utils import (
Expand Down Expand Up @@ -197,7 +198,10 @@ def neft_post_hook(module, input, output):
neft_post_hook_handle = model.get_input_embeddings().register_forward_post_hook(neft_post_hook)
else:
raise NotImplementedError("Only support neftune for model with get_input_embeddings")

if training_args.sequence_parallel:
register_sequence_parallel_allreduce_hooks(
model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce
)
# Load tokenizer & dataset
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, from_aistudio=model_args.from_aistudio)
# init chat_template for tokenizer
Expand Down Expand Up @@ -522,6 +526,8 @@ def compute_metrics_do_generation(eval_preds):
training_args.pipeline_parallel_degree > 1
or training_args.sequence_parallel
or training_args.autotuner_benchmark
or data_args.zero_padding
or data_args.pad_to_max_length
):
# NOTE(gongenlei): new add autotuner_benchmark
max_length = data_args.max_length
Expand Down
4 changes: 4 additions & 0 deletions llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ class DataArgument:
"help": "@deprecated Please use `zero_padding`. Whether to use InTokens data stream, same as `zero_padding`."
},
) # Alias for zero_padding
pad_to_max_length: bool = field(
default=False,
metadata={"help": "Pad the input sequence to `max_length`."},
)

def __post_init__(self):
if self.task_name_or_path is not None:
Expand Down
Loading