Skip to content

Commit 21395da

Browse files
Shocker444tmattio
andauthored
Implement AUC-ROC in Kaun metrics (#124)
Co-authored-by: Thibaut Mattio <thibaut.mattio@gmail.com>
1 parent e0b9882 commit 21395da

File tree

4 files changed

+132
-11
lines changed

4 files changed

+132
-11
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ All notable changes to this project will be documented in this file.
5454
- Allow metric history to tolerate metrics that appear or disappear between epochs so dynamic metric sets no longer raise during training (@tmattio)
5555
- Make `Optimizer.clip_by_global_norm` robust to zero gradients and empty parameter trees to avoid NaNs during training (@tmattio)
5656
- Split CSV loader into `from_csv` and `from_csv_with_labels` to retain labels when requested (#114, @Satarupa22-SD)
57+
- Implement AUC-ROC in Kaun metrics and simplify its API (#109 @Shocker444)
5758

5859
### Talon
5960

kaun/lib/kaun/metrics.ml

Lines changed: 90 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,16 +296,100 @@ let f1_score ?(threshold = 0.5) ?(averaging = Micro) ?(beta = 1.0) () =
296296
~reset:(fun _ -> [])
297297

298298
(* Placeholder implementations for complex metrics *)
299-
let auc_roc ?(num_thresholds = 200) ?(curve = false) () =
300-
let _ = num_thresholds in
301-
let _ = curve in
299+
let auc_roc () =
302300
create_custom ~name:"auc_roc"
303301
~init:(fun () -> [])
304-
~update:(fun state ~predictions:_ ~targets:_ ?weights:_ () -> state)
305-
~compute:(fun _ ->
306-
failwith "AUC-ROC not yet implemented - requires trapezoid integration")
302+
~update:(fun state ~predictions ~targets ?weights () ->
303+
let predictions = Rune.reshape [| -1 |] predictions in
304+
let targets = Rune.reshape [| -1 |] targets in
305+
let dtype =
306+
match state with
307+
| [ preds_acc; _; _ ] -> Rune.dtype preds_acc
308+
| _ -> Rune.dtype predictions
309+
in
310+
let predictions = Rune.cast dtype predictions in
311+
let targets = Rune.cast dtype targets in
312+
let weights =
313+
match weights with
314+
| Some w -> Rune.cast dtype (Rune.reshape [| -1 |] w)
315+
| None -> Rune.ones dtype (Rune.shape predictions)
316+
in
317+
match state with
318+
| [] -> [ predictions; targets; weights ]
319+
| [ preds_acc; targets_acc; weights_acc ] ->
320+
let preds_acc =
321+
Rune.concatenate ~axis:0 [ preds_acc; predictions ]
322+
in
323+
let targets_acc =
324+
Rune.concatenate ~axis:0 [ targets_acc; targets ]
325+
in
326+
let weights_acc =
327+
Rune.concatenate ~axis:0 [ weights_acc; weights ]
328+
in
329+
[ preds_acc; targets_acc; weights_acc ]
330+
| _ -> failwith "Invalid auc_roc state")
331+
~compute:(fun state ->
332+
match state with
333+
| [ preds; targets; weights ] ->
334+
let dtype = Rune.dtype preds in
335+
let ones = Rune.ones dtype (Rune.shape targets) in
336+
let sorted_idx =
337+
Rune.argsort ~axis:0 ~descending:true preds
338+
in
339+
let sorted_targets =
340+
Rune.take_along_axis ~axis:0 sorted_idx targets
341+
in
342+
let sorted_weights =
343+
Rune.take_along_axis ~axis:0 sorted_idx weights
344+
in
345+
346+
let positives = Rune.mul sorted_targets sorted_weights in
347+
let negatives =
348+
Rune.mul (Rune.sub ones sorted_targets) sorted_weights
349+
in
350+
351+
let cum_tp = Rune.cumsum ~axis:0 positives in
352+
let cum_fp = Rune.cumsum ~axis:0 negatives in
353+
let zero = scalar_tensor dtype 0.0 in
354+
let cum_tp =
355+
Rune.concatenate ~axis:0 [ Rune.reshape [| 1 |] zero; cum_tp ]
356+
in
357+
let cum_fp =
358+
Rune.concatenate ~axis:0 [ Rune.reshape [| 1 |] zero; cum_fp ]
359+
in
360+
361+
let total_pos = Rune.item [] (Rune.sum positives) in
362+
let total_neg = Rune.item [] (Rune.sum negatives) in
363+
364+
let ratio cumulative total =
365+
if Float.abs total < 1e-12 then
366+
Rune.zeros dtype (Rune.shape cumulative)
367+
else
368+
let total_tensor = scalar_tensor dtype total in
369+
Rune.div cumulative total_tensor
370+
in
371+
372+
let tpr = ratio cum_tp total_pos in
373+
let fpr = ratio cum_fp total_neg in
374+
375+
let n = Rune.size tpr in
376+
if n < 2 then scalar_tensor dtype 0.0
377+
else
378+
let tail_fpr = Rune.slice [ Rune.R (1, n) ] fpr in
379+
let head_fpr = Rune.slice [ Rune.R (0, n - 1) ] fpr in
380+
let dx = Rune.sub tail_fpr head_fpr in
381+
382+
let tail_tpr = Rune.slice [ Rune.R (1, n) ] tpr in
383+
let head_tpr = Rune.slice [ Rune.R (0, n - 1) ] tpr in
384+
let avg_tpr =
385+
Rune.mul (scalar_tensor dtype 0.5)
386+
(Rune.add tail_tpr head_tpr)
387+
in
388+
Rune.sum (Rune.mul dx avg_tpr)
389+
| _ -> failwith "Invalid auc_roc state")
307390
~reset:(fun _ -> [])
308391

392+
309393
let auc_pr ?(num_thresholds = 200) ?(curve = false) () =
310394
let _ = num_thresholds in
311395
let _ = curve in

kaun/lib/kaun/metrics.mli

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,13 @@ val f1_score :
114114
@param averaging Multi-class averaging strategy (default: Micro)
115115
@param beta Weight of recall vs precision (default: 1.0 for F1) *)
116116

117-
val auc_roc : ?num_thresholds:int -> ?curve:bool -> unit -> 'layout t
118-
(** [auc_roc ?num_thresholds ?curve ()] creates an AUC-ROC metric.
117+
val auc_roc : unit -> 'layout t
118+
(** [auc_roc ()] creates an AUC-ROC metric.
119119
120120
Area Under the Receiver Operating Characteristic Curve.
121121
122-
@param num_thresholds
123-
Number of thresholds for curve computation (default: 200)
124-
@param curve If true, also return the ROC curve points (default: false) *)
122+
Computes the exact ROC integral by sorting predictions and accumulating
123+
true/false positive rates across all seen batches. *)
125124

126125
val auc_pr : ?num_thresholds:int -> ?curve:bool -> unit -> 'layout t
127126
(** [auc_pr ?num_thresholds ?curve ()] creates an AUC-PR metric.

kaun/test/test_metrics.ml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,41 @@ let test_f1_score () =
7878
let expected = Rune.scalar dtype 0.8 in
7979
check (tensor_testable 1e-5) "f1 score" expected result
8080

81+
let test_auc_roc () =
82+
let dtype = Rune.float32 in
83+
84+
let predictions = Rune.create dtype [| 4 |] [| 0.8; 0.7; 0.6; 0.3 |] in
85+
let targets = Rune.create dtype [| 4 |] [| 1.; 1.; 0.; 0. |] in
86+
87+
let auc = Metrics.auc_roc () in
88+
Metrics.update auc ~predictions ~targets ();
89+
let result = Metrics.compute auc in
90+
(* For perfectly separable predictions, AUC should be 1.0 *)
91+
let expected = Rune.scalar dtype 1.0 in
92+
check (tensor_testable 1e-5) "auc roc" expected result
93+
94+
let test_auc_roc_multiple_updates () =
95+
let dtype = Rune.float32 in
96+
97+
let predictions_full = Rune.create dtype [| 4 |] [| 0.8; 0.7; 0.6; 0.3 |] in
98+
let targets_full = Rune.create dtype [| 4 |] [| 1.; 1.; 0.; 0. |] in
99+
100+
let auc_single = Metrics.auc_roc () in
101+
Metrics.update auc_single ~predictions:predictions_full ~targets:targets_full ();
102+
let full_result = Metrics.compute auc_single in
103+
104+
let auc_chunked = Metrics.auc_roc () in
105+
let predictions_1 = Rune.create dtype [| 2 |] [| 0.8; 0.7 |] in
106+
let targets_1 = Rune.create dtype [| 2 |] [| 1.; 1. |] in
107+
Metrics.update auc_chunked ~predictions:predictions_1 ~targets:targets_1 ();
108+
let predictions_2 = Rune.create dtype [| 2 |] [| 0.6; 0.3 |] in
109+
let targets_2 = Rune.create dtype [| 2 |] [| 0.; 0. |] in
110+
Metrics.update auc_chunked ~predictions:predictions_2 ~targets:targets_2 ();
111+
let chunked_result = Metrics.compute auc_chunked in
112+
113+
check (tensor_testable 1e-5) "auc roc incremental"
114+
full_result chunked_result
115+
81116
let test_confusion_matrix () =
82117
let dtype = Rune.float32 in
83118

@@ -376,6 +411,8 @@ let () =
376411
test_case "accuracy" `Quick test_accuracy;
377412
test_case "precision_recall" `Quick test_precision_recall;
378413
test_case "f1_score" `Quick test_f1_score;
414+
test_case "auc_roc" `Quick test_auc_roc;
415+
test_case "auc_roc_multiple_updates" `Quick test_auc_roc_multiple_updates;
379416
test_case "confusion_matrix" `Quick test_confusion_matrix;
380417
] );
381418
( "regression",

0 commit comments

Comments
 (0)