diff --git a/README.md b/README.md index 89ec20eff8..4074912498 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir **DPO:** ```bash -trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/hh-rlhf-trl-style --output_dir opt-sft-hh-rlhf +trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style --output_dir opt-sft-hh-rlhf ``` **Chat:** diff --git a/commands/run_dpo.sh b/commands/run_dpo.sh index f9f3ab507f..c265a294e3 100644 --- a/commands/run_dpo.sh +++ b/commands/run_dpo.sh @@ -3,7 +3,7 @@ # but defaults to QLoRA + PEFT OUTPUT_DIR="test_dpo/" MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM" -DATASET_NAME="trl-internal-testing/hh-rlhf-trl-style" +DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style" MAX_STEPS=5 BATCH_SIZE=2 SEQ_LEN=128 diff --git a/docs/source/clis.mdx b/docs/source/clis.mdx index a3e818867d..85d9aef55b 100644 --- a/docs/source/clis.mdx +++ b/docs/source/clis.mdx @@ -65,7 +65,7 @@ The SFT CLI is based on the `examples/scripts/sft.py` script. To use the DPO CLI, you need to have a dataset in the TRL format such as -* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-trl-style +* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style * TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style These datasets always have at least three columns `prompt, chosen, rejected`: @@ -78,7 +78,7 @@ These datasets always have at least three columns `prompt, chosen, rejected`: To do a quick start, you can run the following command: ```bash -trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-trl-style +trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style ``` diff --git a/examples/datasets/anthropic_hh.py b/examples/datasets/anthropic_hh.py index 32274a2992..e021fa9ece 100644 --- a/examples/datasets/anthropic_hh.py +++ b/examples/datasets/anthropic_hh.py @@ -24,7 +24,9 @@ class ScriptArguments: debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"}) hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"}) - hf_repo_id: Optional[str] = field(default="hh-rlhf-trl-style", metadata={"help": "The Hugging Face repository ID"}) + hf_repo_id: Optional[str] = field( + default="hh-rlhf-helpful-base-trl-style", metadata={"help": "The Hugging Face repository ID"} + ) revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"}) update_main_revision: Optional[bool] = field( default=True, metadata={"help": "Update the main revision of the repository"} @@ -64,7 +66,7 @@ def extract_dialogue(input_text): if args.hf_entity is None: args.hf_entity = api.whoami()["name"] full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" - ds = load_dataset("Anthropic/hh-rlhf") + ds = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base") if args.debug: for key in ds: ds[key] = ds[key].select(range(50)) diff --git a/examples/datasets/tokenize_ds.py b/examples/datasets/tokenize_ds.py index 227af57f2e..0ba1876368 100644 --- a/examples/datasets/tokenize_ds.py +++ b/examples/datasets/tokenize_ds.py @@ -15,7 +15,9 @@ @dataclass class ScriptArguments: debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"}) - dataset: str = field(default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The dataset to load"}) + dataset: str = field( + default="trl-internal-testing/hh-rlhf-helpful-base-trl-style", metadata={"help": "The dataset to load"} + ) model: str = field(default="gpt2", metadata={"help": "The model to use for tokenization"}) diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py index 8e3e386405..77e0020e5f 100644 --- a/examples/scripts/cpo.py +++ b/examples/scripts/cpo.py @@ -64,7 +64,8 @@ @dataclass class ScriptArguments: dataset: str = field( - default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."} + default="trl-internal-testing/hh-rlhf-helpful-base-trl-style", + metadata={"help": "The name of the dataset to use."}, ) diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index d742988736..62df16f06d 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -15,7 +15,7 @@ """ # regular: python examples/scripts/dpo.py \ - --dataset_name=trl-internal-testing/hh-rlhf-trl-style \ + --dataset_name=trl-internal-testing/hh-rlhf-helpful-base-trl-style \ --model_name_or_path=gpt2 \ --per_device_train_batch_size 4 \ --learning_rate 1e-3 \ @@ -31,7 +31,7 @@ # peft: python examples/scripts/dpo.py \ - --dataset_name=trl-internal-testing/hh-rlhf-trl-style \ + --dataset_name=trl-internal-testing/hh-rlhf-helpful-base-trl-style \ --model_name_or_path=gpt2 \ --per_device_train_batch_size 4 \ --learning_rate 1e-3 \ diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index de2c77e417..a5c3381847 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -64,7 +64,8 @@ @dataclass class ScriptArguments: dataset: str = field( - default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."} + default="trl-internal-testing/hh-rlhf-helpful-base-trl-style", + metadata={"help": "The name of the dataset to use."}, ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 4e331cfcf3..4fecef7d0a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -32,7 +32,7 @@ def test_sft_cli(): def test_dpo_cli(): try: subprocess.run( - "trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/hh-rlhf-trl-style --learning_rate 1e-4 --lr_scheduler_type cosine --sanity_check", + "trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style --learning_rate 1e-4 --lr_scheduler_type cosine --sanity_check", shell=True, check=True, )