@@ -32,6 +32,43 @@ type 'layout metric_fn =
3232let scalar_tensor dtype value = Rune. scalar dtype value
3333let ones_like t = Rune. ones (Rune. dtype t) (Rune. shape t)
3434
35+ let accumulate_rank_metric_state metric_name state ~predictions ~targets
36+ ?weights () =
37+ let predictions = Rune. reshape [| - 1 |] predictions in
38+ let targets = Rune. reshape [| - 1 |] targets in
39+ let dtype =
40+ match state with
41+ | [ preds_acc; _; _ ] -> Rune. dtype preds_acc
42+ | _ -> Rune. dtype predictions
43+ in
44+ let predictions = Rune. cast dtype predictions in
45+ let targets = Rune. cast dtype targets in
46+ let weights =
47+ match weights with
48+ | Some w -> Rune. cast dtype (Rune. reshape [| - 1 |] w)
49+ | None -> ones_like predictions
50+ in
51+ match state with
52+ | [] -> [ predictions; targets; weights ]
53+ | [ preds_acc; targets_acc; weights_acc ] ->
54+ let preds_acc = Rune. concatenate ~axis: 0 [ preds_acc; predictions ] in
55+ let targets_acc = Rune. concatenate ~axis: 0 [ targets_acc; targets ] in
56+ let weights_acc = Rune. concatenate ~axis: 0 [ weights_acc; weights ] in
57+ [ preds_acc; targets_acc; weights_acc ]
58+ | _ -> failwith (Printf. sprintf " Invalid %s state" metric_name)
59+
60+ let prepare_rank_curve_inputs preds targets weights =
61+ let dtype = Rune. dtype preds in
62+ let sorted_idx = Rune. argsort ~axis: 0 ~descending: true preds in
63+ let sorted_targets = Rune. take_along_axis ~axis: 0 sorted_idx targets in
64+ let sorted_weights = Rune. take_along_axis ~axis: 0 sorted_idx weights in
65+ let positives = Rune. mul sorted_targets sorted_weights in
66+ let negatives =
67+ let ones = Rune. ones dtype (Rune. shape sorted_targets) in
68+ Rune. mul (Rune. sub ones sorted_targets) sorted_weights
69+ in
70+ (positives, negatives)
71+
3572(* * Core metric operations *)
3673
3774let update metric ~predictions ~targets ?weights () =
@@ -300,45 +337,15 @@ let auc_roc () =
300337 create_custom ~name: " auc_roc"
301338 ~init: (fun () -> [] )
302339 ~update: (fun state ~predictions ~targets ?weights () ->
303- let predictions = Rune. reshape [| - 1 |] predictions in
304- let targets = Rune. reshape [| - 1 |] targets in
305- let dtype =
306- match state with
307- | [ preds_acc; _; _ ] -> Rune. dtype preds_acc
308- | _ -> Rune. dtype predictions
309- in
310- let predictions = Rune. cast dtype predictions in
311- let targets = Rune. cast dtype targets in
312- let weights =
313- match weights with
314- | Some w -> Rune. cast dtype (Rune. reshape [| - 1 |] w)
315- | None -> Rune. ones dtype (Rune. shape predictions)
316- in
317- match state with
318- | [] -> [ predictions; targets; weights ]
319- | [ preds_acc; targets_acc; weights_acc ] ->
320- let preds_acc = Rune. concatenate ~axis: 0 [ preds_acc; predictions ] in
321- let targets_acc = Rune. concatenate ~axis: 0 [ targets_acc; targets ] in
322- let weights_acc = Rune. concatenate ~axis: 0 [ weights_acc; weights ] in
323- [ preds_acc; targets_acc; weights_acc ]
324- | _ -> failwith " Invalid auc_roc state" )
340+ accumulate_rank_metric_state " auc_roc" state ~predictions ~targets
341+ ?weights () )
325342 ~compute: (fun state ->
326343 match state with
327344 | [ preds; targets; weights ] ->
328- let dtype = Rune. dtype preds in
329- let ones = Rune. ones dtype (Rune. shape targets) in
330- let sorted_idx = Rune. argsort ~axis: 0 ~descending: true preds in
331- let sorted_targets =
332- Rune. take_along_axis ~axis: 0 sorted_idx targets
333- in
334- let sorted_weights =
335- Rune. take_along_axis ~axis: 0 sorted_idx weights
336- in
337-
338- let positives = Rune. mul sorted_targets sorted_weights in
339- let negatives =
340- Rune. mul (Rune. sub ones sorted_targets) sorted_weights
345+ let positives, negatives =
346+ prepare_rank_curve_inputs preds targets weights
341347 in
348+ let dtype = Rune. dtype positives in
342349
343350 let cum_tp = Rune. cumsum ~axis: 0 positives in
344351 let cum_fp = Rune. cumsum ~axis: 0 negatives in
@@ -380,13 +387,54 @@ let auc_roc () =
380387 | _ -> failwith " Invalid auc_roc state" )
381388 ~reset: (fun _ -> [] )
382389
383- let auc_pr ?(num_thresholds = 200 ) ?(curve = false ) () =
384- let _ = num_thresholds in
385- let _ = curve in
390+ let auc_pr () =
386391 create_custom ~name: " auc_pr"
387392 ~init: (fun () -> [] )
388- ~update: (fun state ~predictions :_ ~targets :_ ?weights :_ () -> state)
389- ~compute: (fun _ -> failwith " AUC-PR not yet implemented" )
393+ ~update: (fun state ~predictions ~targets ?weights () ->
394+ accumulate_rank_metric_state " auc_pr" state ~predictions ~targets ?weights
395+ () )
396+ ~compute: (fun state ->
397+ match state with
398+ | [ preds; targets; weights ] ->
399+ let positives, negatives =
400+ prepare_rank_curve_inputs preds targets weights
401+ in
402+ let dtype = Rune. dtype positives in
403+
404+ let cum_tp = Rune. cumsum ~axis: 0 positives in
405+ let cum_fp = Rune. cumsum ~axis: 0 negatives in
406+
407+ let cum_fn = Rune. sub (Rune. sum positives) cum_tp in
408+
409+ let zero = scalar_tensor dtype 0.0 in
410+ let cum_tp =
411+ Rune. concatenate ~axis: 0 [ Rune. reshape [| 1 |] zero; cum_tp ]
412+ in
413+ let cum_fp =
414+ Rune. concatenate ~axis: 0 [ Rune. reshape [| 1 |] zero; cum_fp ]
415+ in
416+ let cum_fn =
417+ Rune. concatenate ~axis: 0 [ Rune. reshape [| 1 |] zero; cum_fn ]
418+ in
419+
420+ let precision_denom = Rune. add cum_tp cum_fp in
421+ let recall_denom = Rune. add cum_tp cum_fn in
422+ let eps = scalar_tensor dtype 1e-7 in
423+
424+ let precision = Rune. div cum_tp (Rune. add precision_denom eps) in
425+ let recall = Rune. div cum_tp (Rune. add recall_denom eps) in
426+
427+ let n = Rune. size precision in
428+ if n < 2 then scalar_tensor dtype 0.0
429+ else
430+ let tail_recall = Rune. slice [ Rune. R (1 , n) ] recall in
431+ let head_recall = Rune. slice [ Rune. R (0 , n - 1 ) ] recall in
432+ let dx = Rune. sub tail_recall head_recall in
433+
434+ let precision_k = Rune. slice [ Rune. R (1 , n) ] precision in
435+
436+ Rune. sum (Rune. mul dx precision_k)
437+ | _ -> failwith " Invalid auc_pr state" )
390438 ~reset: (fun _ -> [] )
391439
392440let confusion_matrix ~num_classes ?(normalize = `None ) () =
0 commit comments