Skip to content

Commit 61d2480

Browse files
Add metrics calculations to the inference pipeline (#23)
1 parent fdcce7e commit 61d2480

File tree

6 files changed

+96
-3
lines changed

6 files changed

+96
-3
lines changed

.mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ ignore_missing_imports = True
88

99
[mypy-mlflow.*]
1010
ignore_missing_imports = True
11+
12+
[mypy-nltk.*]
13+
ignore_missing_imports = True

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ repos:
7878
# supported by your project here, or alternatively use
7979
# pre-commit's default_language_version, see
8080
# https://pre-commit.com/#top_level-default_language_version
81-
language_version: python3.10
81+
language_version: python3
8282

8383

8484

azureml/conda.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ dependencies:
1414
- transformers>=4.35.2
1515
- xformers
1616
- scipy
17+
- nltk
1718
# This works, while installing from pytorch and cuda from conda does not
1819
- torch==2.0.1

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies = [
2121
# This works, while installing from pytorch and cuda from conda does not",
2222
"torch==2.0.1",
2323
"transformers>=4.35.2",
24+
"nltk",
2425
]
2526

2627
# On a mac, install optional dependencies with `pip install '.[dev]'` (include the single quotes)

src/autora/doc/pipelines/main.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import itertools
22
import logging
33
from timeit import default_timer as timer
4-
from typing import List
4+
from typing import List, Tuple
55

6+
import nltk
67
import torch
78
import typer
9+
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu
10+
from nltk.translate.meteor_score import single_meteor_score
811

912
from autora.doc.runtime.predict_hf import Predictor
1013
from autora.doc.runtime.prompts import INSTR, SYS, InstructionPrompts, SystemPrompts
@@ -17,6 +20,33 @@
1720
logger = logging.getLogger(__name__)
1821

1922

23+
def evaluate_documentation(predictions: List[List[str]], references: List[str]) -> Tuple[float, float]:
24+
nltk.download("wordnet")
25+
26+
# Tokenize references
27+
tokenized_references = [ref.split() for ref in references]
28+
# Currently there is only 1 prediction for 1 reference, need to avg in future
29+
tokenized_predictions = [pred[0].split() if pred else [] for pred in predictions]
30+
31+
# Calculate BLEU score with smoothing function
32+
# SmoothingFunction().method1 is used to avoid zero scores for n-grams not found in the reference.
33+
bleu = corpus_bleu(
34+
# Wrap each reference list in another list
35+
[[tokenized_ref] for tokenized_ref in tokenized_references],
36+
tokenized_predictions,
37+
smoothing_function=SmoothingFunction().method1,
38+
)
39+
40+
# Calculate METEOR scores
41+
meteor_scores = [
42+
single_meteor_score(tokenized_ref, tokenized_pred)
43+
for tokenized_ref, tokenized_pred in zip(tokenized_references, tokenized_predictions)
44+
]
45+
meteor = sum(meteor_scores) / len(predictions) if predictions else 0
46+
47+
return (bleu, meteor)
48+
49+
2050
@app.command(help="Evaluate model on a data file")
2151
def eval(
2252
data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"),
@@ -55,6 +85,8 @@ def eval(
5585
pred = Predictor(model_path)
5686
timer_start = timer()
5787
predictions = pred.predict(sys_prompt, instr_prompt, inputs, **param_dict)
88+
bleu, meteor = evaluate_documentation(predictions, labels)
89+
5890
timer_end = timer()
5991
pred_time = timer_end - timer_start
6092
mlflow.log_metric("prediction_time/doc", pred_time / (len(inputs)))
@@ -63,13 +95,17 @@ def eval(
6395
mlflow.log_text(inputs[i], f"input_{i}.py")
6496
for j in range(len(predictions[i])):
6597
mlflow.log_text(predictions[i][j], f"prediction_{i}_{j}.txt")
98+
mlflow.log_text("bleu_score is ", str(bleu))
99+
mlflow.log_text("meteor_score is ", str(meteor))
66100

67101
# flatten predictions for counting tokens
68102
predictions_flat = list(itertools.chain.from_iterable(predictions))
69103
tokens = pred.tokenize(predictions_flat)["input_ids"]
70104
total_tokens = sum([len(token) for token in tokens])
71105
mlflow.log_metric("total_tokens", total_tokens)
72106
mlflow.log_metric("tokens/sec", total_tokens / pred_time)
107+
mlflow.log_metric("bleu_score", round(bleu, 5))
108+
mlflow.log_metric("meteor_score", round(meteor, 5))
73109
return predictions
74110

75111

tests/test_main.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from pathlib import Path
22

3-
from autora.doc.pipelines.main import eval, generate, import_data
3+
import jsonlines
4+
import pytest
5+
6+
from autora.doc.pipelines.main import eval, evaluate_documentation, generate, import_data
47
from autora.doc.runtime.prompts import InstructionPrompts, SystemPrompts
58

69
# dummy HF model for testing
@@ -15,6 +18,55 @@ def test_predict() -> None:
1518
assert len(output[0]) > 0, "Expected non-empty output"
1619

1720

21+
def test_evaluation() -> None:
22+
# Test Case: Meteor and Bleu scores are close to 1
23+
data = Path(__file__).parent.joinpath("../data/sweetpea/data.jsonl").resolve()
24+
with jsonlines.open(data) as reader:
25+
items = [item for item in reader]
26+
labels = [item["output"] for item in items]
27+
predictions = [[item["output"]] for item in items]
28+
29+
bleu, meteor = evaluate_documentation(predictions, labels)
30+
assert bleu == pytest.approx(1, 0.01), f"BLEU Score is {bleu}"
31+
assert meteor == pytest.approx(1, 0.01), f"METEOR Score is {meteor}"
32+
33+
34+
def test_extra_token_in_prediction() -> None:
35+
# Test Case bleu score should be less due to brevity penalty and meteor is robust to small mistakes
36+
labels = ["this is a test"]
37+
predictions = [["this is a test extra"]]
38+
bleu, meteor = evaluate_documentation(predictions, labels)
39+
assert 0.6 <= bleu <= 0.8, f"BLEU Score is {bleu}"
40+
assert 0.8 <= meteor <= 1, f"METEOR Score is {meteor}"
41+
42+
43+
def test_missing_token_in_prediction() -> None:
44+
# bleu score is less, meteor is higher
45+
labels = ["this is a test"]
46+
predictions = [["this is a"]]
47+
bleu, meteor = evaluate_documentation(predictions, labels)
48+
assert 0.4 <= bleu <= 0.6, f"BLEU Score is {bleu}"
49+
assert 0.6 <= meteor <= 0.8, f"METEOR Score is {meteor}"
50+
51+
52+
def test_completely_different_tokens() -> None:
53+
# both scores are less, as no common tokens
54+
labels = ["this is a test"]
55+
predictions = [["completely different sentence"]]
56+
bleu, meteor = evaluate_documentation(predictions, labels)
57+
assert bleu <= 0.1, f"BLEU Score is {bleu}"
58+
assert meteor <= 0.1, f"METEOR Score is {meteor}"
59+
60+
61+
def test_partially_matching_tokens() -> None:
62+
# As ngrams arent matching because of extra token within, BLEU score is very less. Meteor gives a good score only.
63+
labels = ["this is a test"]
64+
predictions = [["this is a different test"]]
65+
bleu, meteor = evaluate_documentation(predictions, labels)
66+
assert 0.25 <= bleu <= 0.4, f"BLEU Score is {bleu}"
67+
assert 0.8 <= meteor <= 0.95, f"METEOR Score is {meteor}"
68+
69+
1870
def test_generate() -> None:
1971
python_file = __file__
2072
output = Path("output.txt")

0 commit comments

Comments
 (0)