Skip to content
Closed
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
591f3c1
wip adding HTGen dataset and benchmark
Feb 20, 2025
69b44ca
add API test
Feb 20, 2025
b3a9587
wip adding HTGen dataset and benchmark
Feb 20, 2025
41af428
add API test
Feb 20, 2025
58cbde4
construct prompt in the dataset generator
Feb 22, 2025
aa3523d
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 22, 2025
cde339c
merge from upstream
Feb 22, 2025
2d225e9
prompt construction
Feb 22, 2025
d28e8f7
fix some typos and add more docstrings
Feb 22, 2025
40fdef8
add reward
Feb 22, 2025
4341395
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 23, 2025
3eae18c
fix typos
Feb 23, 2025
ce8d1b1
Merge branch 'feature/htgen-dataset' of github.com:unfoldml/open-r1 i…
Feb 23, 2025
5e2ea33
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 25, 2025
8535700
Merge branch 'feature/htgen-dataset' of github.com:unfoldml/open-r1 i…
Feb 25, 2025
410b4f9
add unit test for code rewards
Feb 25, 2025
465ae8c
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 25, 2025
fbd20c7
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 26, 2025
c567d55
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 26, 2025
d63b4a2
docstring
Feb 28, 2025
0c9732c
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 2, 2025
c84f645
fix makefile and tests
ocramz Mar 2, 2025
ff0db1e
fix code rewards test
ocramz Mar 2, 2025
1793479
add prompt and fix_triple reward
ocramz Mar 2, 2025
3303083
fix makefile to activate venv correctly
ocramz Mar 2, 2025
532a012
fix_triple task: add reward tests and docstrings
ocramz Mar 2, 2025
66969e8
readme
ocramz Mar 2, 2025
3f88b06
fix style and quality
ocramz Mar 2, 2025
604f66f
cannot reliably activate venv within makefile
ocramz Mar 2, 2025
ed0c484
ignore API json parsing errors
ocramz Mar 2, 2025
6e4298c
cleanup and docstrings
Mar 3, 2025
aa97e62
add test for verify v2 endpoint
ocramz Mar 3, 2025
334b4b0
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 5, 2025
3bd8689
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 7, 2025
40736bd
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 14, 2025
456543a
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 18, 2025
bbf700a
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 23, 2025
f3f2166
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 29, 2025
cde9a89
Merge branch 'main' into feature/htgen-dataset
ocramz Apr 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add reward
  • Loading branch information
Marco Zocca committed Feb 22, 2025
commit 40fdef852ccc54dbf6cc63d9fe773fbe831f5d75
59 changes: 52 additions & 7 deletions src/open_r1/rewards/code/htgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@ def quotes(s:str):
"""markdown triple backticks for a piece of code"""
return f"```{str}```"

# totality check task

# TOTALITY_CHECK task
def mk_row_totality_check(o):
"""
Construct the prompt
NB: the rows have a 'prompt' column as required by the GRPOTrainer interface:
https://huggingface.co/docs/trl/main/grpo_trainer#trl.GRPOTrainer.train_dataset
"""
label = o['label']
pre = o['pre']
Expand Down Expand Up @@ -51,8 +54,9 @@ def mk_row_totality_check(o):

# # construct a row of the dataset
o_out = {
"problem": prompt_problem,
"solution": label_is_total
"prompt": prompt_problem,
"ground_truth": label_is_total,
"triple": {"pre": pre, "program":program, "post": post}
}

return o_out
Expand Down Expand Up @@ -82,10 +86,51 @@ def mk_dataset_totality_check(

return dataset

def totality_check_reward(completions, solution, **kwargs):
def totality_check_reward(completions, ground_truth, **kwargs):
"""
verification callback for GRPOTRainer
:param completions: list of truthy values produced by the model
:param ground_truth: list of boolean ground truth values
:returns: list of float 1s or 0s with the prediction scores that match the ground truth
"""
# pass the completion together with the reference solution to 'verify_triple_X'
# and score the result
pass
if not isinstance(completions[0], bool):
completions = [bool(c) for c in completions]
def verify(predicted, actual):
if predicted == actual:
return 1.0
else:
return 0.0

return [verify(predicted, actual) for (predicted, actual) in zip(completions, ground_truth)]




if __name__ == "__main__":
compls = [True]
ground_truth = ["True"]
res = totality_check_reward(compls, ground_truth)
print(res)


# # # verify against API

# def totality_oracle_reward(completions, triples, **kwargs):
# """
# verification callback for GRPOTRainer
# :param completions: list of truthy values produced by the model
# :param triples: list of program triples dicts {"pre":: string, "program":: string, "post:: string}
# """

# def verify(pre, program, post, is_total):
# res = verify_triple_33(
# preconditions = pre,
# program = program,
# postconditions = post,
# is_total = is_total
# )
# if res is not None:
# prediction = res['prediction_is_correct']
# return 1.0 if prediction else 0.0
# else:
# return 0.0