Skip to content

Commit 10294bc

Browse files
authored
feat: Support one-shot prompts (#26)
1 parent 61d2480 commit 10294bc

File tree

8 files changed

+175
-163
lines changed

8 files changed

+175
-163
lines changed

azureml/eval.yml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ command: >
33
python -m autora.doc.pipelines.main eval
44
${{inputs.data_dir}}/data.jsonl
55
--model-path ${{inputs.model_path}}
6-
--sys-id ${{inputs.sys_id}}
7-
--instruc-id ${{inputs.instruc_id}}
6+
--prompt-id ${{inputs.prompt_id}}
87
--param do_sample=${{inputs.do_sample}}
98
--param temperature=${{inputs.temperature}}
109
--param top_k=${{inputs.top_k}}
@@ -23,8 +22,7 @@ inputs:
2322
do_sample: 0
2423
top_p: 0.95
2524
top_k: 1
26-
sys_id: SYS_1
27-
instruc_id: INSTR_SWEETP_1
25+
prompt_id: SWEETP_1
2826
# using a curated environment doesn't work because we need additional packages
2927
environment: # azureml://registries/azureml/environments/acpt-pytorch-2.0-cuda11.7/versions/21
3028
image: mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7:21
@@ -37,6 +35,6 @@ environment: # azureml://registries/azureml/environments/acpt-pytorch-2.0-cuda11
3735
# image: nvcr.io/nvidia/pytorch:23.10-py3
3836
conda_file: conda.yml
3937
display_name: autodoc_prediction
40-
compute: azureml:t4cluster
38+
compute: azureml:v100cluster
4139
experiment_name: evaluation
4240
description: |

azureml/generate.yml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@ command: >
33
python -m autora.doc.pipelines.main generate
44
--model-path ${{inputs.model_path}}
55
--output ./outputs/output.txt
6-
--sys-id ${{inputs.sys_id}}
7-
--instruc-id ${{inputs.instruc_id}}
86
--param do_sample=${{inputs.do_sample}}
7+
--prompt-id ${{inputs.prompt_id}}
98
--param temperature=${{inputs.temperature}}
109
--param top_k=${{inputs.top_k}}
1110
--param top_p=${{inputs.top_p}}
@@ -21,12 +20,11 @@ inputs:
2120
do_sample: 0
2221
top_p: 0.95
2322
top_k: 40
24-
sys_id: SYS_1
25-
instruc_id: INSTR_SWEETP_1
23+
prompt_id: SWEETP_1
2624
environment:
2725
image: mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7:21
2826
conda_file: conda.yml
2927
display_name: autodoc_prediction
30-
compute: azureml:t4cluster
28+
compute: azureml:v100cluster
3129
experiment_name: prediction
3230
description: |

data/autora/code1_sm.txt

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
iv = Variable(name="x", value_range=(0, 2 * np.pi), allowed_values=np.linspace(0, 2 * np.pi, 30))
2+
dv = Variable(name="y", type=ValueType.REAL)
3+
variables = VariableCollection(independent_variables=[iv], dependent_variables=[dv])
4+
5+
conditions = random_pool(variables, num_samples=10, random_state=0)
6+
7+
experimentalist = on_state(random_pool, output=["conditions"])
8+
9+
sin_experiment = equation_experiment(
10+
sp.simplify("sin(x)"), variables.independent_variables, variables.dependent_variables[0]
11+
)
12+
sin_runner = sin_experiment.experiment_runner
13+
14+
experiment_runner = on_state(sin_runner, output=["experiment_data"])
15+
16+
theorist = estimator_on_state(BMSRegressor(epochs=100))
17+
18+
s = StandardState(
19+
variables=variables, conditions=conditions, experiment_data=pd.DataFrame(columns=["x", "y"])
20+
)
21+
22+
print("Pre-Defined State:")
23+
print(f"Number of datapoints collected: {len(s['experiment_data'])}")
24+
print(f"Derived models: {s['models']}")
25+
print("\n")
26+
27+
for i in range(5):
28+
s = experimentalist(s, num_samples=10, random_state=42)
29+
s = experiment_runner(s, added_noise=1.0, random_state=42)
30+
s = theorist(s)
31+
print(f"\nCycle {i+1} Results:")
32+
print(f"Number of datapoints collected: {len(s['experiment_data'])}")
33+
print(f"Derived models: {s['models']}")
34+
print("\n")

notebooks/generate.ipynb

Lines changed: 36 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"%load_ext autoreload\n",
1010
"%autoreload 2\n",
1111
"from autora.doc.runtime.predict_hf import Predictor\n",
12-
"from autora.doc.runtime.prompts import INSTR, SYS, InstructionPrompts, SystemPrompts"
12+
"from autora.doc.runtime.prompts import PROMPTS, PromptIds"
1313
]
1414
},
1515
{
@@ -29,57 +29,10 @@
2929
"metadata": {},
3030
"outputs": [],
3131
"source": [
32-
"TEST_CODE = \"\"\"\n",
33-
"from sweetpea import *\n",
34-
"from sweetpea.primitives import *\n",
35-
"\n",
36-
"number_list = [125, 132, 139, 146, 160, 167, 174, 181]\n",
37-
"letter_list = ['b', 'd', 'f', 'h', 's', 'u', 'w', 'y']\n",
38-
"\n",
39-
"number = Factor(\"number\", number_list)\n",
40-
"letter = Factor(\"letter\", letter_list)\n",
41-
"task = Factor(\"task\", [\"number task\", \"letter task\", \"free choice task\"])\n",
42-
"\n",
43-
"\n",
44-
"def is_forced_trial_switch(task):\n",
45-
" return (task[-1] == \"number task\" and task[0] == \"letter task\") or \\\n",
46-
" (task[-1] == \"letter task\" and task[0] == \"number task\")\n",
47-
"\n",
48-
"\n",
49-
"def is_forced_trial_repeat(task):\n",
50-
" return (task[-1] == \"number task\" and task[0] == \"number task\") or \\\n",
51-
" (task[-1] == \"letter task\" and task[0] == \"letter task\")\n",
52-
"\n",
53-
"\n",
54-
"def is_free_trial_transition(task):\n",
55-
" return task[-1] != \"free choice task\" and task[0] == \"free choice task\"\n",
56-
"\n",
57-
"\n",
58-
"def is_free_trial_repeat(task):\n",
59-
" return task[-1] == \"free choice task\" and task[0] == \"free choice task\"\n",
60-
"\n",
61-
"\n",
62-
"def is_not_relevant_transition(task):\n",
63-
" return not (is_forced_trial_repeat(task) or is_forced_trial_switch(task) or is_free_trial_repeat(\n",
64-
" task) or is_free_trial_transition(task))\n",
65-
"\n",
66-
"\n",
67-
"transit = Factor(\"task transition\", [\n",
68-
" DerivedLevel(\"forced switch\", transition(is_forced_trial_switch, [task]), 3),\n",
69-
" DerivedLevel(\"forced repeat\", transition(is_forced_trial_repeat, [task])),\n",
70-
" DerivedLevel(\"free transition\", transition(is_free_trial_transition, [task]), 4),\n",
71-
" DerivedLevel(\"free repeat\", transition(is_free_trial_repeat, [task]), 4),\n",
72-
" DerivedLevel(\"forced first\", transition(is_not_relevant_transition, [task]), 4)\n",
73-
"])\n",
74-
"design = [letter, number, task, transit]\n",
75-
"crossing = [[letter], [number], [transit]]\n",
76-
"constraints = [MinimumTrials(256)]\n",
77-
"\n",
78-
"block = MultiCrossBlock(design, crossing, constraints)\n",
79-
"\n",
80-
"experiment = synthesize_trials(block, 1)\n",
81-
"\n",
82-
"save_experiments_csv(block, experiment, 'code_1_sequences/seq')\n",
32+
"TEST_VAR_CODE = \"\"\"\n",
33+
"iv = Variable(name=\"x\", value_range=(0, 2 * np.pi), allowed_values=np.linspace(0, 2 * np.pi, 30))\n",
34+
"dv = Variable(name=\"y\", type=ValueType.REAL)\n",
35+
"variables = VariableCollection(independent_variables=[iv], dependent_variables=[dv])\n",
8336
"\"\"\""
8437
]
8538
},
@@ -89,16 +42,36 @@
8942
"metadata": {},
9043
"outputs": [],
9144
"source": [
92-
"output = pred.predict(\n",
93-
" SYS[SystemPrompts.SYS_1],\n",
94-
" INSTR[InstructionPrompts.INSTR_SWEETP_EXAMPLE],\n",
95-
" [TEST_CODE],\n",
96-
" temperature=0.05,\n",
97-
" top_k=10,\n",
98-
" num_ret_seq=3,\n",
99-
")[0]\n",
100-
"for i, o in enumerate(output):\n",
101-
" print(f\"******** Output {i} ********\\n{o}*************\\n\")"
45+
"def test(promptid, code):\n",
46+
" output = pred.predict(\n",
47+
" PROMPTS[promptid],\n",
48+
" [code],\n",
49+
" do_sample=0,\n",
50+
" max_length=800,\n",
51+
" temperature=0.05,\n",
52+
" top_k=10,\n",
53+
" num_ret_seq=1,\n",
54+
" )[0]\n",
55+
" for i, o in enumerate(output):\n",
56+
" print(f\"{promptid}\\n******* Output {i} ********\\n{o}\\n*************\\n\")"
57+
]
58+
},
59+
{
60+
"cell_type": "code",
61+
"execution_count": null,
62+
"metadata": {},
63+
"outputs": [],
64+
"source": [
65+
"test(PromptIds.AUTORA_VARS_ZEROSHOT, TEST_VAR_CODE)"
66+
]
67+
},
68+
{
69+
"cell_type": "code",
70+
"execution_count": null,
71+
"metadata": {},
72+
"outputs": [],
73+
"source": [
74+
"test(PromptIds.AUTORA_VARS_ONESHOT, TEST_VAR_CODE)"
10275
]
10376
}
10477
],
@@ -118,7 +91,7 @@
11891
"name": "python",
11992
"nbconvert_exporter": "python",
12093
"pygments_lexer": "ipython3",
121-
"version": "3.8.18"
94+
"version": "3.11.5"
12295
}
12396
},
12497
"nbformat": 4,

src/autora/doc/pipelines/main.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from nltk.translate.meteor_score import single_meteor_score
1111

1212
from autora.doc.runtime.predict_hf import Predictor
13-
from autora.doc.runtime.prompts import INSTR, SYS, InstructionPrompts, SystemPrompts
13+
from autora.doc.runtime.prompts import PROMPTS, PromptIds
1414

1515
app = typer.Typer()
1616
logging.basicConfig(
@@ -51,10 +51,7 @@ def evaluate_documentation(predictions: List[List[str]], references: List[str])
5151
def eval(
5252
data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"),
5353
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
54-
sys_id: SystemPrompts = typer.Option(SystemPrompts.SYS_1, help="System prompt ID"),
55-
instruc_id: InstructionPrompts = typer.Option(
56-
InstructionPrompts.INSTR_SWEETP_1, help="Instruction prompt ID"
57-
),
54+
prompt_id: PromptIds = typer.Option(PromptIds.SWEETP_1, help="Instruction prompt ID"),
5855
param: List[str] = typer.Option(
5956
[], help="Additional float parameters to pass to the model as name=float pairs"
6057
),
@@ -67,15 +64,17 @@ def eval(
6764
param_dict = {pair[0]: float(pair[1]) for pair in [pair.split("=") for pair in param]}
6865
run = mlflow.active_run()
6966

70-
sys_prompt = SYS[sys_id]
71-
instr_prompt = INSTR[instruc_id]
67+
prompt = PROMPTS[prompt_id]
7268
if run is None:
7369
run = mlflow.start_run()
7470
with run:
7571
logger.info(f"Active run_id: {run.info.run_id}")
7672
logger.info(f"running predict with {data_file}")
7773
logger.info(f"model path: {model_path}")
7874
mlflow.log_params(param_dict)
75+
mlflow.log_param("prompt_id", prompt_id)
76+
mlflow.log_param("model_path", model_path)
77+
mlflow.log_param("data_file", data_file)
7978

8079
with jsonlines.open(data_file) as reader:
8180
items = [item for item in reader]
@@ -84,10 +83,9 @@ def eval(
8483

8584
pred = Predictor(model_path)
8685
timer_start = timer()
87-
predictions = pred.predict(sys_prompt, instr_prompt, inputs, **param_dict)
88-
bleu, meteor = evaluate_documentation(predictions, labels)
89-
86+
predictions = pred.predict(prompt, inputs, **param_dict)
9087
timer_end = timer()
88+
bleu, meteor = evaluate_documentation(predictions, labels)
9189
pred_time = timer_end - timer_start
9290
mlflow.log_metric("prediction_time/doc", pred_time / (len(inputs)))
9391
for i in range(len(inputs)):
@@ -114,10 +112,7 @@ def generate(
114112
python_file: str = typer.Argument(..., help="Python file to generate documentation for"),
115113
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
116114
output: str = typer.Option("output.txt", help="Output file"),
117-
sys_id: SystemPrompts = typer.Option(SystemPrompts.SYS_1, help="System prompt ID"),
118-
instruc_id: InstructionPrompts = typer.Option(
119-
InstructionPrompts.INSTR_SWEETP_1, help="Instruction prompt ID"
120-
),
115+
prompt_id: PromptIds = typer.Option(PromptIds.SWEETP_1, help="Instruction prompt ID"),
121116
param: List[str] = typer.Option(
122117
[], help="Additional float parameters to pass to the model as name=float pairs"
123118
),
@@ -128,11 +123,10 @@ def generate(
128123
"""
129124
with open(python_file, "r") as f:
130125
input = f.read()
131-
sys_prompt = SYS[sys_id]
132-
instr_prompt = INSTR[instruc_id]
126+
prompt = PROMPTS[prompt_id]
133127
pred = Predictor(model_path)
134128
# grab first result since we only passed one input
135-
predictions = pred.predict(sys_prompt, instr_prompt, [input], **param_dict)[0]
129+
predictions = pred.predict(prompt, [input], **param_dict)[0]
136130
assert len(predictions) == 1, f"Expected only one output, got {len(predictions)}"
137131
logger.info(f"Writing output to {output}")
138132
with open(output, "w") as f:

src/autora/doc/runtime/predict_hf.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import transformers
66
from transformers import AutoModelForCausalLM, AutoTokenizer
77

8-
from autora.doc.runtime.prompts import LLAMA2_INST_CLOSE, TEMP_LLAMA2
8+
from autora.doc.runtime.prompts import LLAMA2_INST_CLOSE
99

1010
logger = logging.getLogger(__name__)
1111

@@ -29,8 +29,7 @@ def __init__(self, model_path: str):
2929

3030
def predict(
3131
self,
32-
sys: str,
33-
instr: str,
32+
prompt_template: str,
3433
inputs: List[str],
3534
do_sample: float = 0.0,
3635
temperature: float = 0.01,
@@ -45,7 +44,7 @@ def predict(
4544
f"Generating {len(inputs)} predictions. do_sample: {do_sample}, temperature: {temperature}, top_p: {top_p},"
4645
f" top_k: {top_k}, max_length: {max_length}"
4746
)
48-
prompts = [TEMP_LLAMA2.format(sys=sys, instr=instr, input=input) for input in inputs]
47+
prompts = [prompt_template.format(code=input) for input in inputs]
4948
sequences = self.pipeline(
5049
prompts,
5150
do_sample=do_sample,
@@ -65,7 +64,7 @@ def predict(
6564

6665
@staticmethod
6766
def trim_prompt(output: str) -> str:
68-
marker = output.find(LLAMA2_INST_CLOSE)
67+
marker = output.rfind(LLAMA2_INST_CLOSE)
6968
if marker == -1:
7069
logger.warning(f"Could not find end of prompt marker '{LLAMA2_INST_CLOSE}' in '{output}'")
7170
return output

0 commit comments

Comments
 (0)