1
1
"""Main STaR Loop"""
2
2
3
- import argparse
4
3
from datasets import Dataset , load_dataset
5
4
from inference import generate_predictions
6
- from utils import execute_tests
5
+ from train import train
6
+ from utils import execute_tests , parse_args
7
7
8
8
9
9
def main ():
10
- parser = argparse .ArgumentParser ()
11
- parser .add_argument ("--model_name" , type = str , required = True , help = "model to use" )
12
- parser .add_argument (
13
- "--dataset_name" , type = str , required = True , help = "dataset to use"
14
- )
15
- parser .add_argument ("--temperature" , type = float , default = 1 )
16
- parser .add_argument ("-n" , type = int , default = 1 )
17
- args = parser .parse_args ()
18
-
10
+ args = parse_args ()
19
11
ds = load_dataset (args .dataset_name )
20
12
assert "train" in ds
21
13
all_samples = generate_predictions (
22
- args .model_name , ds ["train" ], args .temperature , args .n
14
+ args .model_name_or_path , ds ["train" ], args .temperature , args .n
23
15
)
24
16
assert len (ds ["train" ]) == len (all_samples )
25
17
all_traces , all_execution_results = execute_tests (ds ["train" ], all_samples )
@@ -28,13 +20,15 @@ def main():
28
20
ds ["train" ], all_execution_results , all_samples
29
21
):
30
22
for execution_result , sample in zip (execution_results , samples ):
23
+ # pytest exit code: https://docs.pytest.org/en/stable/reference/exit-codes.html
31
24
if execution_result == 0 :
32
25
example ["prediction" ] = sample
33
26
passed_examples .append (example )
34
27
break
35
28
new_ds = Dataset .from_list (passed_examples )
36
29
new_ds .to_json ("star_training.json" )
37
30
print (len (passed_examples ) / len (ds ["train" ]))
31
+ train (args )
38
32
39
33
40
34
if __name__ == "__main__" :
0 commit comments