Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐾 Process-supervised RM Trainer #2127

Merged
merged 140 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 136 commits
Commits
Show all changes
140 commits
Select commit Hold shift + click to select a range
357a8c6
initial skeleton
gaetanlop Sep 26, 2024
841f7a1
tokenize fn
gaetanlop Sep 26, 2024
641e899
adding bos and eos to tokenization fn
gaetanlop Sep 26, 2024
106bc0e
prmtrainer
gaetanlop Sep 27, 2024
0163dcc
fixing small typo in tokenize
gaetanlop Sep 27, 2024
c2720d7
typo in input_ids and labels construction
gaetanlop Sep 27, 2024
5034083
numpy dimension
gaetanlop Sep 27, 2024
8818b6a
introduce the stepwise reward trainer
gaetanlop Sep 28, 2024
b777d1c
update markdown files
gaetanlop Sep 28, 2024
afa9e0a
let user decide post step separator in config
gaetanlop Sep 28, 2024
2dd752d
doc post_step_separator
gaetanlop Sep 28, 2024
613d838
do not add post step_tokens to last step of the reasoning process
gaetanlop Sep 28, 2024
b96ef4d
renaming prm to stepwisereward
gaetanlop Sep 28, 2024
161f5de
formatting
gaetanlop Sep 28, 2024
93e6652
fix tokenize kwargs
gaetanlop Sep 28, 2024
3ec4ebe
adapt test to the new post_token args
gaetanlop Sep 28, 2024
1461a61
adding example script
gaetanlop Sep 28, 2024
8c4ac31
fix small typo
gaetanlop Sep 28, 2024
8b3fa52
add create_model_card and renaming
gaetanlop Oct 1, 2024
8e4e159
fixing booleans
gaetanlop Oct 1, 2024
c60bc40
Adding the new stepwise_preference instead of placeholders for datasets
gaetanlop Oct 1, 2024
614fb4e
formatting
gaetanlop Oct 1, 2024
c582464
Merge branch 'main' into prmtrainer
qgallouedec Oct 1, 2024
424af34
Merge branch 'main' into prmtrainer
kashif Oct 8, 2024
b00e32b
Update docs/source/_toctree.yml
gaetanlop Oct 12, 2024
d5f780a
Update examples/scripts/stepwise_reward_modeling.py
gaetanlop Oct 12, 2024
f02056a
Update trl/trainer/stepwise_reward_trainer.py
gaetanlop Oct 12, 2024
3ac323f
Update trl/trainer/stepwise_reward_trainer.py
gaetanlop Oct 12, 2024
436dfd7
update push to hub
gaetanlop Oct 12, 2024
f4e6d4e
step_separator can't be None
gaetanlop Oct 12, 2024
6947aef
Merge branch 'main' into prmtrainer
gaetanlop Oct 12, 2024
e0c0648
fix suggested typos
gaetanlop Oct 12, 2024
35de0ee
add citation
gaetanlop Oct 12, 2024
c3eb08e
reformat doc
gaetanlop Oct 12, 2024
898f621
reordering init
gaetanlop Oct 13, 2024
3a488e0
push to hub prm800k
gaetanlop Oct 13, 2024
a03aed8
changing dataset in example
gaetanlop Oct 13, 2024
e77eee2
change dataset format to align with the sky is blue example
gaetanlop Oct 13, 2024
6c62c69
Merge branch 'main' into prmtrainer
gaetanlop Oct 13, 2024
e8e93f1
fix tokenization column names
gaetanlop Oct 13, 2024
2059c51
fix num labels in openai example
gaetanlop Oct 13, 2024
701241b
add support for conversational dataset
gaetanlop Oct 13, 2024
6bb467b
remove training whitespace
gaetanlop Oct 13, 2024
6b2bd97
Merge branch 'main' into prmtrainer
gaetanlop Oct 14, 2024
2030a83
replace tokenizer with processing class
gaetanlop Oct 14, 2024
66baada
Merge branch 'prmtrainer' of https://github.com/gaetanlop/trl into pr…
gaetanlop Oct 14, 2024
b47eea5
Merge branch 'main' into prmtrainer
qgallouedec Nov 18, 2024
9b1693d
Merge branch 'main' into prmtrainer
gaetanlop Nov 24, 2024
086ea8f
Update docs/source/dataset_formats.mdx
gaetanlop Nov 24, 2024
fe440de
remove openai_prm800k
gaetanlop Nov 24, 2024
468502b
Update trl/trainer/stepwise_reward_trainer.py
gaetanlop Nov 24, 2024
d205064
Update trl/trainer/stepwise_reward_trainer.py
gaetanlop Nov 24, 2024
6128a7f
Merge branch 'prmtrainer' of https://github.com/gaetanlop/trl into pr…
gaetanlop Nov 24, 2024
faf1051
Update docs/source/stepwise_reward_trainer.mdx
gaetanlop Nov 24, 2024
dfe7e04
Update docs/source/stepwise_reward_trainer.mdx
gaetanlop Nov 24, 2024
fc702be
renaming
gaetanlop Nov 24, 2024
a65e30c
renaming
gaetanlop Nov 24, 2024
d53ad35
minor renamings in docs
gaetanlop Nov 24, 2024
24d2f1a
using prm800k instead of openai_prm800k
gaetanlop Nov 24, 2024
4fd282e
update num labels to 2 following the new format
gaetanlop Nov 24, 2024
2c9d2f3
changing doc examples to math examples
gaetanlop Nov 24, 2024
91a3de8
change reference to dataset_formats.mdx
gaetanlop Nov 24, 2024
97ef925
changing dataset config in test
gaetanlop Nov 24, 2024
754ba44
remove conversational dataset support
gaetanlop Nov 25, 2024
a7bac4e
remove conv dataset support
gaetanlop Nov 25, 2024
916f87e
fix bos token
gaetanlop Nov 25, 2024
364d7d8
fix scriptarguments in example
gaetanlop Nov 25, 2024
5a6970d
completion to completions
gaetanlop Nov 25, 2024
e445bad
remove valuerror for step_separator inside steps
gaetanlop Nov 25, 2024
fb15691
run precommit
gaetanlop Nov 25, 2024
1c76266
Merge branch 'main' into prmtrainer
gaetanlop Nov 25, 2024
9ae131a
Merge branch 'main' into prmtrainer
gaetanlop Nov 26, 2024
84c28fe
remove conv dataset support
gaetanlop Nov 26, 2024
16e4ef8
renaming zen dataset
gaetanlop Nov 26, 2024
147c375
remove unused printing
gaetanlop Nov 26, 2024
e310b0e
unknown label column
gaetanlop Nov 26, 2024
59f1e9f
introduce the train on last step arg
gaetanlop Nov 26, 2024
b057cf7
_tokenize support train_on_last_step
gaetanlop Nov 26, 2024
3a034d0
incorporate train_on_last_step to tests
gaetanlop Nov 26, 2024
8dce558
formatting
gaetanlop Nov 26, 2024
69adb5c
remove comments in trainer
gaetanlop Nov 26, 2024
be6e843
Refactor `tokenize_row`
qgallouedec Nov 26, 2024
e8c782d
Update max_completion_length parameter in StepwiseRewardConfig
qgallouedec Nov 26, 2024
4c83f41
Collator
qgallouedec Nov 26, 2024
a93138f
Update comment
qgallouedec Nov 26, 2024
072794a
Update type hint
qgallouedec Nov 26, 2024
5b10e38
fix table
qgallouedec Nov 26, 2024
5a8d0a2
Remove collator
qgallouedec Nov 26, 2024
f4ba54f
don't need pad token id
qgallouedec Nov 26, 2024
fd204d7
add error back
qgallouedec Nov 26, 2024
ebc8fb1
max length args
qgallouedec Nov 26, 2024
95a4a46
use tokenizer arg
qgallouedec Nov 26, 2024
46b6bd6
Update doc
qgallouedec Nov 26, 2024
201bdf2
label -> labels
qgallouedec Nov 26, 2024
4f28ed7
Merge pull request #1 from huggingface/prm-trainer-qgallouedec
gaetanlop Nov 27, 2024
0527531
Merge branch 'main' into prmtrainer
gaetanlop Nov 27, 2024
228aa31
fixing tokenization issues in tokenize row
gaetanlop Nov 27, 2024
aa33e62
correct labels for token classification
gaetanlop Nov 27, 2024
4cd0b79
adding max_length to tokenize_row
gaetanlop Nov 27, 2024
c58db4b
reformat tests
gaetanlop Nov 27, 2024
1385f46
adding tests for tokenize row
gaetanlop Nov 27, 2024
b2d45a8
fixing typos in comments
gaetanlop Nov 27, 2024
3d7d37d
update doc
gaetanlop Nov 28, 2024
ad3bd25
Add math_shepherd.py script for dataset processing
qgallouedec Nov 28, 2024
1cc6c8a
split the dataset
qgallouedec Nov 28, 2024
7273a3b
Merge pull request #2 from huggingface/prm-trainer-qgallouedec-2
gaetanlop Nov 29, 2024
b4e676b
Merge branch 'main' into prmtrainer
gaetanlop Nov 29, 2024
150500f
Merge branch 'main' into prmtrainer
qgallouedec Nov 29, 2024
30bb2c3
Merge branch 'main' into prmtrainer
gaetanlop Dec 1, 2024
32bb0b1
formatting
gaetanlop Dec 1, 2024
dec7bad
same evaluation method for the two training methods
gaetanlop Dec 2, 2024
e4fc400
adding filtering to example script
gaetanlop Dec 2, 2024
4ff8674
formatting
gaetanlop Dec 2, 2024
7787b98
Merge branch 'main' into prmtrainer
gaetanlop Dec 3, 2024
0d81c04
Merge branch 'main' into prmtrainer
qgallouedec Dec 9, 2024
049fdf9
Add features to avoid casting labels to bool in dataset tokenization
qgallouedec Dec 9, 2024
62b7465
Update docs/source/stepwise_reward_trainer.mdx [ci skip]
qgallouedec Dec 9, 2024
b62d74b
Add learning_rate parameter to StepwiseRewardConfig class
qgallouedec Dec 9, 2024
8d6a879
update doc
qgallouedec Dec 9, 2024
7da024c
Remove unused setup_chat_format function
qgallouedec Dec 9, 2024
c1f83ea
Fix warning message in stepwise_reward_modeling.py
qgallouedec Dec 9, 2024
a2d5837
Update logging steps in stepwise_reward_trainer.mdx
qgallouedec Dec 9, 2024
7146aff
little doc change [ci skip]
qgallouedec Dec 9, 2024
92be608
Merge branch 'main' into prmtrainer
qgallouedec Dec 10, 2024
ae677b1
Fix copyrights
qgallouedec Dec 10, 2024
7b88981
fix space after copyrights
qgallouedec Dec 10, 2024
c4faf19
Merge branch 'main' into prmtrainer
qgallouedec Dec 10, 2024
f164711
Update dataset loading in stepwise_reward_modeling.py
qgallouedec Dec 10, 2024
4572a21
refine compute_accuracy and proper test
qgallouedec Dec 10, 2024
75b50af
fix tests
qgallouedec Dec 10, 2024
2ebf9da
style
qgallouedec Dec 10, 2024
83e174e
Merge branch 'main' into prmtrainer
qgallouedec Dec 10, 2024
0d48cfa
Merge branch 'main' into prmtrainer
gaetanlop Dec 13, 2024
c4f6a62
renamings
gaetanlop Dec 13, 2024
81574f5
renaming in init
gaetanlop Dec 13, 2024
823825d
doc renaming
gaetanlop Dec 13, 2024
68e16f5
fix sorting and tag
qgallouedec Dec 13, 2024
9609ac8
experiemental [ci skip]
qgallouedec Dec 13, 2024
54011c9
trigger CI
qgallouedec Dec 13, 2024
686edfb
other doc fix
qgallouedec Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
title: PPO
- local: reward_trainer
title: Reward
- local: prm_trainer
title: PRM
- local: rloo_trainer
title: RLOO
- local: sft_trainer
Expand Down
31 changes: 16 additions & 15 deletions docs/source/dataset_formats.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -254,21 +254,22 @@ stepwise_example = {

Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer.

| Trainer | Expected dataset type |
| ----------------------- | ------------------------------------------------------------------------------------------------------ |
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) |
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) |
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`PPOTrainer`] | Tokenized language modeling |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
| Trainer | Expected dataset type |
| ------------------------- | ------------------------------------------------------------------------------------------------------ |
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) |
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) |
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`PPOTrainer`] | Tokenized language modeling |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |

<Tip>

Expand Down
117 changes: 117 additions & 0 deletions docs/source/prm_trainer.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# PRM

## Overview

Process reward models were proposed in [Solving math word problems with process- and outcome-based feedback](https://huggingface.co/papers/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving, and Irina Higgins.
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved

The abstract from the paper is the following:

> Recent work has shown that asking language models to generate reasoning steps improves performance on many reasoning tasks. When moving beyond prompting, this raises the question of how we should supervise such models: outcome-based approaches which supervise the final result, or process-based approaches which supervise the reasoning process itself? Differences between these approaches might naturally be expected not just in final-answer errors but also in reasoning errors, which can be difficult to detect and are problematic in many real-world domains such as education. We run the first comprehensive comparison between process- and outcome-based approaches trained on a natural language task, GSM8K. We find that pure outcome-based supervision produces similar final-answer error rates with less label supervision. However, for correct reasoning steps we find it necessary to use processbased supervision or supervision from learned reward models that emulate process-based feedback. In total, we improve the previous best results from 16.8% → 12.7% final-answer error and 14.0% → 3.4% reasoning error among final-answer-correct solutions.

This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Agustín Piqueres](https://huggingface.co/plaguss).


## Quick start

This example demonstrates how to train a model using the PRM method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model. We use the stepwise supervision data from the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd). You can view the data in the dataset here:

<iframe
src="https://huggingface.co/datasets/trl-lib/math_shepherd/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>

Below is the script to train the model:

```python
# train_prm.py
from datasets import load_dataset
from trl import PRMConfig, PRMTrainer
from transformers import AutoModelForTokenClassification, AutoTokenizer

model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]")

training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd", logging_steps=10)
trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```

Execute the script using the following command:

```bash
accelerate launch train_prm.py
```

Distributed across 8 GPUs, the training takes approximately 1 hour.

To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward-Math-Sheperd) performs, you can use the following script.


```python
from datasets import load_dataset
from transformers import pipeline

pipe = pipeline("token-classification", model="trl-lib/Qwen2-0.5B-Reward-Math-Sheperd")
dataset = load_dataset("trl-lib/math_shepherd")
example = {
"prompt": "Musa is the class teacher of a class of 45 students. He wants to split them into three groups by age. If a third of the class is under 11 years, and two-fifths are above 11 but under 13, how many students will be in the third group (13 years and above)?",
"completions": [
"Step 1: A third of the class is under 11 years because 11 - 1/3 = <<11-1/3=7>>7.",
"Step 2: Two-fifths of the class are above 11 but under 13 because 2/5 * 11 = <<2/5*11=8>>8.",
"Step 3: There are 45 students, so the third group will have 45 - 7 - 8 = <<45-7-8=20>>20 students. The answer is: 20",
],
"labels": [True, False, False],
}


separator = "\n" # It's important to use the same separator as the one used during training

for idx in range(1, len(example["completions"]) + 1):
steps = example["completions"][0:idx]
text = separator.join((example["prompt"], *steps)) + separator # Add a separator between the prompt and each steps
pred_entity = pipe(text)[-1]["entity"]
pred = {"LABEL_0": False, "LABEL_1": True}[pred_entity]
label = example["labels"][idx - 1]
print(f"Step {idx}\tPredicted: {pred} \tLabel: {label}")
```

```
Step 1 Predicted: True Label: True
Step 2 Predicted: False Label: False
Step 3 Predicted: False Label: False
```

It's a win!

## Expected dataset type

Process-supervised reward modeling requires a [stepwise supervision](dataset_formats#stepwise-supervision).
The dataset should contain the following columns: `prompt`, `completions` and `labels`, where `completions` contains a list of reasoning steps and `labels` a list of booleans or floats indicating the correctness of each step.

The [`PRMTrainer`] only supports [standard](dataset_formats#standard) dataset format.

## Example script

We provide an example script to train a model using the process-supervised reward modeling method. The script is available in [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py)

To use the process-supervised reward modeling script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) on the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd), run the following command:

```bash
accelerate launch examples/scripts/prm.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/math_shepherd \
--num_train_epochs 1 \
--logging_steps 25 \
--output_dir Qwen2-0.5B-Reward-Math-Sheperd
```

## PRMTrainer

[[autodoc]] PRMTrainer

## PRMConfig

[[autodoc]] PRMConfig
131 changes: 131 additions & 0 deletions examples/datasets/math_shepherd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from dataclasses import dataclass
from itertools import chain
from typing import Optional

from datasets import load_dataset
from transformers import HfArgumentParser


@dataclass
class ScriptArguments:
r"""
Arguments for the script.

Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/math_shepherd"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/math_shepherd"
dataset_num_proc: Optional[int] = None


def process_example(example):
# Replace "ки" with "ⶻ" so that the size of the "input" matches the size of the "label"
inputs = example["input"].replace("ки", "ⶻ")

# Find the indices of the "ⶻ" characters (that should match with the indexes of the "+" or "-" in the label)
indexes = [m.start() for m in re.finditer("ⶻ", inputs)]

# Sanity that all indexes are either "+" or "-"
assert all(example["label"][idx] in ["+", "-"] for idx in indexes)

# Get the labels
labels = [example["label"][idx] == "+" for idx in indexes]

# Split the inputs into steps (caution, the first step is missing here, it is the prompt)
steps = [inputs[i:j] for i, j in zip(chain([0], indexes), chain(indexes, [None]))]

# Remove the last step (single ⶻ)
steps = steps[:-1]

# Get the prompt (first part) and completions (rest)
prompt = steps[0]
completions = steps[1:]

# Remove the heading "ⶻ" and the final whitespace from the completions
assert all(completion.startswith("ⶻ") for completion in completions)
completions = [completion[1:].strip() for completion in completions]

# At this point, we need to retrieve the first step from the prompt.
# First, we handle particular cases (annotation error) where we have a first label before the end of the prompt.
if prompt.startswith(
(
"Mr. Rocky",
"Parker",
"What is the smallest positive",
" The Myth",
"Let $\\mathbf{a}$",
"Find the arithmetic",
"Determine an ordered pair",
"Determine the ordered pair",
"At the Quill and Scroll stationery",
"Round to the nearest",
r"Calculate $\sqrt{10p}",
r"Simplify $\sqrt{28x}",
)
):
# Some spotted datasets errors where there is an annotation in the prompt: we remove it
labels = labels[1:]

# Then we handle the general case: we get the first step from the prompt by looking for "Step 1:" or "step 1:" or
# (less common) "?".
elif "Step 1:" in prompt:
prompt, first_step = prompt.split("Step 1:")
first_step = "Step 1:" + first_step
completions = [first_step.strip()] + completions
elif "step 1:" in prompt:
prompt, first_step = prompt.split("step 1:")
first_step = "step 1:" + first_step
completions = [first_step.strip()] + completions
elif "?" in prompt:
prompt, first_step = prompt.split("?")
prompt = prompt + "?"
completions = [first_step.strip()] + completions
else:
raise ValueError(f"Prompt can't be processed: {prompt}")

# Strip the prompt
prompt = prompt.strip()

# Sanity check that the length of the completions is the same as the length of the labels
assert len(completions) == len(labels)

return {"prompt": prompt, "completions": completions, "labels": labels}


if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

dataset = load_dataset("peiyi9979/Math-Shepherd", split="train")

dataset = dataset.map(
process_example,
remove_columns=["input", "label", "task"],
num_proc=script_args.dataset_num_proc,
)
dataset = dataset.train_test_split(test_size=0.05, seed=42)

if script_args.push_to_hub:
dataset.push_to_hub(script_args.repo_id)
Loading
Loading