generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
dpo_online.py
123 lines (110 loc) · 4.32 KB
/
dpo_online.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
python examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-1b-tldr-online-dpo \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 16 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0
With LoRA:
python examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-6 \
--output_dir pythia-1b-tldr-online-dpo \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 8 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--use_peft
"""
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig
from trl import (
DPOScriptArguments,
LogCompletionsCallback,
ModelConfig,
OnlineDPOConfig,
OnlineDPOTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
if __name__ == "__main__":
parser = TrlParser((DPOScriptArguments, OnlineDPOConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()
script_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)
reward_model = AutoModelForSequenceClassification.from_pretrained(
training_args.reward_model_path,
num_labels=1,
trust_remote_code=model_config.trust_remote_code,
**model_kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
trust_remote_code=model_config.trust_remote_code,
**model_kwargs,
)
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
dataset = load_dataset(script_args.dataset_name)
trainer = OnlineDPOTrainer(
model=model,
reward_model=reward_model,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split],
processing_class=tokenizer,
peft_config=get_peft_config(model_config),
)
generation_config = GenerationConfig(
max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
)
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
trainer.add_callback(completions_callback)
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)