Skip to content

Commit 570090b

Browse files
committed
more updates
1 parent b58a8bb commit 570090b

File tree

4 files changed

+228
-100
lines changed

4 files changed

+228
-100
lines changed

examples/star/inference.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,29 @@
11
from typing import List
22
from datasets import Dataset
33
from vllm import LLM, SamplingParams
4+
from utils import generate_prompt
5+
46

57
def generate_predictions(
6-
model_name: str,
7-
dataset: Dataset,
8-
temperature: float = 1.0,
9-
n: int = 1
8+
model_name: str, dataset: Dataset, temperature: float = 1.0, n: int = 1
109
) -> List[List[str]]:
11-
"""
12-
Generate predictions for a given dataset using a specified language model and
10+
"""Generate predictions for a given dataset using a specified language model and
1311
sampling parameters. The function loads the dataset, constructs prompts from
1412
each example, and obtains generated predictions. The resulting predictions are
1513
then added as a new column to the dataset.
1614
1715
Args:
16+
----
1817
model_name (str): Name of the model to use for generation.
1918
dataset (Dataset): The Dataset object.
2019
temperature (float, optional): Temperature setting for the model's
2120
sampling strategy. Default is 1.0.
2221
n (int, optional): Number of sampling runs per prompt. Default is 1.
2322
2423
Returns:
24+
-------
2525
predictions (List[List[str]]): Predictions on the dataset.
26+
2627
"""
2728
sampling_params = SamplingParams(n=n, temperature=temperature, max_tokens=512)
2829
llm = LLM(model=model_name)
@@ -31,19 +32,7 @@ def generate_predictions(
3132
for example in dataset:
3233
prompt = example["prompt"]
3334
test = example["test"]
34-
prompt = f"""Write a Python function implementation for the following prompt:
35-
36-
{prompt}
37-
38-
Your code should satisfy these tests:
39-
40-
{test}
41-
42-
Return only the implementation code, no tests or explanations. Be sure to include the relevant import statements:
43-
```python
44-
code
45-
```
46-
"""
35+
prompt = generate_prompt(prompt, test)
4736
prompts.append(prompt)
4837

4938
outputs = llm.generate(prompts, sampling_params)
@@ -53,7 +42,6 @@ def generate_predictions(
5342
generated_texts = [one.text for one in output.outputs]
5443
results.append(generated_texts)
5544
return results
56-
#out_name = dataset_name.split("/")[-1]
57-
#out_name = f"wentingzhao/{out_name}_predictions_{n}"
58-
#ds.push_to_hub(out_name)
59-
45+
# out_name = dataset_name.split("/")[-1]
46+
# out_name = f"wentingzhao/{out_name}_predictions_{n}"
47+
# ds.push_to_hub(out_name)

examples/star/star.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Main STaR Loop"""
2+
23
import argparse
34
from datasets import Dataset, load_dataset
45
from inference import generate_predictions
@@ -8,24 +9,33 @@
89
def main():
910
parser = argparse.ArgumentParser()
1011
parser.add_argument("--model_name", type=str, required=True, help="model to use")
11-
parser.add_argument("--dataset_name", type=str, required=True, help="dataset to use")
12+
parser.add_argument(
13+
"--dataset_name", type=str, required=True, help="dataset to use"
14+
)
1215
parser.add_argument("--temperature", type=float, default=1)
1316
parser.add_argument("-n", type=int, default=1)
1417
args = parser.parse_args()
1518

1619
ds = load_dataset(args.dataset_name)
1720
assert "train" in ds
18-
all_samples = generate_predictions(args.model_name, ds["train"], args.temperature, args.n)
21+
all_samples = generate_predictions(
22+
args.model_name, ds["train"], args.temperature, args.n
23+
)
1924
assert len(ds["train"]) == len(all_samples)
2025
all_traces, all_execution_results = execute_tests(ds["train"], all_samples)
2126
passed_examples = []
22-
for example, execution_results, samples in zip(ds["train"], all_execution_results, all_samples):
27+
for example, execution_results, samples in zip(
28+
ds["train"], all_execution_results, all_samples
29+
):
2330
for execution_result, sample in zip(execution_results, samples):
2431
if execution_result == 0:
25-
example['prediction'] = sample
32+
example["prediction"] = sample
2633
passed_examples.append(example)
2734
break
28-
print(len(passed_examples)/len(ds["train"]))
35+
new_ds = Dataset.from_list(passed_examples)
36+
new_ds.to_json("star_training.json")
37+
print(len(passed_examples) / len(ds["train"]))
38+
2939

30-
if __name__ == '__main__':
40+
if __name__ == "__main__":
3141
main()

0 commit comments

Comments
 (0)