Skip to content

Commit

Permalink
Update HH dataset on helpful only subset (huggingface#1613)
Browse files Browse the repository at this point in the history
* Update HH dataset on helpful only subset

* format
  • Loading branch information
vwxyzjn authored May 2, 2024
1 parent adf17a5 commit 7075cec
Show file tree
Hide file tree
Showing 9 changed files with 18 additions and 12 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir
**DPO:**

```bash
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/hh-rlhf-trl-style --output_dir opt-sft-hh-rlhf
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style --output_dir opt-sft-hh-rlhf
```

**Chat:**
Expand Down
2 changes: 1 addition & 1 deletion commands/run_dpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_dpo/"
MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM"
DATASET_NAME="trl-internal-testing/hh-rlhf-trl-style"
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128
Expand Down
4 changes: 2 additions & 2 deletions docs/source/clis.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The SFT CLI is based on the `examples/scripts/sft.py` script.

To use the DPO CLI, you need to have a dataset in the TRL format such as

* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-trl-style
* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style

These datasets always have at least three columns `prompt, chosen, rejected`:
Expand All @@ -78,7 +78,7 @@ These datasets always have at least three columns `prompt, chosen, rejected`:
To do a quick start, you can run the following command:

```bash
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-trl-style
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style
```


Expand Down
6 changes: 4 additions & 2 deletions examples/datasets/anthropic_hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
class ScriptArguments:
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"})
hf_repo_id: Optional[str] = field(default="hh-rlhf-trl-style", metadata={"help": "The Hugging Face repository ID"})
hf_repo_id: Optional[str] = field(
default="hh-rlhf-helpful-base-trl-style", metadata={"help": "The Hugging Face repository ID"}
)
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
update_main_revision: Optional[bool] = field(
default=True, metadata={"help": "Update the main revision of the repository"}
Expand Down Expand Up @@ -64,7 +66,7 @@ def extract_dialogue(input_text):
if args.hf_entity is None:
args.hf_entity = api.whoami()["name"]
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
ds = load_dataset("Anthropic/hh-rlhf")
ds = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base")
if args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
Expand Down
4 changes: 3 additions & 1 deletion examples/datasets/tokenize_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
@dataclass
class ScriptArguments:
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
dataset: str = field(default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The dataset to load"})
dataset: str = field(
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style", metadata={"help": "The dataset to load"}
)
model: str = field(default="gpt2", metadata={"help": "The model to use for tokenization"})


Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/cpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@
@dataclass
class ScriptArguments:
dataset: str = field(
default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."}
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style",
metadata={"help": "The name of the dataset to use."},
)


Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
# regular:
python examples/scripts/dpo.py \
--dataset_name=trl-internal-testing/hh-rlhf-trl-style \
--dataset_name=trl-internal-testing/hh-rlhf-helpful-base-trl-style \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--learning_rate 1e-3 \
Expand All @@ -31,7 +31,7 @@
# peft:
python examples/scripts/dpo.py \
--dataset_name=trl-internal-testing/hh-rlhf-trl-style \
--dataset_name=trl-internal-testing/hh-rlhf-helpful-base-trl-style \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--learning_rate 1e-3 \
Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@
@dataclass
class ScriptArguments:
dataset: str = field(
default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."}
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style",
metadata={"help": "The name of the dataset to use."},
)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_sft_cli():
def test_dpo_cli():
try:
subprocess.run(
"trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/hh-rlhf-trl-style --learning_rate 1e-4 --lr_scheduler_type cosine --sanity_check",
"trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style --learning_rate 1e-4 --lr_scheduler_type cosine --sanity_check",
shell=True,
check=True,
)
Expand Down

0 comments on commit 7075cec

Please sign in to comment.