1
1
"""Main STaR Loop"""
2
2
3
+ from copy import deepcopy
3
4
from datasets import Dataset , DatasetDict , load_dataset
4
5
from inference import generate_predictions
5
6
from train import train
@@ -21,13 +22,13 @@ def main():
21
22
ds [split ] = ds [split ].add_column (name = "text" , column = texts )
22
23
23
24
model_name = args .model_name_or_path
24
- ds [ "train" ] = ds [ "train" ]. select ( range ( 10 ) )
25
+ output_dir = deepcopy ( args . output_dir )
25
26
for i in range (args .iteration ):
26
27
# sample
27
28
all_samples = generate_predictions (
28
29
model_name , ds ["train" ], args .temperature , args .n
29
30
)
30
- ds ["train" ].add_column (name = "sample" , column = all_samples ).to_json (f"{ args . output_dir } /data/samples-iter{ i } .json" )
31
+ ds ["train" ].add_column (name = "sample" , column = all_samples ).to_json (f"{ output_dir } /data/samples-iter{ i } .json" )
31
32
assert len (ds ["train" ]) == len (all_samples )
32
33
33
34
# verify and construct the training set
@@ -43,10 +44,10 @@ def main():
43
44
passed_examples .append (example )
44
45
break
45
46
raw_datasets = DatasetDict ({"train" : Dataset .from_list (passed_examples ), "validation" : ds ["validation" ]})
46
- raw_datasets ["train" ].to_json (f"{ args . output_dir } /data/verified-samples-iter{ i } .json" )
47
+ raw_datasets ["train" ].to_json (f"{ output_dir } /data/verified-samples-iter{ i } .json" )
47
48
48
49
# train
49
- args .output_dir = f"{ args . output_dir } /models-iter{ i } "
50
+ args .output_dir = f"{ output_dir } /models-iter{ i } "
50
51
train (raw_datasets , model_name , args )
51
52
model_name = args .output_dir
52
53
0 commit comments