Skip to content

Commit

Permalink
PPO / Reinforce Trainers (#1540)
Browse files Browse the repository at this point in the history
* Add ppov2 trainer

* make eos trick optional, remove unused args

* quick fix

* precommit

* update debugging script

* fix out of bound `drop_last=True`; use built-in scheduler

* Add PPO examples

* push changes

* quick change

* quick change

* various bug fixes

* remove unnecessary grad accumulation setting

* push new changes

* fix DS3 model saving

* update ppo.py

* refactor

* quick change

* refactor

* update ppo trainer

* refactor

* quick test

* add ds2 /ds3 7 processes config

* add vllm trainer

* quick change

* experiment with reward normalization

* push changes

* quick push

* push changes

* push various changes

* refactor to use ModelConfig

* quick change

* refactor

* refactor

* Simplify DS logic

* quick update

* remove unnecessary files

* precommit

* deepspeed fix; handle edge case when eos_token_id = 0

* add PPO tldr example

* add TL;DR example

* fix undefined var

* utilize all samples in rloo

* quick setting

* remove the unnecessary `value_model`

* use exact_div

* allow saving the deepspeed model

* refactor

* remove dead code

* Use some shared utilities

* add some end-to-end test cases

* add PPOv2 docs and RLOO docs / tests

* update docs

* quikc push

* fix ci

* fix type annotation for ci

* quick update

* update trainer docs
  • Loading branch information
vwxyzjn authored May 22, 2024
1 parent 99f2c94 commit 13454d2
Show file tree
Hide file tree
Showing 23 changed files with 3,114 additions and 11 deletions.
4 changes: 4 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
title: Supervised Fine-Tuning
- local: ppo_trainer
title: PPO Trainer
- local: ppov2_trainer
title: PPOv2 Trainer
- local: rloo_trainer
title: RLOO Trainer
- local: best_of_n
title: Best of N Sampling
- local: dpo_trainer
Expand Down
257 changes: 257 additions & 0 deletions docs/source/ppov2_trainer.md

Large diffs are not rendered by default.

301 changes: 301 additions & 0 deletions docs/source/rloo_trainer.md

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion examples/accelerate_configs/deepspeed_zero2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
Expand Down
3 changes: 1 addition & 2 deletions examples/accelerate_configs/deepspeed_zero3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
Expand All @@ -12,7 +11,7 @@ distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
Expand Down
79 changes: 76 additions & 3 deletions examples/datasets/tldr_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class ScriptArguments:
hf_repo_id: Optional[str] = field(
default="tldr-preference-trl-style", metadata={"help": "The Hugging Face repository ID"}
)
sft_hf_repo_id: Optional[str] = field(
default="tldr-preference-sft-trl-style", metadata={"help": "The Hugging Face repository ID"}
)
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
update_main_revision: Optional[bool] = field(
default=True, metadata={"help": "Update the main revision of the repository"}
Expand All @@ -39,7 +42,11 @@ class ScriptArguments:
if args.hf_entity is None:
args.hf_entity = api.whoami()["name"]
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
full_sft_repo_id = f"{args.hf_entity}/{args.sft_hf_repo_id}"

################
# Preference dataset
################
ds = load_dataset("openai/summarize_from_feedback", "comparisons")
if args.debug:
for key in ds:
Expand Down Expand Up @@ -92,22 +99,88 @@ def process(row):
repo_type="dataset",
)

sft_card = RepoCard.load(
preference_card = RepoCard.load(
full_repo_id,
repo_type="dataset",
)
sft_card.text = f"""\
preference_card.text = f"""\
# TRL's TL;DR Preference Dataset
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
## Source of the dataset
We take the dataset from https://huggingface.co/datasets/openai/summarize_from_feedback.
## Reproduce this dataset
1. Download the `{file_name}` from the {repo_full_url}.
2. Run `{run_command}`
"""
sft_card.push_to_hub(
preference_card.push_to_hub(
full_repo_id,
repo_type="dataset",
)

################
# SFT dataset
################
sft_ds = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered")
if args.debug:
for key in sft_ds:
sft_ds[key] = sft_ds[key].select(range(50))

def sft_process(row):
row["prompt"] = tldr_format_str.format(**row)
row["messages"] = [
{"role": "user", "content": row["prompt"]},
{"role": "assistant", "content": row["summary"]},
]
return row

sft_ds = sft_ds.map(
sft_process,
num_proc=1 if args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
)
for key in sft_ds: # reorder columns
sft_ds[key] = sft_ds[key].select_columns(["prompt", "messages", "id", "subreddit", "title", "post", "summary"])
if args.push_to_hub:
revisions = ["main"] if args.update_main_revision else []
revisions.append(args.revision)

# get the commnad used to run the script
run_command = " ".join(["python"] + sys.argv)

for revision in revisions:
sft_ds.push_to_hub(full_sft_repo_id, revision=revision)
repo_full_url = f"https://huggingface.co/datasets/{full_sft_repo_id}/tree/{revision}"

# get the name of the current file
file_name = __file__.split("/")[-1]
api.upload_file(
path_or_fileobj=__file__,
path_in_repo=file_name,
revision=revision,
repo_id=full_sft_repo_id,
repo_type="dataset",
)

sft_card = RepoCard.load(
full_sft_repo_id,
repo_type="dataset",
)
sft_card.text = f"""\
# TRL's TL;DR SFT Dataset
We preprocess the dataset using our standard `prompt, messages` format.
## Source of the dataset
We take the dataset from https://huggingface.co/datasets/vwxyzjn/summarize_from_feedback_tldr_3_filtered.
## Reproduce this dataset
1. Download the `{file_name}` from the {repo_full_url}.
2. Run `{run_command}`
"""
89 changes: 89 additions & 0 deletions examples/scripts/evals/generate_tldr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import shlex
import subprocess
import sys
from collections import defaultdict
from dataclasses import dataclass

import pandas as pd
from datasets import load_dataset
from gpt_tldr_judge import LLMJudgeConfig, llm_judge
from transformers import AutoTokenizer, HfArgumentParser
from vllm import SamplingParams, SingleGPULLM


"""
python -i examples/scripts/evals/generate_tldr.py \
--model_name_or_path vwxyzjn/rloo_tldr \
--output_path examples/scripts/minimal/evals/rloo_tldr.csv \
--n 1000
python -i examples/scripts/evals/generate_tldr.py \
--model_name_or_path vwxyzjn/ppo_tldr \
--output_path examples/scripts/minimal/evals/ppo_tldr.csv \
--n 1000
"""


@dataclass
class Args:
output_path: str
model_name_or_path: str
model_revision: str = "main"
n: int = 1000


def run_command(command: str):
command_list = shlex.split(command)
print(f"running {command}")
subprocess.run(command_list, stderr=sys.stderr, stdout=sys.stdout)


MAX_TOKENS = 200 # a very generous max token length
parser = HfArgumentParser(Args)
args = parser.parse_args_into_dataclasses()[0]
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
revision=args.model_revision,
)
raw_datasets = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style")
prompts = raw_datasets["test"]["prompt"]
if args.n is not None:
prompts = prompts[: args.n]
reference_summaries = [message[-1]["content"] for message in raw_datasets["test"]["messages"]]
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=MAX_TOKENS)
llm = SingleGPULLM(
model=args.model_name_or_path,
revision=args.model_revision,
tensor_parallel_size=1,
device="cuda:0",
)
outputs = llm.generate(prompts, sampling_params)
table = defaultdict(list)

# Print the outputs.
for output, reference_response in zip(outputs, reference_summaries):
prompt = output.prompt
generated_text = output.outputs[0].text
table["prompt"].append(prompt)
table["model_response"].append(generated_text.strip()) # need `strip()` because of the leading space
table["model_response_len"].append(len(output.outputs[0].token_ids))
table["reference_response"].append(reference_response)
table["reference_response_len"].append(
len(tokenizer(f" {reference_response}")["input_ids"])
) # prepend leading space

df = pd.DataFrame(table)
df.to_csv(args.output_path)

#####
# GPT as a judge
####
df["response0"] = df["model_response"]
df["response1"] = df["reference_response"]
judged_df = llm_judge(
LLMJudgeConfig(
n=args.n,
model="gpt-3.5-turbo-0125",
),
df,
)
judged_df.to_csv(args.output_path.replace(".csv", "_judged.csv"))
141 changes: 141 additions & 0 deletions examples/scripts/evals/gpt_tldr_judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# you can download the CSV from https://wandb.ai/costa-huang/tldr_summarize/runs/gb2dian5

import asyncio
import random
import time
from dataclasses import dataclass
from typing import Optional

import pandas as pd
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio
from transformers import HfArgumentParser


@dataclass
class LLMJudgeConfig:
n: int = 64
model: str = "gpt-3.5-turbo-0125"
max_parallel_requests: Optional[int] = None

def __post_init__(self):
if "gpt-3.5" in self.model:
# gpt-3.5 generates so fast that it will exceeds the
# token limit per minute
self.max_parallel_requests = 11
elif "gpt-4" in self.model:
self.max_parallel_requests = 13


@dataclass
class Args:
csv: str = "trained_response.csv"
output_path: Optional[str] = None
num_trails: int = 1


TEMPLATE = r"""
Which of the following summaries does a better job of summarizing the most important points in the given forum post, without including unimportant or irrelevant details? Judge based on accuracy, coverage, and coherence.
### Post:
{{post}}
### Summary A:
{{response0}}
### Summary B:
{{response1}}
### Instructions:
FIRST provide a one-sentence comparison of the two summaries, explaining which \
you prefer and why. SECOND, on a new line, state only "A" or "B" to indicate your choice. Your response should use the format:
Comparison: <one-sentence comparison and explanation>
Preferred: <"A" or "B">
"""


def llm_judge(ljc: LLMJudgeConfig, df: pd.DataFrame):
limiter = asyncio.Semaphore(ljc.max_parallel_requests)
async_client = AsyncOpenAI()

async def process_text(post: str, response0: str, response1: str, i: int):
text = TEMPLATE.replace("{{post}}", post)
text = text.replace("{{response0}}", response0)
text = text.replace("{{response1}}", response1) # Ensure this split logic is correct for your data

async with limiter:
response = None
while response is None:
try:
response = await async_client.chat.completions.create(
model=ljc.model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": text},
],
)
r = response.choices[0].message.content
except Exception as e:
print(f"error in {i}: {e}")
time.sleep(30) # deal with rate limit
continue

try:
comparison = r.split("Comparison:")[1].split("Preferred:")[0].strip()
preferred = r.split("Preferred:")[1].strip()
return comparison, preferred, i, text + r
except Exception as e:
print(f"error in {i} {e}")
return "", random.choice(["A", "B"]), i, text + r

async def main(ljc: LLMJudgeConfig, df: pd.DataFrame):
"""`df` should have columns: `prompt`, `response0`, `response1`"""
tasks = []
df["explanation"] = [None for _ in range(len(df))]
df["preferred"] = [None for _ in range(len(df))]
df["shuffled_index"] = [None for _ in range(len(df))]
df["entire_conversation"] = [None for _ in range(len(df))]
r = range(min(ljc.n, len(df)))
if ljc.n == -1:
r = range(len(df))
for i in r:
post = df["prompt"].iloc[i].strip()
# shuffled the index to avoid GPT4's preference bias in the content's order
shuffled_index = random.randint(0, 1)
df.at[i, "shuffled_index"] = shuffled_index
responses = [
df["response0"].iloc[i].strip(),
df["response1"].iloc[i].strip(),
]
response0 = responses[shuffled_index]
response1 = responses[1 - shuffled_index]
task = asyncio.create_task(process_text(post, response0, response1, i))
tasks.append(task)

results = await tqdm_asyncio.gather(*tasks)

for _, (comparison, preferred, i, entire_conversation) in enumerate(results):
df.at[i, "explanation"] = comparison
df.at[i, "entire_conversation"] = entire_conversation
preferred_label = (
"response0"
if (df.at[i, "shuffled_index"] == 0 and preferred == "A")
or (df.at[i, "shuffled_index"] == 1 and preferred == "B")
else "response1"
)
df.at[i, "preferred"] = preferred_label
print(df["preferred"].value_counts())
return df

return asyncio.run(main(ljc, df))


if __name__ == "__main__":
args, ljc = HfArgumentParser((Args, LLMJudgeConfig)).parse_args_into_dataclasses()
df = pd.read_csv(args.csv)
df["reference_response"] = df["reference_response"].map(lambda x: x.split("<|endoftext|>")[0].strip())
df["prompt"] = df["query"].map(lambda x: x.strip())
df["response0"] = df["model_response"].map(lambda x: x.strip())
df["response1"] = df["reference_response"].map(lambda x: x.strip())
judge_df = llm_judge(ljc, df)
judge_df.to_csv(args.output_path)
Loading

0 comments on commit 13454d2

Please sign in to comment.