@@ -296,16 +296,100 @@ let f1_score ?(threshold = 0.5) ?(averaging = Micro) ?(beta = 1.0) () =
296296 ~reset: (fun _ -> [] )
297297
298298(* Placeholder implementations for complex metrics *)
299- let auc_roc ?(num_thresholds = 200 ) ?(curve = false ) () =
300- let _ = num_thresholds in
301- let _ = curve in
299+ let auc_roc () =
302300 create_custom ~name: " auc_roc"
303301 ~init: (fun () -> [] )
304- ~update: (fun state ~predictions :_ ~targets :_ ?weights :_ () -> state)
305- ~compute: (fun _ ->
306- failwith " AUC-ROC not yet implemented - requires trapezoid integration" )
302+ ~update: (fun state ~predictions ~targets ?weights () ->
303+ let predictions = Rune. reshape [| - 1 |] predictions in
304+ let targets = Rune. reshape [| - 1 |] targets in
305+ let dtype =
306+ match state with
307+ | [ preds_acc; _; _ ] -> Rune. dtype preds_acc
308+ | _ -> Rune. dtype predictions
309+ in
310+ let predictions = Rune. cast dtype predictions in
311+ let targets = Rune. cast dtype targets in
312+ let weights =
313+ match weights with
314+ | Some w -> Rune. cast dtype (Rune. reshape [| - 1 |] w)
315+ | None -> Rune. ones dtype (Rune. shape predictions)
316+ in
317+ match state with
318+ | [] -> [ predictions; targets; weights ]
319+ | [ preds_acc; targets_acc; weights_acc ] ->
320+ let preds_acc =
321+ Rune. concatenate ~axis: 0 [ preds_acc; predictions ]
322+ in
323+ let targets_acc =
324+ Rune. concatenate ~axis: 0 [ targets_acc; targets ]
325+ in
326+ let weights_acc =
327+ Rune. concatenate ~axis: 0 [ weights_acc; weights ]
328+ in
329+ [ preds_acc; targets_acc; weights_acc ]
330+ | _ -> failwith " Invalid auc_roc state" )
331+ ~compute: (fun state ->
332+ match state with
333+ | [ preds; targets; weights ] ->
334+ let dtype = Rune. dtype preds in
335+ let ones = Rune. ones dtype (Rune. shape targets) in
336+ let sorted_idx =
337+ Rune. argsort ~axis: 0 ~descending: true preds
338+ in
339+ let sorted_targets =
340+ Rune. take_along_axis ~axis: 0 sorted_idx targets
341+ in
342+ let sorted_weights =
343+ Rune. take_along_axis ~axis: 0 sorted_idx weights
344+ in
345+
346+ let positives = Rune. mul sorted_targets sorted_weights in
347+ let negatives =
348+ Rune. mul (Rune. sub ones sorted_targets) sorted_weights
349+ in
350+
351+ let cum_tp = Rune. cumsum ~axis: 0 positives in
352+ let cum_fp = Rune. cumsum ~axis: 0 negatives in
353+ let zero = scalar_tensor dtype 0.0 in
354+ let cum_tp =
355+ Rune. concatenate ~axis: 0 [ Rune. reshape [| 1 |] zero; cum_tp ]
356+ in
357+ let cum_fp =
358+ Rune. concatenate ~axis: 0 [ Rune. reshape [| 1 |] zero; cum_fp ]
359+ in
360+
361+ let total_pos = Rune. item [] (Rune. sum positives) in
362+ let total_neg = Rune. item [] (Rune. sum negatives) in
363+
364+ let ratio cumulative total =
365+ if Float. abs total < 1e-12 then
366+ Rune. zeros dtype (Rune. shape cumulative)
367+ else
368+ let total_tensor = scalar_tensor dtype total in
369+ Rune. div cumulative total_tensor
370+ in
371+
372+ let tpr = ratio cum_tp total_pos in
373+ let fpr = ratio cum_fp total_neg in
374+
375+ let n = Rune. size tpr in
376+ if n < 2 then scalar_tensor dtype 0.0
377+ else
378+ let tail_fpr = Rune. slice [ Rune. R (1 , n) ] fpr in
379+ let head_fpr = Rune. slice [ Rune. R (0 , n - 1 ) ] fpr in
380+ let dx = Rune. sub tail_fpr head_fpr in
381+
382+ let tail_tpr = Rune. slice [ Rune. R (1 , n) ] tpr in
383+ let head_tpr = Rune. slice [ Rune. R (0 , n - 1 ) ] tpr in
384+ let avg_tpr =
385+ Rune. mul (scalar_tensor dtype 0.5 )
386+ (Rune. add tail_tpr head_tpr)
387+ in
388+ Rune. sum (Rune. mul dx avg_tpr)
389+ | _ -> failwith " Invalid auc_roc state" )
307390 ~reset: (fun _ -> [] )
308391
392+
309393let auc_pr ?(num_thresholds = 200 ) ?(curve = false ) () =
310394 let _ = num_thresholds in
311395 let _ = curve in
0 commit comments