Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸ•ΉοΈ CLI refactor #2380

Merged
merged 53 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
c77503a
Refactor main function in dpo.py
qgallouedec Nov 21, 2024
66254d8
Update setup.py and add cli.py
qgallouedec Nov 21, 2024
49f34d1
Add examples to package data
qgallouedec Nov 21, 2024
35ee4e6
style
qgallouedec Nov 21, 2024
09e1257
Refactor setup.py file
qgallouedec Nov 21, 2024
4ccf137
Add new file t.py
qgallouedec Nov 21, 2024
d397ab8
Move dpo to package
qgallouedec Nov 21, 2024
dad9ecc
Update MANIFEST.in and setup.py, refactor trl/cli.py
qgallouedec Nov 21, 2024
3024d5b
Add __init__.py to trl/scripts directory
qgallouedec Nov 21, 2024
60583c2
Add license header to __init__.py
qgallouedec Nov 21, 2024
e5ec9e7
Merge branch 'main' into cli-refactor
qgallouedec Nov 22, 2024
5eb1adf
File moved instruction
qgallouedec Nov 22, 2024
5f37d46
Merge branch 'cli-refactor' of https://github.com/huggingface/trl int…
qgallouedec Nov 22, 2024
793ce44
Add Apache License and update file path
qgallouedec Nov 22, 2024
3713056
Merge branch 'main' into cli-refactor
qgallouedec Nov 24, 2024
1c6261b
Merge branch 'main' into cli-refactor
qgallouedec Nov 25, 2024
bf27b36
Move dpo.py to new location
qgallouedec Nov 25, 2024
30e0e50
Merge branch 'cli-refactor' of https://github.com/huggingface/trl int…
qgallouedec Nov 25, 2024
adac644
Refactor CLI and DPO script
qgallouedec Nov 25, 2024
5c5a254
Merge branch 'main' into cli-refactor
qgallouedec Nov 26, 2024
923ba0c
Merge branch 'main' into cli-refactor
qgallouedec Nov 28, 2024
a15a41c
Refactor import structure in scripts package
qgallouedec Nov 28, 2024
f0a20b2
Merge branch 'main' into cli-refactor
qgallouedec Dec 4, 2024
7a0a4f0
env
qgallouedec Dec 4, 2024
167f23f
rm config from chat arg
qgallouedec Dec 4, 2024
084e33a
rm old cli
qgallouedec Dec 4, 2024
70dd253
chat init
qgallouedec Dec 4, 2024
972f7c6
test cli [skip ci]
qgallouedec Dec 5, 2024
1386d41
Add `datast_config_name` to `ScriptArguments` (#2440)
qgallouedec Dec 5, 2024
bf289d8
add missing arg
qgallouedec Dec 5, 2024
d811b1b
Add test cases for 'trl sft' and 'trl dpo' commands
qgallouedec Dec 5, 2024
61706af
Add sft.py script and update cli.py to include sft command
qgallouedec Dec 5, 2024
d9094e2
Move sft script
qgallouedec Dec 5, 2024
7d2e62c
chat
qgallouedec Dec 5, 2024
d468545
style [ci skip]
qgallouedec Dec 5, 2024
93d423c
kto
qgallouedec Dec 5, 2024
9ee485a
rm example config
qgallouedec Dec 5, 2024
5f86e61
first step on doc
qgallouedec Dec 5, 2024
779062b
see #2442
qgallouedec Dec 5, 2024
0892264
see #2443
qgallouedec Dec 5, 2024
746baec
fix chat windows
qgallouedec Dec 5, 2024
2fc0b6f
©️ Copyrights update (#2454)
qgallouedec Dec 10, 2024
6941e0f
πŸ’¬ Fix chat for windows (#2443)
qgallouedec Dec 10, 2024
b202b15
πŸ†” Add `datast_config` to `ScriptArguments` (#2440)
qgallouedec Dec 10, 2024
2401463
🏎 Fix deepspeed preparation of `ref_model` in `OnlineDPOTrainer` (#2417)
qgallouedec Dec 10, 2024
be0ca9b
Merge branch 'main' into cli-refactor
qgallouedec Dec 10, 2024
c0209f9
Fix config name
qgallouedec Dec 10, 2024
cbff826
Merge branch 'main' into cli-refactor
qgallouedec Dec 13, 2024
0263435
Remove `make dev` in favor of `pip install -e .[dev]`
qgallouedec Dec 13, 2024
590afa0
Merge branch 'cli-refactor' of https://github.com/huggingface/trl int…
qgallouedec Dec 13, 2024
98458a0
Update script paths and remove old symlink related things
qgallouedec Dec 13, 2024
65f31f6
Fix chat script path [ci skip]
qgallouedec Dec 13, 2024
3a3be53
style
qgallouedec Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add sft.py script and update cli.py to include sft command
  • Loading branch information
qgallouedec committed Dec 5, 2024
commit 61706afe388f41402c54be876e566b62e11caaa3
11 changes: 11 additions & 0 deletions trl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from .scripts.dpo import make_parser as make_dpo_parser
from .scripts.env import print_env
from .scripts.sft import make_parser as make_sft_parser
from .scripts.utils import TrlParser


Expand All @@ -32,6 +33,7 @@ def main():
# Add the subparsers for every script
make_dpo_parser(subparsers)
subparsers.add_parser("env", help="Print the environment information")
make_sft_parser(subparsers)

# Parse the arguments
args = parser.parse_args()
Expand All @@ -48,6 +50,15 @@ def main():
elif args.command == "env":
print_env()

elif args.command == "sft":
# Get the default args for the launch command
sft_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "sft.py")
args = launch_command_parser().parse_args([sft_training_script])

# Feed the args to the launch command
args.training_script_args = sys.argv[2:] # remove "trl" and "sft"
launch_command(args) # launch training


if __name__ == "__main__":
main()
125 changes: 125 additions & 0 deletions trl/scripts/sft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# Full training
python examples/scripts/sft.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/Capybara \
--learning_rate 2.0e-5 \
--num_train_epochs 1 \
--packing \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--logging_steps 25 \
--eval_strategy steps \
--eval_steps 100 \
--output_dir Qwen2-0.5B-SFT \
--push_to_hub

# LoRA
python examples/scripts/sft.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/Capybara \
--learning_rate 2.0e-4 \
--num_train_epochs 1 \
--packing \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--logging_steps 25 \
--eval_strategy steps \
--eval_steps 100 \
--use_peft \
--lora_r 32 \
--lora_alpha 16 \
--output_dir Qwen2-0.5B-SFT \
--push_to_hub
"""

import argparse

from datasets import load_dataset
from transformers import AutoTokenizer

from trl import (
ModelConfig,
ScriptArguments,
SFTConfig,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)


def main(script_args, training_args, model_config):
################
# Model init kwargs & Tokenizer
################
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=model_config.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
)
tokenizer.pad_token = tokenizer.eos_token

################
# Dataset
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name)

################
# Training
################
trainer = SFTTrainer(
model=model_config.model_name_or_path,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_config),
)

trainer.train()

# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)


def make_parser(subparsers: argparse._SubParsersAction = None):
dataclass_types = (ScriptArguments, SFTConfig, ModelConfig)
if subparsers is not None:
parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
else:
parser = TrlParser(dataclass_types)
return parser


if __name__ == "__main__":
parser = make_parser()
script_args, training_args, model_config = parser.parse_args_and_config()
main(script_args, training_args, model_config)
Loading