Skip to content

Commit 8170fae

Browse files
committed
added multi-iteration
1 parent d11e0a1 commit 8170fae

File tree

3 files changed

+51
-23
lines changed

3 files changed

+51
-23
lines changed

examples/star/inference.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,30 @@
1+
import gc
12
from typing import List
23
from datasets import Dataset
34
from vllm import LLM, SamplingParams
45
from utils import generate_prompt
56

67

8+
def cleanup(model):
9+
try:
10+
import torch
11+
import contextlib
12+
if torch.cuda.is_available():
13+
from vllm.distributed.parallel_state import (
14+
destroy_model_parallel, destroy_distributed_environment
15+
)
16+
destroy_model_parallel()
17+
destroy_distributed_environment()
18+
del model.llm_engine.model_executor
19+
del model
20+
with contextlib.suppress(AssertionError):
21+
torch.distributed.destroy_process_group()
22+
gc.collect()
23+
torch.cuda.empty_cache()
24+
torch.cuda.synchronize()
25+
except ImportError:
26+
del model
27+
728
def generate_predictions(
829
model_name: str, dataset: Dataset, temperature: float = 1.0, n: int = 1
930
) -> List[List[str]]:
@@ -41,6 +62,7 @@ def generate_predictions(
4162
for output in outputs:
4263
generated_texts = [one.text for one in output.outputs]
4364
results.append(generated_texts)
65+
cleanup(llm)
4466
return results
4567
# out_name = dataset_name.split("/")[-1]
4668
# out_name = f"wentingzhao/{out_name}_predictions_{n}"

examples/star/star.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,36 @@ def main():
1818
canonical_solution = f"```python\n{example['canonical_solution']}\n```"
1919
text = [{"role": "user", "message": generate_prompt(example["prompt"], example["test"])}, {"role": "assistant", "message": format_solution(canonical_solution, example["prompt"])}]
2020
texts.append(text)
21-
print(text)
2221
ds[split] = ds[split].add_column(name="text", column=texts)
2322

24-
# sample
25-
all_samples = generate_predictions(
26-
args.model_name_or_path, ds["train"], args.temperature, args.n
27-
)
28-
assert len(ds["train"]) == len(all_samples)
29-
30-
# verify and construct the training set
31-
all_traces, all_execution_results = execute_tests(ds["train"], all_samples)
32-
passed_examples = []
33-
for example, execution_results, samples in zip(
34-
ds["train"], all_execution_results, all_samples
35-
):
36-
for execution_result, sample in zip(execution_results, samples):
37-
# pytest exit code: https://docs.pytest.org/en/stable/reference/exit-codes.html
38-
if execution_result == 0:
39-
example["text"] = [{"role": "user", "message": generate_prompt(example["prompt"], example["test"])}, {"role": "assistant", "message": format_solution(sample, example["prompt"])}]
40-
passed_examples.append(example)
41-
break
42-
raw_datasets = DatasetDict({"train": Dataset.from_list(passed_examples), "validation": ds["validation"]})
43-
44-
# train
45-
train(raw_datasets, args.model_name_or_path, args)
23+
model_name = args.model_name_or_path
24+
for i in range(args.iteration):
25+
# sample
26+
all_samples = generate_predictions(
27+
model_name, ds["train"], args.temperature, args.n
28+
)
29+
ds["train"].add_column(name="sample", column=all_samples).to_json(f"{args.output_dir}/data/samples-iter{i}.json")
30+
assert len(ds["train"]) == len(all_samples)
31+
32+
# verify and construct the training set
33+
all_traces, all_execution_results = execute_tests(ds["train"], all_samples)
34+
passed_examples = []
35+
for example, execution_results, samples in zip(
36+
ds["train"], all_execution_results, all_samples
37+
):
38+
for execution_result, sample in zip(execution_results, samples):
39+
# pytest exit code: https://docs.pytest.org/en/stable/reference/exit-codes.html
40+
if execution_result == 0:
41+
example["text"] = [{"role": "user", "message": generate_prompt(example["prompt"], example["test"])}, {"role": "assistant", "message": format_solution(sample, example["prompt"])}]
42+
passed_examples.append(example)
43+
break
44+
raw_datasets = DatasetDict({"train": Dataset.from_list(passed_examples), "validation": ds["validation"]})
45+
raw_datasets["train"].to_json(f"{args.output_dir}/data/verified-samples-iter{i}.json")
46+
47+
# train
48+
args.output_dir = f"{args.output_dir}/models-iter{i}"
49+
train(raw_datasets, model_name, args)
50+
model_name = args.output_dir
4651

4752

4853
if __name__ == "__main__":

examples/star/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def parse_args():
115115
)
116116
parser.add_argument("--temperature", type=float, default=1)
117117
parser.add_argument("-n", type=int, default=1)
118+
parser.add_argument("--iteration", type=int, default=1)
118119
parser.add_argument(
119120
"--dataset_name",
120121
type=str,

0 commit comments

Comments
 (0)