Skip to content

Commit

Permalink
[LLM][TRL] Support DPO with Pipeline Parallel (#9039)
Browse files Browse the repository at this point in the history
* support dpo/kto pp
  • Loading branch information
lugimzzz authored Sep 20, 2024
1 parent 8212b53 commit bc55104
Show file tree
Hide file tree
Showing 46 changed files with 2,241 additions and 758 deletions.
94 changes: 83 additions & 11 deletions llm/alignment/dpo/dpo_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Optional

from paddlenlp.trainer import TrainingArguments
from paddlenlp.trainer.trainer_utils import IntervalStrategy


def add_start_docstrings(*docstr):
Expand All @@ -42,9 +43,66 @@ class DPOTrainingArguments(TrainingArguments):
default="",
metadata={"help": "Configs to unify hybrid parallel checkpoint.\n"},
)
dpo_beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
dpo_label_smoothing: float = field(default=0.0, metadata={"help": "label_smoothing ratio"})
dpo_loss_type: str = field(default="sigmoid", metadata={"help": "DPO loss type"})
autotuner_benchmark: bool = field(
default=False,
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
)
benchmark: bool = field(
default=False,
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
)

def __post_init__(self):
super().__post_init__()
if self.autotuner_benchmark:
self.num_train_epochs = 1
self.max_steps = 5
self.do_train = True
self.do_export = False
self.do_predict = False
self.do_eval = False
self.overwrite_output_dir = True
self.load_best_model_at_end = False
self.report_to = []
self.save_strategy = IntervalStrategy.NO
self.evaluation_strategy = IntervalStrategy.NO
if not self.disable_tqdm:
self.logging_steps = 1
self.logging_strategy = IntervalStrategy.STEPS
if self.benchmark:
self.do_train = True
self.do_export = False
self.do_predict = False
self.do_eval = False
self.overwrite_output_dir = True
self.load_best_model_at_end = False
self.save_strategy = IntervalStrategy.NO
self.evaluation_strategy = IntervalStrategy.NO
if not self.disable_tqdm:
self.logging_steps = 1
self.logging_strategy = IntervalStrategy.STEPS
if self.max_steps > 0:
self.num_train_epochs = 1


@dataclass
class DPOConfig:
"""DPOConfig"""

beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
simpo_gamma: float = field(default=0.5, metadata={"help": "the gamma parameter for SimPO loss"})
normalize_logps: bool = field(
default=True,
metadata={"help": "Apply logprobs normalization."},
)
label_smoothing: float = field(default=0.0, metadata={"help": "label_smoothing ratio"})
loss_type: str = field(default="sigmoid", metadata={"help": "DPO loss type"})
pref_loss_ratio: float = field(default=1.0, metadata={"help": "DPO loss ratio"})
sft_loss_ratio: float = field(default=0.0, metadata={"help": "SFT loss ratio"})
dpop_lambda: float = field(default=50, metadata={"help": "SFT loss ratio"})
ref_model_update_steps: int = field(default=-1, metadata={"help": "Update ref model state dict "})
reference_free: bool = field(default=False, metadata={"help": "No reference model."})
lora: bool = field(default=False, metadata={"help": "Use LoRA model."})


@dataclass
Expand All @@ -55,18 +113,16 @@ class DPODataArgument:
dev_dataset_path: str = field(default="./data/dev.jsonl", metadata={"help": "Path to the dev dataset dir."})
max_seq_len: int = field(default=4096, metadata={"help": "Maximum sequence length."})
max_prompt_len: int = field(default=2048, metadata={"help": "Maximum prompt length."})
autotuner_benchmark: bool = field(
default=False,
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
)
benchmark: bool = field(
default=False,
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
)
greedy_zero_padding: bool = field(
default=False,
metadata={"help": "Whether to use Greedy Zero Padding data stream."},
)
lazy: bool = field(
default=False,
metadata={
"help": "Weather to return `MapDataset` or an `IterDataset`.True for `IterDataset`. False for `MapDataset`."
},
)


@dataclass
Expand Down Expand Up @@ -95,3 +151,19 @@ class DPOModelArgument:
default=False,
metadata={"help": "whether to use sequence parallel"},
)
tensor_parallel_output: bool = field(
default=True,
metadata={"help": "whether to use tensor_parallel_output"},
)
weight_quantize_algo: str = field(
default=None,
metadata={"help": "Model weight quantization algorithm including 'nf4'(qlora), 'weight_only_int8'."},
)
# LoRA
lora_rank: int = field(default=8, metadata={"help": "Lora rank."})
lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."})
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"})
lora_alpha: int = field(default=-1, metadata={"help": "lora_alpha"})
rslora_plus: bool = field(default=False, metadata={"help": "Strengthen lora performance"})
use_quick_lora: bool = field(default=True, metadata={"help": "quick lora"})
159 changes: 103 additions & 56 deletions llm/alignment/dpo/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,24 @@
from functools import partial

import paddle
from dpo_argument import DPODataArgument, DPOModelArgument, DPOTrainingArguments

from paddlenlp.datasets import ZeroPaddingMapDataset, load_dataset
from paddlenlp.trainer import (
IntervalStrategy,
PdArgumentParser,
get_last_checkpoint,
set_seed,
from dpo_argument import (
DPOConfig,
DPODataArgument,
DPOModelArgument,
DPOTrainingArguments,
)

from paddlenlp.datasets import (
ZeroPaddingIterableDataset,
ZeroPaddingMapDataset,
load_dataset,
)
from paddlenlp.peft import LoRAConfig, LoRAModel
from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed
from paddlenlp.transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForCausalLMPipe,
AutoTokenizer,
LlamaForCausalLM,
LlamaForCausalLMPipe,
Expand All @@ -43,47 +49,34 @@
preference_collate_fn,
preprocess_preference_data,
)
from paddlenlp.utils.llm_utils import get_lora_target_modules
from paddlenlp.utils.log import logger

flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe]


def main():
"""main"""
parser = PdArgumentParser((DPOModelArgument, DPODataArgument, DPOTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
parser = PdArgumentParser((DPOModelArgument, DPODataArgument, DPOTrainingArguments, DPOConfig))
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args, dpo_config = parser.parse_json_file_and_cmd_lines()
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")
if training_args.max_steps > 0:
training_args.num_train_epochs = 1
if data_args.autotuner_benchmark:
training_args.num_train_epochs = 1
training_args.max_steps = 5
training_args.do_train = True
training_args.do_export = False
training_args.do_predict = False
training_args.do_eval = False
training_args.overwrite_output_dir = True
training_args.load_best_model_at_end = False
training_args.report_to = []
training_args.save_strategy = IntervalStrategy.NO
training_args.evaluation_strategy = IntervalStrategy.NO
if data_args.benchmark:
training_args.do_train = True
training_args.do_export = False
training_args.do_predict = False
training_args.do_eval = False
training_args.overwrite_output_dir = True
training_args.load_best_model_at_end = False
training_args.save_strategy = IntervalStrategy.NO
training_args.evaluation_strategy = IntervalStrategy.NO
model_args, data_args, training_args, dpo_config = parser.parse_args_into_dataclasses()

paddle.set_device(training_args.device)
set_seed(training_args.seed)
if dpo_config.loss_type == "orpo":
dpo_config.reference_free = True
dpo_config.sft_loss_ratio = 1.0
dpo_config.loss_type = "or"
logger.info("orpo loss_type is equal to sft_loss + pref_loss_ratio * or_loss.")
if dpo_config.loss_type in ["or", "simpo"] and not dpo_config.reference_free:
dpo_config.reference_free = True
logger.warning(f"{dpo_config.loss_type} loss_type only supports reference_free. Set reference_free to True.")

training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")
training_args.print_config(dpo_config, "DPOConfig")

logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: "
Expand Down Expand Up @@ -116,51 +109,102 @@ def main():
tensor_parallel_rank=training_args.tensor_parallel_rank,
recompute_granularity=model_args.recompute_granularity,
use_flash_attention=model_args.use_flash_attention,
tensor_parallel_output=True,
tensor_parallel_output=model_args.tensor_parallel_output,
)
if training_args.pipeline_parallel_degree > 1:
raise ValueError("DPO does not support pipeline parallelism yet.")

if not data_args.autotuner_benchmark:
ref_model = AutoModelForCausalLM.from_pretrained(**model_kwargs)
config = AutoConfig.from_pretrained(**model_kwargs)
model = AutoModelForCausalLM.from_config(config)
model.set_state_dict(ref_model.state_dict())
if training_args.pipeline_parallel_degree > 1:
model_class = AutoModelForCausalLMPipe
else:
model_class = AutoModelForCausalLM
if not training_args.autotuner_benchmark or model_args.weight_quantize_algo is not None:
model = model_class.from_pretrained(**model_kwargs)
# for DPO save
model.config.dpo_config = None
if not dpo_config.reference_free and not dpo_config.lora:
config = AutoConfig.from_pretrained(**model_kwargs)
ref_model = model_class.from_config(config, dtype=dtype)
ref_model.set_state_dict(model.state_dict())
else:
ref_model = None
else:
config = AutoConfig.from_pretrained(**model_kwargs)
model = AutoModelForCausalLM.from_config(config)
ref_config = AutoConfig.from_pretrained(**model_kwargs)
ref_model = AutoModelForCausalLM.from_config(ref_config)
model.set_state_dict(ref_model.state_dict())
model = model_class.from_config(config, dtype=dtype)
if not dpo_config.reference_free and not dpo_config.lora:
ref_model = model_class.from_config(config, dtype=dtype)
else:
ref_model = None

if model_args.flash_mask and not model.config.use_flash_attention:
logger.warning("`flash_mask` must use with zero padding and flash attention.")
model.config.use_flash_attention = True

if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
raise NotImplementedError(f"{model.__class__} not support flash mask.")
if training_args.sequence_parallel:

if model_args.sequence_parallel:
register_sequence_parallel_allreduce_hooks(
model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce
model, training_args.gradient_accumulation_steps, model_args.fuse_sequence_parallel_allreduce
)
if model_args.tokenizer_name_or_path is not None:
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
else:
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
# TODO: support chat template in next pr
# tokenizer.chat_template = None
tokenizer.chat_template = None
logger.info("Loading model & tokenizer successfully !")

if dpo_config.lora:
if training_args.sharding_parallel_degree > 1:
assert (
"enable_stage1_overlap" not in training_args.sharding_parallel_config
), "Currently not support enabling sharding_stage1_overlap in lora mode."
if model_args.lora_path is None:
target_modules = get_lora_target_modules(model)
if model_args.rslora_plus:
model_args.rslora = True
model_args.lora_plus_scale = 4
model_args.lora_alpha = 4
if model_args.weight_quantize_algo is not None:
if model_args.rslora or model_args.lora_plus_scale != 1.0:
logger.info("Weight quantization is not supported in LoRA+ and RsLoRA.")
if model_args.lora_alpha == -1:
if model_args.rslora:
model_args.lora_alpha = 4
else:
model_args.lora_alpha = 2 * model_args.lora_rank
lora_config = LoRAConfig(
target_modules=target_modules,
r=model_args.lora_rank,
lora_alpha=2 * model_args.lora_rank if not model_args.rslora else 4,
rslora=model_args.rslora,
lora_plus_scale=model_args.lora_plus_scale,
tensor_parallel_degree=training_args.tensor_parallel_degree,
dtype=dtype,
base_model_name_or_path=model_args.model_name_or_path,
use_quick_lora=model_args.use_quick_lora,
)
model = LoRAModel(model, lora_config)
else:
model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path)

model.print_trainable_parameters()

logger.info("Start to create dataset")
trans_func = partial(preprocess_preference_data, tokenizer=tokenizer, data_args=data_args, model_args=model_args)
if data_args.lazy:
zero_padding_dataset = ZeroPaddingIterableDataset
else:
zero_padding_dataset = ZeroPaddingMapDataset
if training_args.do_train and training_args.should_load_dataset:
train_ds = load_dataset(
"json",
data_files=data_args.train_dataset_path,
lazy=data_args.lazy,
)[0]
logger.info("Creating train Zero Padding Data Stream. This may take a few minutes.")
train_ds = (
ZeroPaddingMapDataset(
zero_padding_dataset(
train_ds.map(trans_func),
tokenizer=tokenizer,
max_length=data_args.max_seq_len,
Expand All @@ -176,10 +220,11 @@ def main():
eval_ds = load_dataset(
"json",
data_files=data_args.dev_dataset_path,
lazy=data_args.lazy,
)[0]
logger.info("Creating dev Zero Padding Data Stream. This may take a few minutes.")
eval_ds = (
ZeroPaddingMapDataset(
zero_padding_dataset(
eval_ds.map(trans_func),
tokenizer=tokenizer,
max_length=data_args.max_seq_len,
Expand All @@ -194,6 +239,7 @@ def main():
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
dpo_config=dpo_config,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
Expand All @@ -202,17 +248,18 @@ def main():
preference_collate_fn,
max_seq_len=data_args.max_seq_len,
),
ignore_eos_token=True,
)

if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)

if not data_args.autotuner_benchmark and not data_args.benchmark:
if not training_args.autotuner_benchmark and not training_args.benchmark:
trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if data_args.benchmark:
if training_args.benchmark:
total_effective_tokens, total_tokens = calculate_effective_tokens(
training_args, train_ds, data_args.max_seq_len
)
Expand Down
1 change: 0 additions & 1 deletion llm/config/baichuan/dpo_argument.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"disable_tqdm": true,
"load_best_model_at_end": true,
"tensor_parallel_degree": 8,
"sharding_parallel_degree": 1,
"sharding": "stage1",
"use_flash_attention": true,
"recompute": false,
Expand Down
Loading

0 comments on commit bc55104

Please sign in to comment.