Skip to content

Commit 3c7e0a0

Browse files
authored
refactor: Do one prediction per input sequence, easier experimentation (#27)
1 parent 10294bc commit 3c7e0a0

File tree

5 files changed

+180
-91
lines changed

5 files changed

+180
-91
lines changed

notebooks/generate.ipynb

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,19 @@
88
"source": [
99
"%load_ext autoreload\n",
1010
"%autoreload 2\n",
11-
"from autora.doc.runtime.predict_hf import Predictor\n",
12-
"from autora.doc.runtime.prompts import PROMPTS, PromptIds"
11+
"from autora.doc.runtime.predict_hf import Predictor, preprocess_code\n",
12+
"from autora.doc.runtime.prompts import PROMPTS, PromptIds, PromptBuilder, SYS_GUIDES\n",
13+
"from autora.doc.pipelines.main import evaluate_documentation\n",
14+
"from autora.doc.pipelines.main import eval_prompt, load_data"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": null,
20+
"metadata": {},
21+
"outputs": [],
22+
"source": [
23+
"model = \"meta-llama/Llama-2-7b-chat-hf\""
1324
]
1425
},
1526
{
@@ -18,11 +29,16 @@
1829
"metadata": {},
1930
"outputs": [],
2031
"source": [
21-
"# model = \"../../models\" # if model has been previously downloaded via huggingface-cli\n",
22-
"model = \"meta-llama/Llama-2-7b-chat-hf\"\n",
2332
"pred = Predictor(model)"
2433
]
2534
},
35+
{
36+
"cell_type": "markdown",
37+
"metadata": {},
38+
"source": [
39+
"## Test generation for the variable declararion only"
40+
]
41+
},
2642
{
2743
"cell_type": "code",
2844
"execution_count": null,
@@ -33,7 +49,8 @@
3349
"iv = Variable(name=\"x\", value_range=(0, 2 * np.pi), allowed_values=np.linspace(0, 2 * np.pi, 30))\n",
3450
"dv = Variable(name=\"y\", type=ValueType.REAL)\n",
3551
"variables = VariableCollection(independent_variables=[iv], dependent_variables=[dv])\n",
36-
"\"\"\""
52+
"\"\"\"\n",
53+
"LABEL = \"The discovery problem is defined by a single independent variable $x \\in [0, 2 \\pi]$ and dependent variable $y$.\""
3754
]
3855
},
3956
{
@@ -42,18 +59,46 @@
4259
"metadata": {},
4360
"outputs": [],
4461
"source": [
45-
"def test(promptid, code):\n",
62+
"def test(promptid, code, label):\n",
4663
" output = pred.predict(\n",
4764
" PROMPTS[promptid],\n",
4865
" [code],\n",
4966
" do_sample=0,\n",
50-
" max_length=800,\n",
67+
" max_new_tokens=100,\n",
5168
" temperature=0.05,\n",
5269
" top_k=10,\n",
5370
" num_ret_seq=1,\n",
54-
" )[0]\n",
55-
" for i, o in enumerate(output):\n",
56-
" print(f\"{promptid}\\n******* Output {i} ********\\n{o}\\n*************\\n\")"
71+
" )\n",
72+
" bleu, meteor = evaluate_documentation(output, [label])\n",
73+
" for i, o in enumerate(output[0]):\n",
74+
" print(f\"{promptid}\\n******* Output {i} ********. bleu={bleu}, meteor={meteor}\\n{o}\\n*************\\n\")"
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": null,
80+
"metadata": {},
81+
"outputs": [],
82+
"source": [
83+
"# Zero shot test\n",
84+
"test(PromptIds.AUTORA_VARS_ZEROSHOT, TEST_VAR_CODE, LABEL)"
85+
]
86+
},
87+
{
88+
"cell_type": "code",
89+
"execution_count": null,
90+
"metadata": {},
91+
"outputs": [],
92+
"source": [
93+
"# One shot test\n",
94+
"test(PromptIds.AUTORA_VARS_ONESHOT, TEST_VAR_CODE, LABEL)"
95+
]
96+
},
97+
{
98+
"cell_type": "markdown",
99+
"metadata": {},
100+
"source": [
101+
"## One-shot generation for the complete code sample"
57102
]
58103
},
59104
{
@@ -62,7 +107,13 @@
62107
"metadata": {},
63108
"outputs": [],
64109
"source": [
65-
"test(PromptIds.AUTORA_VARS_ZEROSHOT, TEST_VAR_CODE)"
110+
"data_file = \"../data/autora/data.jsonl\"\n",
111+
"inputs, labels = load_data(data_file)\n",
112+
"# preprocessing removes comments, import statements and empty lines\n",
113+
"inputs = [preprocess_code(i) for i in inputs]\n",
114+
"INSTR = \"Generate high-level, one or two paragraph documentation for the following experiment.\"\n",
115+
"prompt = PromptBuilder(SYS_GUIDES, INSTR).add_example(f\"{inputs[0]}\", labels[0]).build()\n",
116+
"print(prompt)"
66117
]
67118
},
68119
{
@@ -71,8 +122,16 @@
71122
"metadata": {},
72123
"outputs": [],
73124
"source": [
74-
"test(PromptIds.AUTORA_VARS_ONESHOT, TEST_VAR_CODE)"
125+
"out, bleu, meteor = eval_prompt(data_file, pred, prompt, {\"max_new_tokens\": 800.0})\n",
126+
"print(f\"bleu={bleu}, meteor={meteor}\\n{out[0][0]}\\n*************\\n\")"
75127
]
128+
},
129+
{
130+
"cell_type": "code",
131+
"execution_count": null,
132+
"metadata": {},
133+
"outputs": [],
134+
"source": []
76135
}
77136
],
78137
"metadata": {

src/autora/doc/pipelines/main.py

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import itertools
22
import logging
33
from timeit import default_timer as timer
4-
from typing import List, Tuple
4+
from typing import Dict, List, Tuple
55

66
import nltk
77
import torch
@@ -20,13 +20,13 @@
2020
logger = logging.getLogger(__name__)
2121

2222

23-
def evaluate_documentation(predictions: List[List[str]], references: List[str]) -> Tuple[float, float]:
23+
def evaluate_documentation(predictions: List[str], references: List[str]) -> Tuple[float, float]:
2424
nltk.download("wordnet")
2525

2626
# Tokenize references
2727
tokenized_references = [ref.split() for ref in references]
2828
# Currently there is only 1 prediction for 1 reference, need to avg in future
29-
tokenized_predictions = [pred[0].split() if pred else [] for pred in predictions]
29+
tokenized_predictions = [pred.split() if pred else [] for pred in predictions]
3030

3131
# Calculate BLEU score with smoothing function
3232
# SmoothingFunction().method1 is used to avoid zero scores for n-grams not found in the reference.
@@ -55,16 +55,13 @@ def eval(
5555
param: List[str] = typer.Option(
5656
[], help="Additional float parameters to pass to the model as name=float pairs"
5757
),
58-
) -> List[List[str]]:
59-
import jsonlines
58+
) -> Tuple[List[str], float, float]:
6059
import mlflow
6160

6261
mlflow.autolog()
63-
64-
param_dict = {pair[0]: float(pair[1]) for pair in [pair.split("=") for pair in param]}
6562
run = mlflow.active_run()
63+
param_dict = {pair[0]: float(pair[1]) for pair in [pair.split("=") for pair in param]}
6664

67-
prompt = PROMPTS[prompt_id]
6865
if run is None:
6966
run = mlflow.start_run()
7067
with run:
@@ -75,36 +72,51 @@ def eval(
7572
mlflow.log_param("prompt_id", prompt_id)
7673
mlflow.log_param("model_path", model_path)
7774
mlflow.log_param("data_file", data_file)
75+
prompt = PROMPTS[prompt_id]
76+
pred = Predictor(model_path)
77+
return eval_prompt(data_file, pred, prompt, param_dict)
78+
79+
80+
def load_data(data_file: str) -> Tuple[List[str], List[str]]:
81+
import jsonlines
82+
83+
with jsonlines.open(data_file) as reader:
84+
items = [item for item in reader]
85+
inputs = [f"{item['instruction']}" for item in items]
86+
labels = [item["output"] for item in items]
87+
return inputs, labels
88+
89+
90+
def eval_prompt(
91+
data_file: str, pred: Predictor, prompt: str, param_dict: Dict[str, float]
92+
) -> Tuple[List[str], float, float]:
93+
import mlflow
7894

79-
with jsonlines.open(data_file) as reader:
80-
items = [item for item in reader]
81-
inputs = [item["instruction"] for item in items]
82-
labels = [item["output"] for item in items]
83-
84-
pred = Predictor(model_path)
85-
timer_start = timer()
86-
predictions = pred.predict(prompt, inputs, **param_dict)
87-
timer_end = timer()
88-
bleu, meteor = evaluate_documentation(predictions, labels)
89-
pred_time = timer_end - timer_start
90-
mlflow.log_metric("prediction_time/doc", pred_time / (len(inputs)))
91-
for i in range(len(inputs)):
92-
mlflow.log_text(labels[i], f"label_{i}.txt")
93-
mlflow.log_text(inputs[i], f"input_{i}.py")
94-
for j in range(len(predictions[i])):
95-
mlflow.log_text(predictions[i][j], f"prediction_{i}_{j}.txt")
96-
mlflow.log_text("bleu_score is ", str(bleu))
97-
mlflow.log_text("meteor_score is ", str(meteor))
98-
99-
# flatten predictions for counting tokens
100-
predictions_flat = list(itertools.chain.from_iterable(predictions))
101-
tokens = pred.tokenize(predictions_flat)["input_ids"]
102-
total_tokens = sum([len(token) for token in tokens])
103-
mlflow.log_metric("total_tokens", total_tokens)
104-
mlflow.log_metric("tokens/sec", total_tokens / pred_time)
105-
mlflow.log_metric("bleu_score", round(bleu, 5))
106-
mlflow.log_metric("meteor_score", round(meteor, 5))
107-
return predictions
95+
inputs, labels = load_data(data_file)
96+
97+
timer_start = timer()
98+
predictions = pred.predict(prompt, inputs, **param_dict)
99+
timer_end = timer()
100+
bleu, meteor = evaluate_documentation(predictions, labels)
101+
pred_time = timer_end - timer_start
102+
mlflow.log_metric("prediction_time/doc", pred_time / (len(inputs)))
103+
for i in range(len(inputs)):
104+
mlflow.log_text(labels[i], f"label_{i}.txt")
105+
mlflow.log_text(inputs[i], f"input_{i}.py")
106+
for j in range(len(predictions[i])):
107+
mlflow.log_text(predictions[i][j], f"prediction_{i}_{j}.txt")
108+
mlflow.log_text("bleu_score is ", str(bleu))
109+
mlflow.log_text("meteor_score is ", str(meteor))
110+
111+
# flatten predictions for counting tokens
112+
predictions_flat = list(itertools.chain.from_iterable(predictions))
113+
tokens = pred.tokenize(predictions_flat)["input_ids"]
114+
total_tokens = sum([len(token) for token in tokens])
115+
mlflow.log_metric("total_tokens", total_tokens)
116+
mlflow.log_metric("tokens/sec", total_tokens / pred_time)
117+
mlflow.log_metric("bleu_score", round(bleu, 5))
118+
mlflow.log_metric("meteor_score", round(meteor, 5))
119+
return predictions, bleu, meteor
108120

109121

110122
@app.command()
@@ -126,7 +138,7 @@ def generate(
126138
prompt = PROMPTS[prompt_id]
127139
pred = Predictor(model_path)
128140
# grab first result since we only passed one input
129-
predictions = pred.predict(prompt, [input], **param_dict)[0]
141+
predictions = pred.predict(prompt, [input], **param_dict)
130142
assert len(predictions) == 1, f"Expected only one output, got {len(predictions)}"
131143
logger.info(f"Writing output to {output}")
132144
with open(output, "w") as f:

src/autora/doc/runtime/predict_hf.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
import logging
2-
from typing import Dict, List
2+
from typing import Dict, Iterable, List
33

44
import torch
55
import transformers
66
from transformers import AutoModelForCausalLM, AutoTokenizer
77

8-
from autora.doc.runtime.prompts import LLAMA2_INST_CLOSE
8+
from autora.doc.runtime.prompts import CODE_PLACEHOLDER, LLAMA2_INST_CLOSE
99

1010
logger = logging.getLogger(__name__)
1111

1212

13+
def preprocess_code(code: str) -> str:
14+
lines: Iterable[str] = code.splitlines()
15+
skip_starts = {"import", "from", "#"}
16+
lines = filter(
17+
lambda line: not (any([line.strip().startswith(skip) for skip in skip_starts]) or line.strip() == ""),
18+
lines,
19+
)
20+
return "\n".join(lines)
21+
22+
1323
class Predictor:
1424
def __init__(self, model_path: str):
1525
config = self.get_config()
@@ -35,16 +45,18 @@ def predict(
3545
temperature: float = 0.01,
3646
top_p: float = 0.95,
3747
top_k: float = 1,
38-
max_length: float = 2048,
48+
max_new_tokens: float = 2048,
3949
num_ret_seq: float = 1,
40-
) -> List[List[str]]:
50+
) -> List[str]:
4151
# convert to bool in case it came in as a generate float param from the CLI
4252
do_sample = bool(do_sample)
4353
logger.info(
4454
f"Generating {len(inputs)} predictions. do_sample: {do_sample}, temperature: {temperature}, top_p: {top_p},"
45-
f" top_k: {top_k}, max_length: {max_length}"
55+
f" top_k: {top_k}, max_new_tokens: {max_new_tokens}"
4656
)
47-
prompts = [prompt_template.format(code=input) for input in inputs]
57+
prompts = [
58+
prompt_template.replace(CODE_PLACEHOLDER, preprocess_code(input).strip("\n")) for input in inputs
59+
]
4860
sequences = self.pipeline(
4961
prompts,
5062
do_sample=do_sample,
@@ -53,12 +65,10 @@ def predict(
5365
top_k=int(top_k),
5466
num_return_sequences=int(num_ret_seq),
5567
eos_token_id=self.tokenizer.eos_token_id,
56-
max_length=int(max_length),
68+
max_new_tokens=int(max_new_tokens),
5769
)
5870

59-
results = [
60-
[Predictor.trim_prompt(seq["generated_text"]) for seq in sequence] for sequence in sequences
61-
]
71+
results = [Predictor.trim_prompt(seq["generated_text"]) for sequence in sequences for seq in sequence]
6272
logger.info(f"Generated {len(results)} results")
6373
return results
6474

0 commit comments

Comments
 (0)