Skip to content

Commit 45bd148

Browse files
authored
Merge pull request #10 from AutoResearch/carlosg/genargs
feat: Add arguments for model parameters
2 parents 1d80dfe + 5231456 commit 45bd148

File tree

8 files changed

+273
-34
lines changed

8 files changed

+273
-34
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
# AutoDoc
22

3+
[![ssec](https://img.shields.io/badge/SSEC-Project-purple?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAA0AAAAOCAQAAABedl5ZAAAACXBIWXMAAAHKAAABygHMtnUxAAAAGXRFWHRTb2Z0d2FyZQB3d3cuaW5rc2NhcGUub3Jnm+48GgAAAMNJREFUGBltwcEqwwEcAOAfc1F2sNsOTqSlNUopSv5jW1YzHHYY/6YtLa1Jy4mbl3Bz8QIeyKM4fMaUxr4vZnEpjWnmLMSYCysxTcddhF25+EvJia5hhCudULAePyRalvUteXIfBgYxJufRuaKuprKsbDjVUrUj40FNQ11PTzEmrCmrevPhRcVQai8m1PRVvOPZgX2JttWYsGhD3atbHWcyUqX4oqDtJkJiJHUYv+R1JbaNHJmP/+Q1HLu2GbNoSm3Ft0+Y1YMdPSTSwQAAAABJRU5ErkJggg==&style=plastic)](https://escience.washington.edu/software-engineering/ssec/)
4+
35
[![Template](https://img.shields.io/badge/Template-LINCC%20Frameworks%20Python%20Project%20Template-brightgreen)](https://lincc-ppt.readthedocs.io/en/latest/)
46

5-
[![PyPI](https://img.shields.io/pypi/v/autora-doc?color=blue&logo=pypi&logoColor=white)](https://pypi.org/project/autora-doc/)
7+
<!-- [![PyPI](https://img.shields.io/pypi/v/autora-doc?color=blue&logo=pypi&logoColor=white)](https://pypi.org/project/autora-doc/) -->
68

79

810
[![GitHub Workflow Status](https://github.com/autoresearch/autodoc/actions/workflows/smoke-test.yml/badge.svg)](https://github.com/AutoResearch/autodoc/actions/workflows/smoke-test.yml)
911
[![codecov](https://codecov.io/gh/AutoResearch/autodoc/branch/main/graph/badge.svg)](https://codecov.io/gh/AutoResearch/autodoc)
10-
[![Read the Docs](https://img.shields.io/readthedocs/autora-doc)](https://autora-doc.readthedocs.io/)
12+
<!-- [![Read the Docs](https://img.shields.io/readthedocs/autora-doc)](https://autora-doc.readthedocs.io/) -->
1113

1214
This project was automatically generated using the LINCC-Frameworks
1315
[python-project-template](https://github.com/lincc-frameworks/python-project-template). For more information about the project template see the

azureml/eval.yml

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandJob.schema.json
22
command: >
33
python -m autora.doc.pipelines.main eval
44
${{inputs.data_dir}}/data.jsonl
5-
${{inputs.model_dir}}/llama-2-7b-chat-hf
6-
SYS_1
7-
INSTR_SWEETP_1
5+
--model-path ${{inputs.model_dir}}/llama-2-7b-chat-hf
6+
--sys-id ${{inputs.sys_id}}
7+
--instruc-id ${{inputs.instruc_id}}
8+
--param temperature=${{inputs.temperature}}
9+
--param top_k=${{inputs.top_k}}
10+
--param top_p=${{inputs.top_p}}
811
code: ../src
912
inputs:
1013
data_dir:
@@ -13,6 +16,11 @@ inputs:
1316
model_dir:
1417
type: uri_folder
1518
path: azureml://datastores/workspaceblobstore/paths/base_models
19+
temperature: 0.7
20+
top_p: 0.95
21+
top_k: 40
22+
sys_id: SYS_1
23+
instruc_id: INSTR_SWEETP_1
1624
# using a curated environment doesn't work because we need additional packages
1725
environment: # azureml://registries/azureml/environments/acpt-pytorch-2.0-cuda11.7/versions/21
1826
image: mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7:21
@@ -25,6 +33,6 @@ environment: # azureml://registries/azureml/environments/acpt-pytorch-2.0-cuda11
2533
# image: nvcr.io/nvidia/pytorch:23.10-py3
2634
conda_file: conda.yml
2735
display_name: autodoc_prediction
28-
compute: azureml:v100cluster
29-
experiment_name: autodoc_prediction
36+
compute: azureml:t4cluster
37+
experiment_name: evaluation
3038
description: |

azureml/generate.yml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,26 @@ command: >
33
python -m autora.doc.pipelines.main generate
44
--model-path ${{inputs.model_dir}}/llama-2-7b-chat-hf
55
--output ./outputs/output.txt
6+
--sys-id ${{inputs.sys_id}}
7+
--instruc-id ${{inputs.instruc_id}}
8+
--param temperature=${{inputs.temperature}}
9+
--param top_k=${{inputs.top_k}}
10+
--param top_p=${{inputs.top_p}}
611
autora/doc/pipelines/main.py
712
code: ../src
813
inputs:
914
model_dir:
1015
type: uri_folder
1116
path: azureml://datastores/workspaceblobstore/paths/base_models
17+
temperature: 0.7
18+
top_p: 0.95
19+
top_k: 40
20+
sys_id: SYS_1
21+
instruc_id: INSTR_SWEETP_1
1222
environment:
1323
image: mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7:21
1424
conda_file: conda.yml
1525
display_name: autodoc_prediction
16-
compute: azureml:v100cluster
17-
experiment_name: autodoc_prediction
26+
compute: azureml:t4cluster
27+
experiment_name: prediction
1828
description: |

notebooks/generate.ipynb

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"%load_ext autoreload\n",
10+
"%autoreload 2\n",
11+
"from autora.doc.runtime.predict_hf import Predictor\n",
12+
"from autora.doc.runtime.prompts import INSTR, SYS, InstructionPrompts, SystemPrompts"
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": null,
18+
"metadata": {},
19+
"outputs": [],
20+
"source": [
21+
"# model = \"../../models\" # if model has been previously downloaded via huggingface-cli\n",
22+
"model = \"meta-llama/Llama-2-7b-chat-hf\"\n",
23+
"pred = Predictor(model)"
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": null,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"TEST_CODE = \"\"\"\n",
33+
"from sweetpea import *\n",
34+
"from sweetpea.primitives import *\n",
35+
"\n",
36+
"number_list = [125, 132, 139, 146, 160, 167, 174, 181]\n",
37+
"letter_list = ['b', 'd', 'f', 'h', 's', 'u', 'w', 'y']\n",
38+
"\n",
39+
"number = Factor(\"number\", number_list)\n",
40+
"letter = Factor(\"letter\", letter_list)\n",
41+
"task = Factor(\"task\", [\"number task\", \"letter task\", \"free choice task\"])\n",
42+
"\n",
43+
"\n",
44+
"def is_forced_trial_switch(task):\n",
45+
" return (task[-1] == \"number task\" and task[0] == \"letter task\") or \\\n",
46+
" (task[-1] == \"letter task\" and task[0] == \"number task\")\n",
47+
"\n",
48+
"\n",
49+
"def is_forced_trial_repeat(task):\n",
50+
" return (task[-1] == \"number task\" and task[0] == \"number task\") or \\\n",
51+
" (task[-1] == \"letter task\" and task[0] == \"letter task\")\n",
52+
"\n",
53+
"\n",
54+
"def is_free_trial_transition(task):\n",
55+
" return task[-1] != \"free choice task\" and task[0] == \"free choice task\"\n",
56+
"\n",
57+
"\n",
58+
"def is_free_trial_repeat(task):\n",
59+
" return task[-1] == \"free choice task\" and task[0] == \"free choice task\"\n",
60+
"\n",
61+
"\n",
62+
"def is_not_relevant_transition(task):\n",
63+
" return not (is_forced_trial_repeat(task) or is_forced_trial_switch(task) or is_free_trial_repeat(\n",
64+
" task) or is_free_trial_transition(task))\n",
65+
"\n",
66+
"\n",
67+
"transit = Factor(\"task transition\", [\n",
68+
" DerivedLevel(\"forced switch\", transition(is_forced_trial_switch, [task]), 3),\n",
69+
" DerivedLevel(\"forced repeat\", transition(is_forced_trial_repeat, [task])),\n",
70+
" DerivedLevel(\"free transition\", transition(is_free_trial_transition, [task]), 4),\n",
71+
" DerivedLevel(\"free repeat\", transition(is_free_trial_repeat, [task]), 4),\n",
72+
" DerivedLevel(\"forced first\", transition(is_not_relevant_transition, [task]), 4)\n",
73+
"])\n",
74+
"design = [letter, number, task, transit]\n",
75+
"crossing = [[letter], [number], [transit]]\n",
76+
"constraints = [MinimumTrials(256)]\n",
77+
"\n",
78+
"block = MultiCrossBlock(design, crossing, constraints)\n",
79+
"\n",
80+
"experiment = synthesize_trials(block, 1)\n",
81+
"\n",
82+
"save_experiments_csv(block, experiment, 'code_1_sequences/seq')\n",
83+
"\"\"\""
84+
]
85+
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": null,
89+
"metadata": {},
90+
"outputs": [],
91+
"source": [
92+
"output = pred.predict(\n",
93+
" SYS[SystemPrompts.SYS_1],\n",
94+
" INSTR[InstructionPrompts.INSTR_SWEETP_EXAMPLE],\n",
95+
" [TEST_CODE],\n",
96+
" temperature=0.05,\n",
97+
" top_k=10,\n",
98+
" num_ret_seq=3,\n",
99+
")[0]\n",
100+
"for i, o in enumerate(output):\n",
101+
" print(f\"******** Output {i} ********\\n{o}*************\\n\")"
102+
]
103+
}
104+
],
105+
"metadata": {
106+
"kernelspec": {
107+
"display_name": "autodoc",
108+
"language": "python",
109+
"name": "python3"
110+
},
111+
"language_info": {
112+
"codemirror_mode": {
113+
"name": "ipython",
114+
"version": 3
115+
},
116+
"file_extension": ".py",
117+
"mimetype": "text/x-python",
118+
"name": "python",
119+
"nbconvert_exporter": "python",
120+
"pygments_lexer": "ipython3",
121+
"version": "3.8.18"
122+
}
123+
},
124+
"nbformat": 4,
125+
"nbformat_minor": 2
126+
}

src/autora/doc/pipelines/main.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import logging
23
from timeit import default_timer as timer
34
from typing import List
@@ -16,13 +17,24 @@
1617
logger = logging.getLogger(__name__)
1718

1819

19-
@app.command()
20-
def eval(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: InstructionPrompts) -> List[str]:
20+
@app.command(help="Evaluate model on a data file")
21+
def eval(
22+
data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"),
23+
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
24+
sys_id: SystemPrompts = typer.Option(SystemPrompts.SYS_1, help="System prompt ID"),
25+
instruc_id: InstructionPrompts = typer.Option(
26+
InstructionPrompts.INSTR_SWEETP_1, help="Instruction prompt ID"
27+
),
28+
param: List[str] = typer.Option(
29+
[], help="Additional float parameters to pass to the model as name=float pairs"
30+
),
31+
) -> List[List[str]]:
2132
import jsonlines
2233
import mlflow
2334

2435
mlflow.autolog()
2536

37+
param_dict = {pair[0]: float(pair[1]) for pair in [pair.split("=") for pair in param]}
2638
run = mlflow.active_run()
2739

2840
sys_prompt = SYS[sys_id]
@@ -33,6 +45,7 @@ def eval(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: Ins
3345
logger.info(f"Active run_id: {run.info.run_id}")
3446
logger.info(f"running predict with {data_file}")
3547
logger.info(f"model path: {model_path}")
48+
mlflow.log_params(param_dict)
3649

3750
with jsonlines.open(data_file) as reader:
3851
items = [item for item in reader]
@@ -41,16 +54,19 @@ def eval(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: Ins
4154

4255
pred = Predictor(model_path)
4356
timer_start = timer()
44-
predictions = pred.predict(sys_prompt, instr_prompt, inputs)
57+
predictions = pred.predict(sys_prompt, instr_prompt, inputs, **param_dict)
4558
timer_end = timer()
4659
pred_time = timer_end - timer_start
4760
mlflow.log_metric("prediction_time/doc", pred_time / (len(inputs)))
4861
for i in range(len(inputs)):
4962
mlflow.log_text(labels[i], f"label_{i}.txt")
5063
mlflow.log_text(inputs[i], f"input_{i}.py")
51-
mlflow.log_text(predictions[i], f"prediction_{i}.txt")
64+
for j in range(len(predictions[i])):
65+
mlflow.log_text(predictions[i][j], f"prediction_{i}_{j}.txt")
5266

53-
tokens = pred.tokenize(predictions)["input_ids"]
67+
# flatten predictions for counting tokens
68+
predictions_flat = list(itertools.chain.from_iterable(predictions))
69+
tokens = pred.tokenize(predictions_flat)["input_ids"]
5470
total_tokens = sum([len(token) for token in tokens])
5571
mlflow.log_metric("total_tokens", total_tokens)
5672
mlflow.log_metric("tokens/sec", total_tokens / pred_time)
@@ -59,18 +75,28 @@ def eval(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: Ins
5975

6076
@app.command()
6177
def generate(
62-
python_file: str,
63-
model_path: str = "meta-llama/llama-2-7b-chat-hf",
64-
output: str = "output.txt",
65-
sys_id: SystemPrompts = SystemPrompts.SYS_1,
66-
instruc_id: InstructionPrompts = InstructionPrompts.INSTR_SWEETP_1,
78+
python_file: str = typer.Argument(..., help="Python file to generate documentation for"),
79+
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
80+
output: str = typer.Option("output.txt", help="Output file"),
81+
sys_id: SystemPrompts = typer.Option(SystemPrompts.SYS_1, help="System prompt ID"),
82+
instruc_id: InstructionPrompts = typer.Option(
83+
InstructionPrompts.INSTR_SWEETP_1, help="Instruction prompt ID"
84+
),
85+
param: List[str] = typer.Option(
86+
[], help="Additional float parameters to pass to the model as name=float pairs"
87+
),
6788
) -> None:
89+
param_dict = {pair[0]: float(pair[1]) for pair in [pair.split("=") for pair in param]}
90+
"""
91+
Generate documentation from python file
92+
"""
6893
with open(python_file, "r") as f:
69-
inputs = [f.read()]
94+
input = f.read()
7095
sys_prompt = SYS[sys_id]
7196
instr_prompt = INSTR[instruc_id]
7297
pred = Predictor(model_path)
73-
predictions = pred.predict(sys_prompt, instr_prompt, inputs)
98+
# grab first result since we only passed one input
99+
predictions = pred.predict(sys_prompt, instr_prompt, [input], **param_dict)[0]
74100
assert len(predictions) == 1, f"Expected only one output, got {len(predictions)}"
75101
logger.info(f"Writing output to {output}")
76102
with open(output, "w") as f:

src/autora/doc/runtime/predict_hf.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,36 @@ def __init__(self, model_path: str):
2727
tokenizer=self.tokenizer,
2828
)
2929

30-
def predict(self, sys: str, instr: str, inputs: List[str]) -> List[str]:
31-
logger.info(f"Generating {len(inputs)} predictions")
30+
def predict(
31+
self,
32+
sys: str,
33+
instr: str,
34+
inputs: List[str],
35+
temperature: float = 0.6,
36+
top_p: float = 0.95,
37+
top_k: float = 40,
38+
max_length: float = 2048,
39+
num_ret_seq: float = 1,
40+
) -> List[List[str]]:
41+
logger.info(
42+
f"Generating {len(inputs)} predictions. Temperature: {temperature}, top_p: {top_p}, top_k: {top_k}, "
43+
f"max_length: {max_length}"
44+
)
3245
prompts = [TEMP_LLAMA2.format(sys=sys, instr=instr, input=input) for input in inputs]
33-
# TODO: Make these parameters configurable
3446
sequences = self.pipeline(
3547
prompts,
3648
do_sample=True,
37-
temperature=0.6,
38-
top_p=0.95,
39-
top_k=40,
40-
num_return_sequences=1,
49+
temperature=temperature,
50+
top_p=top_p,
51+
top_k=int(top_k),
52+
num_return_sequences=int(num_ret_seq),
4153
eos_token_id=self.tokenizer.eos_token_id,
42-
max_length=2048,
54+
max_length=int(max_length),
4355
)
4456

45-
results = [Predictor.trim_prompt(sequence[0]["generated_text"]) for sequence in sequences]
57+
results = [
58+
[Predictor.trim_prompt(seq["generated_text"]) for seq in sequence] for sequence in sequences
59+
]
4660
logger.info(f"Generated {len(results)} results")
4761
return results
4862

0 commit comments

Comments
 (0)