11
11
import io
12
12
import logging
13
13
import gdown
14
+ import pprint
14
15
15
16
from sklearn .metrics import accuracy_score , confusion_matrix
16
17
from collections import defaultdict
27
28
output_dir_name = 'output'
28
29
base_dir = 'lesson_data'
29
30
cache_dir_name = 'cached_responses'
31
+ accuracy_threshold_file = 'accuracy_thresholds.json'
32
+ accuracy_threshold_dir = 'tests/data'
33
+
34
+ pp = pprint .PrettyPrinter (indent = 2 )
30
35
31
36
def command_line_options ():
32
37
parser = argparse .ArgumentParser (description = 'Usage' )
@@ -51,6 +56,8 @@ def command_line_options():
51
56
help = 'Temperature of the LLM. Defaults to 0.0.' )
52
57
parser .add_argument ('-d' , '--download' , action = 'store_true' ,
53
58
help = 're-download lesson files, overwriting previous files' )
59
+ parser .add_argument ('-a' , '--accuracy' , action = 'store_true' ,
60
+ help = 'Run against accuracy thresholds' )
54
61
55
62
args = parser .parse_args ()
56
63
@@ -107,6 +114,13 @@ def get_actual_labels(actual_labels_file, prefix):
107
114
actual_labels [student_id ] = dict (row )
108
115
return actual_labels
109
116
117
+ def get_accuracy_thresholds (accuracy_threshold_file = accuracy_threshold_file , prefix = accuracy_threshold_dir ):
118
+ thresholds = None
119
+ if os .path .exists (os .path .join (prefix , accuracy_threshold_file )):
120
+ with open (os .path .join (prefix , accuracy_threshold_file ), 'r' ) as f :
121
+ thresholds = json .load (f )
122
+ return thresholds
123
+
110
124
111
125
def get_examples (prefix ):
112
126
example_js_files = sorted (glob .glob (os .path .join (prefix , 'examples' , '*.js' )))
@@ -167,11 +181,11 @@ def compute_accuracy(actual_labels, predicted_labels, passing_labels):
167
181
actual = actual_by_criteria [criteria ]
168
182
169
183
confusion_by_criteria [criteria ] = confusion_matrix (actual , predicted , labels = label_names )
170
- accuracy_by_criteria [criteria ] = accuracy_score (actual , predicted ) * 100
184
+ accuracy_by_criteria [criteria ] = accuracy_score (actual , predicted )
171
185
overall_predicted .extend (predicted )
172
186
overall_actual .extend (actual )
173
187
174
- overall_accuracy = accuracy_score (overall_actual , overall_predicted ) * 100
188
+ overall_accuracy = accuracy_score (overall_actual , overall_predicted )
175
189
overall_confusion = confusion_matrix (overall_actual , overall_predicted , labels = label_names )
176
190
177
191
return accuracy_by_criteria , overall_accuracy , confusion_by_criteria , overall_confusion , label_names
@@ -205,13 +219,25 @@ def main():
205
219
command_line = " " .join (os .sys .argv )
206
220
options = command_line_options ()
207
221
main_start_time = time .time ()
222
+ accuracy_failures = {}
223
+ accuracy_pass = True
224
+ accuracy_thresholds = None
225
+
226
+ print (options )
227
+
228
+ if options .accuracy :
229
+ accuracy_thresholds = get_accuracy_thresholds ()
208
230
209
231
for lesson in options .lesson_names :
210
232
prefix = os .path .join (base_dir , lesson )
211
233
212
234
# download lesson files
213
235
if not os .path .exists (prefix ) or options .download :
214
- gdown .download_folder (id = LESSONS [lesson ], output = prefix )
236
+ try :
237
+ gdown .download_folder (id = LESSONS [lesson ], output = prefix )
238
+ except Exception as e :
239
+ print (f"Could not download lesson { lesson } " )
240
+ logging .error (e )
215
241
216
242
# read in lesson files, validate them
217
243
prompt , standard_rubric = read_inputs (prompt_file , standard_rubric_file , prefix )
@@ -244,16 +270,18 @@ def main():
244
270
245
271
# calculate accuracy and generate report
246
272
accuracy_by_criteria , overall_accuracy , confusion_by_criteria , overall_confusion , label_names = compute_accuracy (actual_labels , predicted_labels , options .passing_labels )
273
+ overall_accuracy_percent = overall_accuracy * 100
274
+ accuracy_by_criteria_percent = {k :v * 100 for k ,v in accuracy_by_criteria .items ()}
247
275
report = Report ()
248
276
report .generate_html_output (
249
277
output_file ,
250
278
prompt ,
251
279
rubric ,
252
- accuracy = overall_accuracy ,
280
+ accuracy = overall_accuracy_percent ,
253
281
predicted_labels = predicted_labels ,
254
282
actual_labels = actual_labels ,
255
283
passing_labels = options .passing_labels ,
256
- accuracy_by_criteria = accuracy_by_criteria ,
284
+ accuracy_by_criteria = accuracy_by_criteria_percent ,
257
285
errors = errors ,
258
286
command_line = command_line ,
259
287
confusion_by_criteria = confusion_by_criteria ,
@@ -263,8 +291,30 @@ def main():
263
291
)
264
292
logging .info (f"main finished in { int (time .time () - main_start_time )} seconds" )
265
293
294
+ if options .accuracy and accuracy_thresholds is not None :
295
+ if overall_accuracy < accuracy_thresholds [lesson ]['overall' ]:
296
+ accuracy_pass = False
297
+ accuracy_failures [lesson ] = {}
298
+ accuracy_failures [lesson ]['overall' ] = {}
299
+ accuracy_failures [lesson ]['overall' ]['accuracy_score' ] = overall_accuracy
300
+ accuracy_failures [lesson ]['overall' ]['threshold' ] = accuracy_thresholds [lesson ]['overall' ]
301
+ for key_concept in accuracy_by_criteria :
302
+ if accuracy_by_criteria [key_concept ] < accuracy_thresholds [lesson ]['key_concepts' ][key_concept ]:
303
+ accuracy_pass = False
304
+ if lesson not in accuracy_failures .keys (): accuracy_failures [lesson ] = {}
305
+ if 'key_concepts' not in accuracy_failures [lesson ].keys (): accuracy_failures [lesson ]['key_concepts' ] = {}
306
+ if key_concept not in accuracy_failures [lesson ]['key_concepts' ].keys () : accuracy_failures [lesson ]['key_concepts' ][key_concept ] = {}
307
+ accuracy_failures [lesson ]['key_concepts' ][key_concept ]['accuracy_score' ] = accuracy_by_criteria [key_concept ]
308
+ accuracy_failures [lesson ]['key_concepts' ][key_concept ]['threshold' ] = accuracy_thresholds [lesson ]['key_concepts' ][key_concept ]
309
+
266
310
os .system (f"open { output_file } " )
267
311
312
+ if not accuracy_pass and len (accuracy_failures .keys ()) > 0 :
313
+ logging .error (f"The following thresholds were not met:\n { pp .pformat (accuracy_failures )} " )
314
+ print (("PASS" if accuracy_pass else "FAIL" ))
315
+
316
+ return accuracy_pass
317
+
268
318
269
319
def init ():
270
320
if __name__ == '__main__' :
0 commit comments