Skip to content

Commit 3a8e24f

Browse files
authored
Merge pull request #38 from code-dot-org/ceara/AITT-408-accuracy-threshold
Ceara/aitt 408 accuracy threshold
2 parents b3dc1df + 0c29478 commit 3a8e24f

File tree

8 files changed

+152
-17
lines changed

8 files changed

+152
-17
lines changed

TESTING.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
## Unit Tests
44

5-
The `./tests` directory contains two categories of test:
5+
The `./tests` directory contains three categories of test:
66

77
* `unit`: Unit tests for library functions in the `./lib` path.
88
* `routes`: Tests routes and their helpers in the `./src` as a unit.
9+
* `accuracy`: Tests accuracy against thresholds by calling OpenAI. Not run by default.
910

1011
All tests are using [pytest](https://docs.pytest.org/en/7.4.x/).
1112

@@ -26,6 +27,21 @@ just run `pytest` within a running container's shell session by using the
2627
PYTHONPATH=/app pytest
2728
```
2829

30+
## Accuracy Tests
31+
32+
**Running the Accuracy test hits the OpenAI endpoint and is expensive! Only run this test infrequently**
33+
34+
To run the accuracy threshold test, follow directions in `README.md` to set up your local
35+
environment for running the Rubric Tester. You can then run `./bin/test_accuracy.sh` to run
36+
tests locally, including the accuracy threshold test.
37+
38+
You can pass any arguments to pytest with this script. For instance, the `-k` argument can filter tests by name:
39+
40+
```
41+
# Run only tests with 'accuracy' in the name:
42+
./bin/test_accuracy.sh -k accuracy
43+
```
44+
2945
## Scripted
3046

3147
This assumes you have built and are running the container as depicted in the main `README.md`.

bin/test_accuracy.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
3+
set -eu
4+
5+
echo "Running: \`coverage run -m pytest --accuracy $@ && coverage report -m\`"
6+
coverage run -m pytest --accuracy $@ && coverage report -m

lib/assessment/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
SUPPORTED_MODELS = ['gpt-4-0314', 'gpt-4-32k-0314', 'gpt-4-0613', 'gpt-4-32k-0613', 'gpt-4-1106-preview']
44
DEFAULT_MODEL = 'gpt-4-0613'
55
LESSONS = {
6-
"U3-2022-L10" : "1ROCbvHb3yWGVoQqzKAjwdaF0dSRPUjy_",
7-
"U3-2022-L13" : "1kGHeY5LRpFJ9xVRoBEWbyOJyKm4wClqw",
6+
# "U3-2022-L10" : "1ROCbvHb3yWGVoQqzKAjwdaF0dSRPUjy_",
7+
# "U3-2022-L13" : "1kGHeY5LRpFJ9xVRoBEWbyOJyKm4wClqw",
88
"U3-2022-L17" : "1WirJLIFgo-anxAz-kZXDVQ2Tl_8OuX22",
99
"U3-2022-L20" : "115BHvZ1kJC2xhUSOBkLiE8DC1YgcjyRd",
10-
"U3-2022-L23" : "12OJex4l9OhWrnbLenpvZAibtfiFWWdzx",
10+
# "U3-2022-L23" : "12OJex4l9OhWrnbLenpvZAibtfiFWWdzx",
1111
"New-U3-2022-L10" : "15xAUFVeGkXeG18mDWBOKN6yJPpI185tg",
1212
"New-U3-2022-L13" : "14LI9eRRgxL5rRQK6FoUI0ow_YIb5V0mg",
1313
}

lib/assessment/rubric_tester.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import io
1212
import logging
1313
import gdown
14+
import pprint
1415

1516
from sklearn.metrics import accuracy_score, confusion_matrix
1617
from collections import defaultdict
@@ -27,6 +28,10 @@
2728
output_dir_name = 'output'
2829
base_dir = 'lesson_data'
2930
cache_dir_name = 'cached_responses'
31+
accuracy_threshold_file = 'accuracy_thresholds.json'
32+
accuracy_threshold_dir = 'tests/data'
33+
34+
pp = pprint.PrettyPrinter(indent=2)
3035

3136
def command_line_options():
3237
parser = argparse.ArgumentParser(description='Usage')
@@ -51,6 +56,8 @@ def command_line_options():
5156
help='Temperature of the LLM. Defaults to 0.0.')
5257
parser.add_argument('-d', '--download', action='store_true',
5358
help='re-download lesson files, overwriting previous files')
59+
parser.add_argument('-a', '--accuracy', action='store_true',
60+
help='Run against accuracy thresholds')
5461

5562
args = parser.parse_args()
5663

@@ -107,6 +114,13 @@ def get_actual_labels(actual_labels_file, prefix):
107114
actual_labels[student_id] = dict(row)
108115
return actual_labels
109116

117+
def get_accuracy_thresholds(accuracy_threshold_file=accuracy_threshold_file, prefix=accuracy_threshold_dir):
118+
thresholds = None
119+
if os.path.exists(os.path.join(prefix, accuracy_threshold_file)):
120+
with open(os.path.join(prefix, accuracy_threshold_file), 'r') as f:
121+
thresholds = json.load(f)
122+
return thresholds
123+
110124

111125
def get_examples(prefix):
112126
example_js_files = sorted(glob.glob(os.path.join(prefix, 'examples', '*.js')))
@@ -167,11 +181,11 @@ def compute_accuracy(actual_labels, predicted_labels, passing_labels):
167181
actual = actual_by_criteria[criteria]
168182

169183
confusion_by_criteria[criteria] = confusion_matrix(actual, predicted, labels=label_names)
170-
accuracy_by_criteria[criteria] = accuracy_score(actual, predicted) * 100
184+
accuracy_by_criteria[criteria] = accuracy_score(actual, predicted)
171185
overall_predicted.extend(predicted)
172186
overall_actual.extend(actual)
173187

174-
overall_accuracy = accuracy_score(overall_actual, overall_predicted) * 100
188+
overall_accuracy = accuracy_score(overall_actual, overall_predicted)
175189
overall_confusion = confusion_matrix(overall_actual, overall_predicted, labels=label_names)
176190

177191
return accuracy_by_criteria, overall_accuracy, confusion_by_criteria, overall_confusion, label_names
@@ -205,13 +219,25 @@ def main():
205219
command_line = " ".join(os.sys.argv)
206220
options = command_line_options()
207221
main_start_time = time.time()
222+
accuracy_failures = {}
223+
accuracy_pass = True
224+
accuracy_thresholds = None
225+
226+
print(options)
227+
228+
if options.accuracy:
229+
accuracy_thresholds = get_accuracy_thresholds()
208230

209231
for lesson in options.lesson_names:
210232
prefix = os.path.join(base_dir, lesson)
211233

212234
# download lesson files
213235
if not os.path.exists(prefix) or options.download:
214-
gdown.download_folder(id=LESSONS[lesson], output=prefix)
236+
try:
237+
gdown.download_folder(id=LESSONS[lesson], output=prefix)
238+
except Exception as e:
239+
print(f"Could not download lesson {lesson}")
240+
logging.error(e)
215241

216242
# read in lesson files, validate them
217243
prompt, standard_rubric = read_inputs(prompt_file, standard_rubric_file, prefix)
@@ -244,16 +270,18 @@ def main():
244270

245271
# calculate accuracy and generate report
246272
accuracy_by_criteria, overall_accuracy, confusion_by_criteria, overall_confusion, label_names = compute_accuracy(actual_labels, predicted_labels, options.passing_labels)
273+
overall_accuracy_percent = overall_accuracy * 100
274+
accuracy_by_criteria_percent = {k:v*100 for k,v in accuracy_by_criteria.items()}
247275
report = Report()
248276
report.generate_html_output(
249277
output_file,
250278
prompt,
251279
rubric,
252-
accuracy=overall_accuracy,
280+
accuracy=overall_accuracy_percent,
253281
predicted_labels=predicted_labels,
254282
actual_labels=actual_labels,
255283
passing_labels=options.passing_labels,
256-
accuracy_by_criteria=accuracy_by_criteria,
284+
accuracy_by_criteria=accuracy_by_criteria_percent,
257285
errors=errors,
258286
command_line=command_line,
259287
confusion_by_criteria=confusion_by_criteria,
@@ -263,8 +291,30 @@ def main():
263291
)
264292
logging.info(f"main finished in {int(time.time() - main_start_time)} seconds")
265293

294+
if options.accuracy and accuracy_thresholds is not None:
295+
if overall_accuracy < accuracy_thresholds[lesson]['overall']:
296+
accuracy_pass = False
297+
accuracy_failures[lesson] = {}
298+
accuracy_failures[lesson]['overall'] = {}
299+
accuracy_failures[lesson]['overall']['accuracy_score'] = overall_accuracy
300+
accuracy_failures[lesson]['overall']['threshold'] = accuracy_thresholds[lesson]['overall']
301+
for key_concept in accuracy_by_criteria:
302+
if accuracy_by_criteria[key_concept] < accuracy_thresholds[lesson]['key_concepts'][key_concept]:
303+
accuracy_pass = False
304+
if lesson not in accuracy_failures.keys(): accuracy_failures[lesson] = {}
305+
if 'key_concepts' not in accuracy_failures[lesson].keys(): accuracy_failures[lesson]['key_concepts'] = {}
306+
if key_concept not in accuracy_failures[lesson]['key_concepts'].keys() : accuracy_failures[lesson]['key_concepts'][key_concept] = {}
307+
accuracy_failures[lesson]['key_concepts'][key_concept]['accuracy_score'] = accuracy_by_criteria[key_concept]
308+
accuracy_failures[lesson]['key_concepts'][key_concept]['threshold'] = accuracy_thresholds[lesson]['key_concepts'][key_concept]
309+
266310
os.system(f"open {output_file}")
267311

312+
if not accuracy_pass and len(accuracy_failures.keys()) > 0:
313+
logging.error(f"The following thresholds were not met:\n{pp.pformat(accuracy_failures)}")
314+
print(("PASS" if accuracy_pass else "FAIL"))
315+
316+
return accuracy_pass
317+
268318

269319
def init():
270320
if __name__ == '__main__':

tests/accuracy/conftest.py

Whitespace-only changes.

tests/accuracy/test_accuracy.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
import os
3+
4+
from unittest import mock
5+
6+
from lib.assessment.rubric_tester import (
7+
main,
8+
)
9+
10+
accuracy = pytest.mark.skipif("not config.getoption('accuracy')")
11+
12+
@accuracy
13+
@pytest.mark.accuracy_setup
14+
class TestAccuracy:
15+
def test_accuracy(self):
16+
assert "OPENAI_API_KEY" in os.environ
17+
with mock.patch('sys.argv', ['rubric_tester.py', '-a']):
18+
ret = main()
19+
assert ret == True

tests/conftest.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,15 @@
99
import contextlib
1010
import os
1111

12-
12+
def pytest_addoption(parser):
13+
parser.addoption('--accuracy', action='store_true', dest="accuracy",
14+
default=False, help="enable accuracy tests that run openai")
15+
16+
def pytest_configure(config):
17+
config.addinivalue_line(
18+
"markers", "accuracy_setup"
19+
)
20+
1321
@pytest.fixture()
1422
def app():
1523
app = create_app()
@@ -37,17 +45,18 @@ def configured_app():
3745

3846
# clean up / reset resources here
3947

40-
4148
@pytest.fixture(autouse=True)
42-
def mock_env_vars():
49+
def mock_env_vars(request):
4350
""" Ensures env vars are not touched by tests.
4451
"""
45-
46-
from unittest.mock import patch
47-
48-
# Ensure the os.environ passes out a new dictionary
49-
with patch.dict(os.environ, {}, clear=True):
52+
if 'accuracy_setup' in request.keywords:
5053
yield
54+
else:
55+
from unittest.mock import patch
56+
print("no env vars")
57+
# Ensure the os.environ passes out a new dictionary
58+
with patch.dict(os.environ, {}, clear=True):
59+
yield
5160

5261

5362
@pytest.fixture()

tests/data/accuracy_thresholds.json

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"U3-2022-L17": {
3+
"overall": 0.7,
4+
"key_concepts": {
5+
"Algorithms and Control - Conditionals": 0.7,
6+
"Algorithms and Control - User Input": 0.7,
7+
"Modularity - Multiple Sprites": 0.7,
8+
"Position and Movement": 0.7
9+
}
10+
},
11+
"U3-2022-L20": {
12+
"overall": 0.7,
13+
"key_concepts": {
14+
"Algorithms and Control Structures": 0.7,
15+
"Program Development 2": 0.7,
16+
"Variables": 0.7
17+
}
18+
},
19+
"New-U3-2022-L10": {
20+
"overall": 0.7,
21+
"key_concepts": {
22+
"Modularity - Sprites and Sprite Properties": 0.7,
23+
"Position - Elements and the Coordinate System": 0.7,
24+
"Program Development - Program Sequence": 0.7
25+
}
26+
},
27+
"New-U3-2022-L13": {
28+
"overall": 0.7,
29+
"key_concepts": {
30+
"Modularity - Sprites and Sprite Properties": 0.7,
31+
"Optional \u201cStretch\u201d Feature - Variables": 0.7,
32+
"Position and Movement": 0.7
33+
}
34+
}
35+
}

0 commit comments

Comments
 (0)