|
33 | 33 |
|
34 | 34 | import torch |
35 | 35 | from datasets import load_dataset |
36 | | -from latex2sympy2_extended import NormalizationConfig |
37 | | -from math_verify import LatexExtractionConfig, parse, verify |
38 | 36 | from peft import LoraConfig |
39 | 37 |
|
40 | 38 | from trl import RLOOConfig, RLOOTrainer |
41 | | -from trl.rewards import think_format_reward |
| 39 | +from trl.rewards import accuracy_reward, think_format_reward |
42 | 40 |
|
43 | 41 |
|
44 | 42 | # Enable logging in a Hugging Face Space |
@@ -67,52 +65,6 @@ def make_conversation(example): |
67 | 65 | train_dataset = train_dataset.map(make_conversation, remove_columns=["messages", "problem"]) |
68 | 66 | eval_dataset = eval_dataset.map(make_conversation, remove_columns=["messages", "problem"]) |
69 | 67 |
|
70 | | - # Reward function for training |
71 | | - def accuracy_reward(completions, solution: list[str], **kwargs): |
72 | | - """Reward function that checks if the completion matches the ground truth. |
73 | | - - If both gold and prediction are parseable → use math verification. |
74 | | - - If not parseable → compare as normalized text. |
75 | | - """ |
76 | | - rewards = [] |
77 | | - contents = [completion[0]["content"] for completion in completions] |
78 | | - for content, sol in zip(contents, solution): |
79 | | - try: |
80 | | - gold_parsed = parse(sol, extraction_mode="first_match") |
81 | | - except Exception: |
82 | | - gold_parsed = [] |
83 | | - |
84 | | - if len(gold_parsed) != 0: |
85 | | - # Try parsing predicted answer too |
86 | | - try: |
87 | | - answer_parsed = parse( |
88 | | - content, |
89 | | - extraction_config=[ |
90 | | - LatexExtractionConfig( |
91 | | - normalization_config=NormalizationConfig( |
92 | | - nits=False, |
93 | | - malformed_operators=False, |
94 | | - basic_latex=True, |
95 | | - boxed="all", |
96 | | - units=True, |
97 | | - ), |
98 | | - boxed_match_priority=0, |
99 | | - try_extract_without_anchor=False, |
100 | | - ) |
101 | | - ], |
102 | | - extraction_mode="first_match", |
103 | | - ) |
104 | | - reward = float(verify(gold_parsed, answer_parsed)) |
105 | | - except Exception as e: |
106 | | - print(f"verify failed: {e}, answer: {content}, gold: {sol}") |
107 | | - reward = None |
108 | | - else: |
109 | | - # fallback to text match |
110 | | - reward = float(content.strip().lower() == sol.strip().lower()) |
111 | | - |
112 | | - rewards.append(reward) |
113 | | - |
114 | | - return rewards |
115 | | - |
116 | 68 | # Training |
117 | 69 | training_args = RLOOConfig( |
118 | 70 | output_dir="Qwen3-0.6B-RLOO", |
|
0 commit comments