Skip to content

Commit 8e2d551

Browse files
Add accuracy reward (#4270)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent 94aac4a commit 8e2d551

File tree

15 files changed

+189
-316
lines changed

15 files changed

+189
-316
lines changed

docs/source/rewards.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`].
44

5-
## Format rewards
5+
## accuracy_reward
66

7-
### think_format_reward
7+
[[autodoc]] rewards.accuracy_reward
88

9-
[[autodoc]] rewards.think_format_reward
9+
## think_format_reward
1010

11-
## Other rewards
11+
[[autodoc]] rewards.think_format_reward
1212

13-
### get_soft_overlong_punishment
13+
## get_soft_overlong_punishment
1414

1515
[[autodoc]] rewards.get_soft_overlong_punishment

examples/scripts/grpo_vlm.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@
7070

7171
import torch
7272
from datasets import load_dataset
73-
from latex2sympy2_extended import NormalizationConfig
74-
from math_verify import LatexExtractionConfig, parse, verify
7573

7674
from trl import (
7775
GRPOConfig,
@@ -83,7 +81,7 @@
8381
get_peft_config,
8482
get_quantization_config,
8583
)
86-
from trl.rewards import think_format_reward
84+
from trl.rewards import accuracy_reward, think_format_reward
8785

8886

8987
# Enable logging in a Hugging Face Space
@@ -149,54 +147,6 @@ def convert_to_rgb(example):
149147
train_dataset = dataset["train"]
150148
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
151149

152-
################
153-
# Reward Function for Training
154-
################
155-
def accuracy_reward(completions, solution: list[str], **kwargs):
156-
"""Reward function that checks if the completion matches the ground truth.
157-
- If both gold and prediction are parseable → use math verification.
158-
- If not parseable → compare as normalized text.
159-
"""
160-
rewards = []
161-
contents = [completion[0]["content"] for completion in completions]
162-
for content, sol in zip(contents, solution):
163-
try:
164-
gold_parsed = parse(sol, extraction_mode="first_match")
165-
except Exception:
166-
gold_parsed = []
167-
168-
if len(gold_parsed) != 0:
169-
# Try parsing predicted answer too
170-
try:
171-
answer_parsed = parse(
172-
content,
173-
extraction_config=[
174-
LatexExtractionConfig(
175-
normalization_config=NormalizationConfig(
176-
nits=False,
177-
malformed_operators=False,
178-
basic_latex=True,
179-
boxed="all",
180-
units=True,
181-
),
182-
boxed_match_priority=0,
183-
try_extract_without_anchor=False,
184-
)
185-
],
186-
extraction_mode="first_match",
187-
)
188-
reward = float(verify(gold_parsed, answer_parsed))
189-
except Exception as e:
190-
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
191-
reward = None
192-
else:
193-
# fallback to text match
194-
reward = float(content.strip().lower() == sol.strip().lower())
195-
196-
rewards.append(reward)
197-
198-
return rewards
199-
200150
################
201151
# Training
202152
################

examples/scripts/gspo.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@
5757

5858
import torch
5959
from datasets import load_dataset
60-
from latex2sympy2_extended import NormalizationConfig
61-
from math_verify import LatexExtractionConfig, parse, verify
6260

6361
from trl import (
6462
GRPOConfig,
@@ -70,7 +68,7 @@
7068
get_peft_config,
7169
get_quantization_config,
7270
)
73-
from trl.rewards import think_format_reward
71+
from trl.rewards import accuracy_reward, think_format_reward
7472

7573

7674
# Enable logging in a Hugging Face Space
@@ -120,54 +118,6 @@ def make_conversation(example):
120118
train_dataset = train_dataset.remove_columns(["messages", "problem"])
121119
eval_dataset = eval_dataset.remove_columns(["messages", "problem"])
122120

123-
################
124-
# Reward Function for Training
125-
################
126-
def accuracy_reward(completions, solution: list[str], **kwargs):
127-
"""Reward function that checks if the completion matches the ground truth.
128-
- If both gold and prediction are parseable → use math verification.
129-
- If not parseable → compare as normalized text.
130-
"""
131-
rewards = []
132-
contents = [completion[0]["content"] for completion in completions]
133-
for content, sol in zip(contents, solution):
134-
try:
135-
gold_parsed = parse(sol, extraction_mode="first_match")
136-
except Exception:
137-
gold_parsed = []
138-
139-
if len(gold_parsed) != 0:
140-
# Try parsing predicted answer too
141-
try:
142-
answer_parsed = parse(
143-
content,
144-
extraction_config=[
145-
LatexExtractionConfig(
146-
normalization_config=NormalizationConfig(
147-
nits=False,
148-
malformed_operators=False,
149-
basic_latex=True,
150-
boxed="all",
151-
units=True,
152-
),
153-
boxed_match_priority=0,
154-
try_extract_without_anchor=False,
155-
)
156-
],
157-
extraction_mode="first_match",
158-
)
159-
reward = float(verify(gold_parsed, answer_parsed))
160-
except Exception as e:
161-
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
162-
reward = None
163-
else:
164-
# fallback to text match
165-
reward = float(content.strip().lower() == sol.strip().lower())
166-
167-
rewards.append(reward)
168-
169-
return rewards
170-
171121
################
172122
# Training
173123
################

examples/scripts/gspo_vlm.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@
5757

5858
import torch
5959
from datasets import load_dataset
60-
from latex2sympy2_extended import NormalizationConfig
61-
from math_verify import LatexExtractionConfig, parse, verify
6260

6361
from trl import (
6462
GRPOConfig,
@@ -70,7 +68,7 @@
7068
get_peft_config,
7169
get_quantization_config,
7270
)
73-
from trl.rewards import think_format_reward
71+
from trl.rewards import accuracy_reward, think_format_reward
7472

7573

7674
# Enable logging in a Hugging Face Space
@@ -136,54 +134,6 @@ def convert_to_rgb(example):
136134
train_dataset = dataset["train"]
137135
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
138136

139-
################
140-
# Reward Function for Training
141-
################
142-
def accuracy_reward(completions, solution: list[str], **kwargs):
143-
"""Reward function that checks if the completion matches the ground truth.
144-
- If both gold and prediction are parseable → use math verification.
145-
- If not parseable → compare as normalized text.
146-
"""
147-
rewards = []
148-
contents = [completion[0]["content"] for completion in completions]
149-
for content, sol in zip(contents, solution):
150-
try:
151-
gold_parsed = parse(sol, extraction_mode="first_match")
152-
except Exception:
153-
gold_parsed = []
154-
155-
if len(gold_parsed) != 0:
156-
# Try parsing predicted answer too
157-
try:
158-
answer_parsed = parse(
159-
content,
160-
extraction_config=[
161-
LatexExtractionConfig(
162-
normalization_config=NormalizationConfig(
163-
nits=False,
164-
malformed_operators=False,
165-
basic_latex=True,
166-
boxed="all",
167-
units=True,
168-
),
169-
boxed_match_priority=0,
170-
try_extract_without_anchor=False,
171-
)
172-
],
173-
extraction_mode="first_match",
174-
)
175-
reward = float(verify(gold_parsed, answer_parsed))
176-
except Exception as e:
177-
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
178-
reward = None
179-
else:
180-
# fallback to text match
181-
reward = float(content.strip().lower() == sol.strip().lower())
182-
183-
rewards.append(reward)
184-
185-
return rewards
186-
187137
################
188138
# Training
189139
################

examples/scripts/online_dpo_vlm.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@
8787
import torch
8888
import transformers
8989
from datasets import load_dataset
90-
from latex2sympy2_extended import NormalizationConfig
91-
from math_verify import LatexExtractionConfig, parse, verify
9290
from transformers import AutoConfig, AutoProcessor, GenerationConfig
9391

9492
from trl import (
@@ -102,7 +100,7 @@
102100
get_peft_config,
103101
get_quantization_config,
104102
)
105-
from trl.rewards import think_format_reward
103+
from trl.rewards import accuracy_reward, think_format_reward
106104

107105

108106
# Enable logging in a Hugging Face Space
@@ -192,54 +190,6 @@ def convert_to_rgb(example):
192190
train_dataset = dataset["train"]
193191
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
194192

195-
################
196-
# Reward Function for Training (same as GRPO VLM)
197-
################
198-
def accuracy_reward(completions, solution: list[str], **kwargs):
199-
"""Reward function that checks if the completion matches the ground truth.
200-
- If both gold and prediction are parseable → use math verification.
201-
- If not parseable → compare as normalized text.
202-
"""
203-
rewards = []
204-
contents = [completion[0]["content"] for completion in completions]
205-
for content, sol in zip(contents, solution):
206-
try:
207-
gold_parsed = parse(sol, extraction_mode="first_match")
208-
except Exception:
209-
gold_parsed = []
210-
211-
if len(gold_parsed) != 0:
212-
# Try parsing predicted answer too
213-
try:
214-
answer_parsed = parse(
215-
content,
216-
extraction_config=[
217-
LatexExtractionConfig(
218-
normalization_config=NormalizationConfig(
219-
nits=False,
220-
malformed_operators=False,
221-
basic_latex=True,
222-
boxed="all",
223-
units=True,
224-
),
225-
boxed_match_priority=0,
226-
try_extract_without_anchor=False,
227-
)
228-
],
229-
extraction_mode="first_match",
230-
)
231-
reward = float(verify(gold_parsed, answer_parsed))
232-
except Exception as e:
233-
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
234-
reward = None
235-
else:
236-
# fallback to text match
237-
reward = float(content.strip().lower() == sol.strip().lower())
238-
239-
rewards.append(reward)
240-
241-
return rewards
242-
243193
################
244194
# Training
245195
################

examples/scripts/rloo.py

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,10 @@
3333

3434
import torch
3535
from datasets import load_dataset
36-
from latex2sympy2_extended import NormalizationConfig
37-
from math_verify import LatexExtractionConfig, parse, verify
3836
from peft import LoraConfig
3937

4038
from trl import RLOOConfig, RLOOTrainer
41-
from trl.rewards import think_format_reward
39+
from trl.rewards import accuracy_reward, think_format_reward
4240

4341

4442
# Enable logging in a Hugging Face Space
@@ -67,52 +65,6 @@ def make_conversation(example):
6765
train_dataset = train_dataset.map(make_conversation, remove_columns=["messages", "problem"])
6866
eval_dataset = eval_dataset.map(make_conversation, remove_columns=["messages", "problem"])
6967

70-
# Reward function for training
71-
def accuracy_reward(completions, solution: list[str], **kwargs):
72-
"""Reward function that checks if the completion matches the ground truth.
73-
- If both gold and prediction are parseable → use math verification.
74-
- If not parseable → compare as normalized text.
75-
"""
76-
rewards = []
77-
contents = [completion[0]["content"] for completion in completions]
78-
for content, sol in zip(contents, solution):
79-
try:
80-
gold_parsed = parse(sol, extraction_mode="first_match")
81-
except Exception:
82-
gold_parsed = []
83-
84-
if len(gold_parsed) != 0:
85-
# Try parsing predicted answer too
86-
try:
87-
answer_parsed = parse(
88-
content,
89-
extraction_config=[
90-
LatexExtractionConfig(
91-
normalization_config=NormalizationConfig(
92-
nits=False,
93-
malformed_operators=False,
94-
basic_latex=True,
95-
boxed="all",
96-
units=True,
97-
),
98-
boxed_match_priority=0,
99-
try_extract_without_anchor=False,
100-
)
101-
],
102-
extraction_mode="first_match",
103-
)
104-
reward = float(verify(gold_parsed, answer_parsed))
105-
except Exception as e:
106-
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
107-
reward = None
108-
else:
109-
# fallback to text match
110-
reward = float(content.strip().lower() == sol.strip().lower())
111-
112-
rewards.append(reward)
113-
114-
return rewards
115-
11668
# Training
11769
training_args = RLOOConfig(
11870
output_dir="Qwen3-0.6B-RLOO",

0 commit comments

Comments
 (0)