Skip to content

Commit 82d0daa

Browse files
Shocker444tmattio
andauthored
Implement AUC-PR in Kaun metrics (#131)
Co-authored-by: Thibaut Mattio <thibaut.mattio@gmail.com>
1 parent 60d20c3 commit 82d0daa

File tree

4 files changed

+133
-44
lines changed

4 files changed

+133
-44
lines changed

CHANGES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ All notable changes to this project will be documented in this file.
5656
- Allow metric history to tolerate metrics that appear or disappear between epochs so dynamic metric sets no longer raise during training (@tmattio)
5757
- Make `Optimizer.clip_by_global_norm` robust to zero gradients and empty parameter trees to avoid NaNs during training (@tmattio)
5858
- Split CSV loader into `from_csv` and `from_csv_with_labels` to retain labels when requested (#114, @Satarupa22-SD)
59-
- Implement AUC-ROC in Kaun metrics and simplify its API (#109 @Shocker444)
59+
- Implement AUC-ROC and AUC-PR in Kaun metrics and simplify their signatures (#109, #131, @Shocker444)
6060

6161
### Talon
6262

kaun/lib/kaun/metrics.ml

Lines changed: 88 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,43 @@ type 'layout metric_fn =
3232
let scalar_tensor dtype value = Rune.scalar dtype value
3333
let ones_like t = Rune.ones (Rune.dtype t) (Rune.shape t)
3434

35+
let accumulate_rank_metric_state metric_name state ~predictions ~targets
36+
?weights () =
37+
let predictions = Rune.reshape [| -1 |] predictions in
38+
let targets = Rune.reshape [| -1 |] targets in
39+
let dtype =
40+
match state with
41+
| [ preds_acc; _; _ ] -> Rune.dtype preds_acc
42+
| _ -> Rune.dtype predictions
43+
in
44+
let predictions = Rune.cast dtype predictions in
45+
let targets = Rune.cast dtype targets in
46+
let weights =
47+
match weights with
48+
| Some w -> Rune.cast dtype (Rune.reshape [| -1 |] w)
49+
| None -> ones_like predictions
50+
in
51+
match state with
52+
| [] -> [ predictions; targets; weights ]
53+
| [ preds_acc; targets_acc; weights_acc ] ->
54+
let preds_acc = Rune.concatenate ~axis:0 [ preds_acc; predictions ] in
55+
let targets_acc = Rune.concatenate ~axis:0 [ targets_acc; targets ] in
56+
let weights_acc = Rune.concatenate ~axis:0 [ weights_acc; weights ] in
57+
[ preds_acc; targets_acc; weights_acc ]
58+
| _ -> failwith (Printf.sprintf "Invalid %s state" metric_name)
59+
60+
let prepare_rank_curve_inputs preds targets weights =
61+
let dtype = Rune.dtype preds in
62+
let sorted_idx = Rune.argsort ~axis:0 ~descending:true preds in
63+
let sorted_targets = Rune.take_along_axis ~axis:0 sorted_idx targets in
64+
let sorted_weights = Rune.take_along_axis ~axis:0 sorted_idx weights in
65+
let positives = Rune.mul sorted_targets sorted_weights in
66+
let negatives =
67+
let ones = Rune.ones dtype (Rune.shape sorted_targets) in
68+
Rune.mul (Rune.sub ones sorted_targets) sorted_weights
69+
in
70+
(positives, negatives)
71+
3572
(** Core metric operations *)
3673

3774
let update metric ~predictions ~targets ?weights () =
@@ -300,45 +337,15 @@ let auc_roc () =
300337
create_custom ~name:"auc_roc"
301338
~init:(fun () -> [])
302339
~update:(fun state ~predictions ~targets ?weights () ->
303-
let predictions = Rune.reshape [| -1 |] predictions in
304-
let targets = Rune.reshape [| -1 |] targets in
305-
let dtype =
306-
match state with
307-
| [ preds_acc; _; _ ] -> Rune.dtype preds_acc
308-
| _ -> Rune.dtype predictions
309-
in
310-
let predictions = Rune.cast dtype predictions in
311-
let targets = Rune.cast dtype targets in
312-
let weights =
313-
match weights with
314-
| Some w -> Rune.cast dtype (Rune.reshape [| -1 |] w)
315-
| None -> Rune.ones dtype (Rune.shape predictions)
316-
in
317-
match state with
318-
| [] -> [ predictions; targets; weights ]
319-
| [ preds_acc; targets_acc; weights_acc ] ->
320-
let preds_acc = Rune.concatenate ~axis:0 [ preds_acc; predictions ] in
321-
let targets_acc = Rune.concatenate ~axis:0 [ targets_acc; targets ] in
322-
let weights_acc = Rune.concatenate ~axis:0 [ weights_acc; weights ] in
323-
[ preds_acc; targets_acc; weights_acc ]
324-
| _ -> failwith "Invalid auc_roc state")
340+
accumulate_rank_metric_state "auc_roc" state ~predictions ~targets
341+
?weights ())
325342
~compute:(fun state ->
326343
match state with
327344
| [ preds; targets; weights ] ->
328-
let dtype = Rune.dtype preds in
329-
let ones = Rune.ones dtype (Rune.shape targets) in
330-
let sorted_idx = Rune.argsort ~axis:0 ~descending:true preds in
331-
let sorted_targets =
332-
Rune.take_along_axis ~axis:0 sorted_idx targets
333-
in
334-
let sorted_weights =
335-
Rune.take_along_axis ~axis:0 sorted_idx weights
336-
in
337-
338-
let positives = Rune.mul sorted_targets sorted_weights in
339-
let negatives =
340-
Rune.mul (Rune.sub ones sorted_targets) sorted_weights
345+
let positives, negatives =
346+
prepare_rank_curve_inputs preds targets weights
341347
in
348+
let dtype = Rune.dtype positives in
342349

343350
let cum_tp = Rune.cumsum ~axis:0 positives in
344351
let cum_fp = Rune.cumsum ~axis:0 negatives in
@@ -380,13 +387,54 @@ let auc_roc () =
380387
| _ -> failwith "Invalid auc_roc state")
381388
~reset:(fun _ -> [])
382389

383-
let auc_pr ?(num_thresholds = 200) ?(curve = false) () =
384-
let _ = num_thresholds in
385-
let _ = curve in
390+
let auc_pr () =
386391
create_custom ~name:"auc_pr"
387392
~init:(fun () -> [])
388-
~update:(fun state ~predictions:_ ~targets:_ ?weights:_ () -> state)
389-
~compute:(fun _ -> failwith "AUC-PR not yet implemented")
393+
~update:(fun state ~predictions ~targets ?weights () ->
394+
accumulate_rank_metric_state "auc_pr" state ~predictions ~targets ?weights
395+
())
396+
~compute:(fun state ->
397+
match state with
398+
| [ preds; targets; weights ] ->
399+
let positives, negatives =
400+
prepare_rank_curve_inputs preds targets weights
401+
in
402+
let dtype = Rune.dtype positives in
403+
404+
let cum_tp = Rune.cumsum ~axis:0 positives in
405+
let cum_fp = Rune.cumsum ~axis:0 negatives in
406+
407+
let cum_fn = Rune.sub (Rune.sum positives) cum_tp in
408+
409+
let zero = scalar_tensor dtype 0.0 in
410+
let cum_tp =
411+
Rune.concatenate ~axis:0 [ Rune.reshape [| 1 |] zero; cum_tp ]
412+
in
413+
let cum_fp =
414+
Rune.concatenate ~axis:0 [ Rune.reshape [| 1 |] zero; cum_fp ]
415+
in
416+
let cum_fn =
417+
Rune.concatenate ~axis:0 [ Rune.reshape [| 1 |] zero; cum_fn ]
418+
in
419+
420+
let precision_denom = Rune.add cum_tp cum_fp in
421+
let recall_denom = Rune.add cum_tp cum_fn in
422+
let eps = scalar_tensor dtype 1e-7 in
423+
424+
let precision = Rune.div cum_tp (Rune.add precision_denom eps) in
425+
let recall = Rune.div cum_tp (Rune.add recall_denom eps) in
426+
427+
let n = Rune.size precision in
428+
if n < 2 then scalar_tensor dtype 0.0
429+
else
430+
let tail_recall = Rune.slice [ Rune.R (1, n) ] recall in
431+
let head_recall = Rune.slice [ Rune.R (0, n - 1) ] recall in
432+
let dx = Rune.sub tail_recall head_recall in
433+
434+
let precision_k = Rune.slice [ Rune.R (1, n) ] precision in
435+
436+
Rune.sum (Rune.mul dx precision_k)
437+
| _ -> failwith "Invalid auc_pr state")
390438
~reset:(fun _ -> [])
391439

392440
let confusion_matrix ~num_classes ?(normalize = `None) () =

kaun/lib/kaun/metrics.mli

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,13 @@ val auc_roc : unit -> 'layout t
122122
Computes the exact ROC integral by sorting predictions and accumulating
123123
true/false positive rates across all seen batches. *)
124124

125-
val auc_pr : ?num_thresholds:int -> ?curve:bool -> unit -> 'layout t
126-
(** [auc_pr ?num_thresholds ?curve ()] creates an AUC-PR metric.
125+
val auc_pr : unit -> 'layout t
126+
(** [auc_pr ()] creates an AUC-PR metric.
127127
128-
Area Under the Precision-Recall Curve. *)
128+
Area Under the Precision-Recall Curve.
129+
130+
Computes the exact precision-recall integral by sorting predictions and
131+
accumulating precision/recall scores across all seen batches. *)
129132

130133
val confusion_matrix :
131134
num_classes:int ->

kaun/test/test_metrics.ml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,41 @@ let test_auc_roc_multiple_updates () =
113113

114114
check (tensor_testable 1e-5) "auc roc incremental" full_result chunked_result
115115

116+
let test_auc_pr () =
117+
let dtype = Rune.float32 in
118+
119+
let predictions = Rune.create dtype [| 4 |] [| 0.8; 0.7; 0.6; 0.3 |] in
120+
let targets = Rune.create dtype [| 4 |] [| 1.; 1.; 0.; 0. |] in
121+
122+
let auc = Metrics.auc_pr () in
123+
Metrics.update auc ~predictions ~targets ();
124+
let result = Metrics.compute auc in
125+
(* For perfectly separable predictions, AUC should be 1.0 *)
126+
let expected = Rune.scalar dtype 1.0 in
127+
check (tensor_testable 1e-5) "auc pr" expected result
128+
129+
let test_auc_pr_multiple_updates () =
130+
let dtype = Rune.float32 in
131+
132+
let predictions_full = Rune.create dtype [| 4 |] [| 0.8; 0.7; 0.6; 0.3 |] in
133+
let targets_full = Rune.create dtype [| 4 |] [| 1.; 1.; 0.; 0. |] in
134+
135+
let auc_single = Metrics.auc_pr () in
136+
Metrics.update auc_single ~predictions:predictions_full ~targets:targets_full
137+
();
138+
let full_result = Metrics.compute auc_single in
139+
140+
let auc_chunked = Metrics.auc_pr () in
141+
let predictions_1 = Rune.create dtype [| 2 |] [| 0.8; 0.7 |] in
142+
let targets_1 = Rune.create dtype [| 2 |] [| 1.; 1. |] in
143+
Metrics.update auc_chunked ~predictions:predictions_1 ~targets:targets_1 ();
144+
let predictions_2 = Rune.create dtype [| 2 |] [| 0.6; 0.3 |] in
145+
let targets_2 = Rune.create dtype [| 2 |] [| 0.; 0. |] in
146+
Metrics.update auc_chunked ~predictions:predictions_2 ~targets:targets_2 ();
147+
let chunked_result = Metrics.compute auc_chunked in
148+
149+
check (tensor_testable 1e-5) "auc pr incremental" full_result chunked_result
150+
116151
let test_confusion_matrix () =
117152
let dtype = Rune.float32 in
118153

@@ -414,6 +449,9 @@ let () =
414449
test_case "auc_roc" `Quick test_auc_roc;
415450
test_case "auc_roc_multiple_updates" `Quick
416451
test_auc_roc_multiple_updates;
452+
test_case "auc_pr" `Quick test_auc_pr;
453+
test_case "auc_pr_multiple_updates" `Quick
454+
test_auc_pr_multiple_updates;
417455
test_case "confusion_matrix" `Quick test_confusion_matrix;
418456
] );
419457
( "regression",

0 commit comments

Comments
 (0)