Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

📉 Add PEFT support for PPOTrainer #2344

Merged
merged 14 commits into from
Nov 18, 2024
38 changes: 34 additions & 4 deletions examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import shutil

import torch
from accelerate import PartialState
from datasets import load_dataset
from transformers import (
Expand All @@ -23,7 +24,15 @@
HfArgumentParser,
)

from trl import ModelConfig, PPOConfig, PPOTrainer, ScriptArguments
from trl import (
ModelConfig,
PPOConfig,
PPOTrainer,
ScriptArguments,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


Expand Down Expand Up @@ -67,6 +76,20 @@
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)

tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
Expand All @@ -81,12 +104,18 @@
reward_model = AutoModelForSequenceClassification.from_pretrained(
training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained(
training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
policy = AutoModelForCausalLM.from_pretrained(
training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code
)

peft_config = get_peft_config(model_config)
if peft_config is None:
ref_policy = AutoModelForCausalLM.from_pretrained(
training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
else:
ref_policy = None

################
# Dataset
################
Expand Down Expand Up @@ -131,6 +160,7 @@ def tokenize(element):
value_model=value_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
)
trainer.train()

Expand Down
52 changes: 43 additions & 9 deletions examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import shutil

import torch
from accelerate import PartialState
from datasets import load_dataset
from transformers import (
Expand All @@ -23,16 +24,24 @@
HfArgumentParser,
)

from trl import ModelConfig, PPOConfig, PPOTrainer, ScriptArguments
from trl import (
ModelConfig,
PPOConfig,
PPOTrainer,
ScriptArguments,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


"""
python examples/scripts/ppo/ppo_tldr.py \
--dataset_name trl-internal-testing/tldr-preference-sft-trl-style
--dataset_name trl-internal-testing/tldr-preference-sft-trl-style \
--dataset_test_split validation \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--output_dir models/minimal/ppo_tldr \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 64 \
--total_episodes 30000 \
Expand All @@ -41,11 +50,13 @@
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--missing_eos_penalty 1.0 \
--stop_token eos \
--response_length 53
--response_length 53 \
--eval_strategy steps \
--eval_steps 100

accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/ppo/ppo_tldr.py \
--dataset_name trl-internal-testing/tldr-preference-sft-trl-style
--dataset_name trl-internal-testing/tldr-preference-sft-trl-style \
--dataset_test_split validation \
--output_dir models/minimal/ppo_tldr \
--learning_rate 3e-6 \
Expand All @@ -57,7 +68,9 @@
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--missing_eos_penalty 1.0 \
--stop_token eos
--stop_token eos \
--eval_strategy steps \
--eval_steps 100
"""


Expand All @@ -70,6 +83,20 @@
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)

tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
Expand All @@ -84,12 +111,18 @@
reward_model = AutoModelForSequenceClassification.from_pretrained(
training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained(
training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
policy = AutoModelForCausalLM.from_pretrained(
training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code
)

peft_config = get_peft_config(model_config)
if peft_config is None:
ref_policy = AutoModelForCausalLM.from_pretrained(
training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
else:
ref_policy = None

################
# Dataset
################
Expand Down Expand Up @@ -138,6 +171,7 @@ def tokenize(element):
value_model=value_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
)
trainer.train()

Expand Down
33 changes: 33 additions & 0 deletions tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import platform
import subprocess

from transformers.testing_utils import require_peft


def test():
command = """\
Expand Down Expand Up @@ -65,3 +67,34 @@ def test_num_train_epochs():
shell=True,
check=True,
)


@require_peft
def test_peft_support():
command = """\
python examples/scripts/ppo/ppo.py \
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
--dataset_train_split descriptiveness \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 1 \
--total_episodes 10 \
--model_name_or_path EleutherAI/pythia-14m \
--missing_eos_penalty 1.0 \
--save_strategy no \
--stop_token eos \
--use_peft \
--lora_r 32 \
--lora_alpha 16 \
--lora_target_modules query_key_value dense
"""
if platform.system() == "Windows":
# windows CI does not work with subprocesses for some reason
# e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743
return
subprocess.run(
command,
shell=True,
check=True,
)
7 changes: 7 additions & 0 deletions trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
from dataclasses import dataclass
from typing import Optional

from ..trainer.utils import OnPolicyConfig

Expand All @@ -32,6 +33,10 @@ class PPOConfig(OnPolicyConfig):
Name of this experiment.
reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
Path to the reward model.
model_adapter_name (`Optional[str]`, *optional*, defaults to `None`):
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
ref_adapter_name (`Optional[str]`, *optional*, defaults to `None`):
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
num_ppo_epochs (`int`, *optional*, defaults to `4`):
Number of epochs to train.
whiten_rewards (`bool`, *optional*, defaults to `False`):
Expand All @@ -52,6 +57,8 @@ class PPOConfig(OnPolicyConfig):

exp_name: str = os.path.basename(__file__)[: -len(".py")]
reward_model_path: str = "EleutherAI/pythia-160m"
model_adapter_name: Optional[str] = None
ref_adapter_name: Optional[str] = None
num_ppo_epochs: int = 4
whiten_rewards: bool = False
kl_coef: float = 0.05
Expand Down
Loading