Skip to content

Commit f6b2a71

Browse files
committed
updates
1 parent 570090b commit f6b2a71

File tree

5 files changed

+215
-412
lines changed

5 files changed

+215
-412
lines changed

examples/star/star.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,17 @@
11
"""Main STaR Loop"""
22

3-
import argparse
43
from datasets import Dataset, load_dataset
54
from inference import generate_predictions
6-
from utils import execute_tests
5+
from train import train
6+
from utils import execute_tests, parse_args
77

88

99
def main():
10-
parser = argparse.ArgumentParser()
11-
parser.add_argument("--model_name", type=str, required=True, help="model to use")
12-
parser.add_argument(
13-
"--dataset_name", type=str, required=True, help="dataset to use"
14-
)
15-
parser.add_argument("--temperature", type=float, default=1)
16-
parser.add_argument("-n", type=int, default=1)
17-
args = parser.parse_args()
18-
10+
args = parse_args()
1911
ds = load_dataset(args.dataset_name)
2012
assert "train" in ds
2113
all_samples = generate_predictions(
22-
args.model_name, ds["train"], args.temperature, args.n
14+
args.model_name_or_path, ds["train"], args.temperature, args.n
2315
)
2416
assert len(ds["train"]) == len(all_samples)
2517
all_traces, all_execution_results = execute_tests(ds["train"], all_samples)
@@ -28,13 +20,15 @@ def main():
2820
ds["train"], all_execution_results, all_samples
2921
):
3022
for execution_result, sample in zip(execution_results, samples):
23+
# pytest exit code: https://docs.pytest.org/en/stable/reference/exit-codes.html
3124
if execution_result == 0:
3225
example["prediction"] = sample
3326
passed_examples.append(example)
3427
break
3528
new_ds = Dataset.from_list(passed_examples)
3629
new_ds.to_json("star_training.json")
3730
print(len(passed_examples) / len(ds["train"]))
31+
train(args)
3832

3933

4034
if __name__ == "__main__":

0 commit comments

Comments
 (0)