Skip to content

Commit d1d6d4c

Browse files
authored
Merge pull request #6 from AutoResearch/carlosg/generate
feat: Generate command
2 parents 4f8d900 + 4b40f34 commit d1d6d4c

File tree

8 files changed

+67
-20
lines changed

8 files changed

+67
-20
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ az storage blob upload --account-name <account> --container <container>> --file
7878

7979
Prediction
8080
```sh
81-
az ml job create -f azureml/predict.yml --set display_name="Test prediction job" --web
81+
az ml job create -f azureml/eval.yml --set display_name="Test prediction job" --web
8282
```
8383

8484
Notes:

azureml/predict.yml renamed to azureml/eval.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
$schema: https://azuremlschemas.azureedge.net/latest/commandJob.schema.json
22
command: >
3-
python -m autora.doc.pipelines.main predict
3+
python -m autora.doc.pipelines.main eval
44
${{inputs.data_dir}}/data.jsonl
55
${{inputs.model_dir}}/llama-2-7b-chat-hf
66
SYS_1

azureml/generate.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
$schema: https://azuremlschemas.azureedge.net/latest/commandJob.schema.json
2+
command: >
3+
python -m autora.doc.pipelines.main generate
4+
--model-path ${{inputs.model_dir}}/llama-2-7b-chat-hf
5+
--output ./outputs/output.txt
6+
autora/doc/pipelines/main.py
7+
code: ../src
8+
inputs:
9+
model_dir:
10+
type: uri_folder
11+
path: azureml://datastores/workspaceblobstore/paths/base_models
12+
environment:
13+
image: mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7:21
14+
conda_file: conda.yml
15+
display_name: autodoc_prediction
16+
compute: azureml:v100cluster
17+
experiment_name: autodoc_prediction
18+
description: |

pyproject.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ classifiers = [
1616
]
1717
dynamic = ["version"]
1818
dependencies = [
19-
"transformers>=4.35.2",
2019
"typer",
2120
"scipy",
2221
# This works, while installing from pytorch and cuda from conda does not",
@@ -42,17 +41,18 @@ dev = [
4241
"nbsphinx", # Used to integrate Python notebooks into Sphinx documentation
4342
"ipython", # Also used in building notebooks into Sphinx
4443
"matplotlib", # Used in sample notebook intro_notebook.ipynb
45-
"numpy", # Used in sample notebook intro_notebook.ipynb
4644
"ipykernel",
4745
]
4846
train = [
47+
"jsonlines",
4948
"mlflow",
50-
"azureml-mlflow",
49+
]
50+
azure = [
5151
"azureml-core",
52-
"jsonlines",
52+
"azureml-mlflow",
5353
]
54-
55-
train_cuda = [
54+
cuda = [
55+
"transformers>=4.35.2",
5656
"bitsandbytes>=0.41.2.post2",
5757
"accelerate>=0.24.1",
5858
"xformers",

src/autora/doc/pipelines/main.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
from timeit import default_timer as timer
33
from typing import List
44

5-
import jsonlines
6-
import mlflow
75
import torch
86
import typer
97

@@ -19,9 +17,12 @@
1917

2018

2119
@app.command()
22-
def predict(
23-
data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: InstructionPrompts
24-
) -> List[str]:
20+
def eval(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: InstructionPrompts) -> List[str]:
21+
import jsonlines
22+
import mlflow
23+
24+
mlflow.autolog()
25+
2526
run = mlflow.active_run()
2627

2728
sys_prompt = SYS[sys_id]
@@ -33,7 +34,6 @@ def predict(
3334
logger.info(f"running predict with {data_file}")
3435
logger.info(f"model path: {model_path}")
3536

36-
# predictions = []
3737
with jsonlines.open(data_file) as reader:
3838
items = [item for item in reader]
3939
inputs = [item["instruction"] for item in items]
@@ -57,6 +57,26 @@ def predict(
5757
return predictions
5858

5959

60+
@app.command()
61+
def generate(
62+
python_file: str,
63+
model_path: str = "meta-llama/llama-2-7b-chat-hf",
64+
output: str = "output.txt",
65+
sys_id: SystemPrompts = SystemPrompts.SYS_1,
66+
instruc_id: InstructionPrompts = InstructionPrompts.INSTR_SWEETP_1,
67+
) -> None:
68+
with open(python_file, "r") as f:
69+
inputs = [f.read()]
70+
sys_prompt = SYS[sys_id]
71+
instr_prompt = INSTR[instruc_id]
72+
pred = Predictor(model_path)
73+
predictions = pred.predict(sys_prompt, instr_prompt, inputs)
74+
assert len(predictions) == 1, f"Expected only one output, got {len(predictions)}"
75+
logger.info(f"Writing output to {output}")
76+
with open(output, "w") as f:
77+
f.write(predictions[0])
78+
79+
6080
@app.command()
6181
def import_model(model_name: str) -> None:
6282
pass
@@ -65,5 +85,4 @@ def import_model(model_name: str) -> None:
6585
if __name__ == "__main__":
6686
logger.info(f"Torch version: {torch.__version__} , Cuda available: {torch.cuda.is_available()}")
6787

68-
mlflow.autolog()
6988
app()

src/autora/doc/runtime/predict_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def predict(self, sys: str, instr: str, inputs: List[str]) -> List[str]:
3939
top_k=40,
4040
num_return_sequences=1,
4141
eos_token_id=self.tokenizer.eos_token_id,
42-
max_length=1000,
42+
max_length=2048,
4343
)
4444

4545
results = [Predictor.trim_prompt(sequence[0]["generated_text"]) for sequence in sequences]

src/autora/doc/runtime/prompts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
paragraph should explain the purpose and the second one the procedure, but don't use the word 'Paragraph'"""
2525

2626

27-
class SystemPrompts(Enum):
27+
class SystemPrompts(str, Enum):
2828
SYS_1 = "SYS_1"
2929

3030

31-
class InstructionPrompts(Enum):
31+
class InstructionPrompts(str, Enum):
3232
INSTR_SWEETP_1 = "INSTR_SWEETP_1"
3333

3434

tests/test_main.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pathlib import Path
22

3-
from autora.doc.pipelines.main import predict
3+
from autora.doc.pipelines.main import eval, generate
44
from autora.doc.runtime.prompts import InstructionPrompts, SystemPrompts
55

66
# dummy HF model for testing
@@ -9,7 +9,17 @@
99

1010
def test_predict() -> None:
1111
data = Path(__file__).parent.joinpath("../data/data.jsonl").resolve()
12-
outputs = predict(str(data), TEST_HF_MODEL, SystemPrompts.SYS_1, InstructionPrompts.INSTR_SWEETP_1)
12+
outputs = eval(str(data), TEST_HF_MODEL, SystemPrompts.SYS_1, InstructionPrompts.INSTR_SWEETP_1)
1313
assert len(outputs) == 3, "Expected 3 outputs"
1414
for output in outputs:
1515
assert len(output) > 0, "Expected non-empty output"
16+
17+
18+
def test_generate() -> None:
19+
python_file = __file__
20+
output = Path("output.txt")
21+
output.unlink(missing_ok=True)
22+
generate(python_file, TEST_HF_MODEL, str(output), SystemPrompts.SYS_1, InstructionPrompts.INSTR_SWEETP_1)
23+
assert output.exists(), f"Expected output file {output} to exist"
24+
with open(str(output), "r") as f:
25+
assert len(f.read()) > 0, f"Expected non-empty output file {output}"

0 commit comments

Comments
 (0)