1
1
from typing import List
2
2
from datasets import Dataset
3
3
from vllm import LLM , SamplingParams
4
+ from utils import generate_prompt
5
+
4
6
5
7
def generate_predictions (
6
- model_name : str ,
7
- dataset : Dataset ,
8
- temperature : float = 1.0 ,
9
- n : int = 1
8
+ model_name : str , dataset : Dataset , temperature : float = 1.0 , n : int = 1
10
9
) -> List [List [str ]]:
11
- """
12
- Generate predictions for a given dataset using a specified language model and
10
+ """Generate predictions for a given dataset using a specified language model and
13
11
sampling parameters. The function loads the dataset, constructs prompts from
14
12
each example, and obtains generated predictions. The resulting predictions are
15
13
then added as a new column to the dataset.
16
14
17
15
Args:
16
+ ----
18
17
model_name (str): Name of the model to use for generation.
19
18
dataset (Dataset): The Dataset object.
20
19
temperature (float, optional): Temperature setting for the model's
21
20
sampling strategy. Default is 1.0.
22
21
n (int, optional): Number of sampling runs per prompt. Default is 1.
23
22
24
23
Returns:
24
+ -------
25
25
predictions (List[List[str]]): Predictions on the dataset.
26
+
26
27
"""
27
28
sampling_params = SamplingParams (n = n , temperature = temperature , max_tokens = 512 )
28
29
llm = LLM (model = model_name )
@@ -31,19 +32,7 @@ def generate_predictions(
31
32
for example in dataset :
32
33
prompt = example ["prompt" ]
33
34
test = example ["test" ]
34
- prompt = f"""Write a Python function implementation for the following prompt:
35
-
36
- { prompt }
37
-
38
- Your code should satisfy these tests:
39
-
40
- { test }
41
-
42
- Return only the implementation code, no tests or explanations. Be sure to include the relevant import statements:
43
- ```python
44
- code
45
- ```
46
- """
35
+ prompt = generate_prompt (prompt , test )
47
36
prompts .append (prompt )
48
37
49
38
outputs = llm .generate (prompts , sampling_params )
@@ -53,7 +42,6 @@ def generate_predictions(
53
42
generated_texts = [one .text for one in output .outputs ]
54
43
results .append (generated_texts )
55
44
return results
56
- #out_name = dataset_name.split("/")[-1]
57
- #out_name = f"wentingzhao/{out_name}_predictions_{n}"
58
- #ds.push_to_hub(out_name)
59
-
45
+ # out_name = dataset_name.split("/")[-1]
46
+ # out_name = f"wentingzhao/{out_name}_predictions_{n}"
47
+ # ds.push_to_hub(out_name)
0 commit comments