Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ All notable changes to this project will be documented in this file.
- Allow metric history to tolerate metrics that appear or disappear between epochs so dynamic metric sets no longer raise during training (@tmattio)
- Make `Optimizer.clip_by_global_norm` robust to zero gradients and empty parameter trees to avoid NaNs during training (@tmattio)
- Split CSV loader into `from_csv` and `from_csv_with_labels` to retain labels when requested (#114, @Satarupa22-SD)
- Implement AUC-ROC in Kaun metrics and simplify its API (#109 @Shocker444)

### Talon

Expand Down
96 changes: 90 additions & 6 deletions kaun/lib/kaun/metrics.ml
Original file line number Diff line number Diff line change
Expand Up @@ -296,16 +296,100 @@ let f1_score ?(threshold = 0.5) ?(averaging = Micro) ?(beta = 1.0) () =
~reset:(fun _ -> [])

(* Placeholder implementations for complex metrics *)
let auc_roc ?(num_thresholds = 200) ?(curve = false) () =
let _ = num_thresholds in
let _ = curve in
let auc_roc () =
create_custom ~name:"auc_roc"
~init:(fun () -> [])
~update:(fun state ~predictions:_ ~targets:_ ?weights:_ () -> state)
~compute:(fun _ ->
failwith "AUC-ROC not yet implemented - requires trapezoid integration")
~update:(fun state ~predictions ~targets ?weights () ->
let predictions = Rune.reshape [| -1 |] predictions in
let targets = Rune.reshape [| -1 |] targets in
let dtype =
match state with
| [ preds_acc; _; _ ] -> Rune.dtype preds_acc
| _ -> Rune.dtype predictions
in
let predictions = Rune.cast dtype predictions in
let targets = Rune.cast dtype targets in
let weights =
match weights with
| Some w -> Rune.cast dtype (Rune.reshape [| -1 |] w)
| None -> Rune.ones dtype (Rune.shape predictions)
in
match state with
| [] -> [ predictions; targets; weights ]
| [ preds_acc; targets_acc; weights_acc ] ->
let preds_acc =
Rune.concatenate ~axis:0 [ preds_acc; predictions ]
in
let targets_acc =
Rune.concatenate ~axis:0 [ targets_acc; targets ]
in
let weights_acc =
Rune.concatenate ~axis:0 [ weights_acc; weights ]
in
[ preds_acc; targets_acc; weights_acc ]
| _ -> failwith "Invalid auc_roc state")
~compute:(fun state ->
match state with
| [ preds; targets; weights ] ->
let dtype = Rune.dtype preds in
let ones = Rune.ones dtype (Rune.shape targets) in
let sorted_idx =
Rune.argsort ~axis:0 ~descending:true preds
in
let sorted_targets =
Rune.take_along_axis ~axis:0 sorted_idx targets
in
let sorted_weights =
Rune.take_along_axis ~axis:0 sorted_idx weights
in

let positives = Rune.mul sorted_targets sorted_weights in
let negatives =
Rune.mul (Rune.sub ones sorted_targets) sorted_weights
in

let cum_tp = Rune.cumsum ~axis:0 positives in
let cum_fp = Rune.cumsum ~axis:0 negatives in
let zero = scalar_tensor dtype 0.0 in
let cum_tp =
Rune.concatenate ~axis:0 [ Rune.reshape [| 1 |] zero; cum_tp ]
in
let cum_fp =
Rune.concatenate ~axis:0 [ Rune.reshape [| 1 |] zero; cum_fp ]
in

let total_pos = Rune.item [] (Rune.sum positives) in
let total_neg = Rune.item [] (Rune.sum negatives) in

let ratio cumulative total =
if Float.abs total < 1e-12 then
Rune.zeros dtype (Rune.shape cumulative)
else
let total_tensor = scalar_tensor dtype total in
Rune.div cumulative total_tensor
in

let tpr = ratio cum_tp total_pos in
let fpr = ratio cum_fp total_neg in

let n = Rune.size tpr in
if n < 2 then scalar_tensor dtype 0.0
else
let tail_fpr = Rune.slice [ Rune.R (1, n) ] fpr in
let head_fpr = Rune.slice [ Rune.R (0, n - 1) ] fpr in
let dx = Rune.sub tail_fpr head_fpr in

let tail_tpr = Rune.slice [ Rune.R (1, n) ] tpr in
let head_tpr = Rune.slice [ Rune.R (0, n - 1) ] tpr in
let avg_tpr =
Rune.mul (scalar_tensor dtype 0.5)
(Rune.add tail_tpr head_tpr)
in
Rune.sum (Rune.mul dx avg_tpr)
| _ -> failwith "Invalid auc_roc state")
~reset:(fun _ -> [])


let auc_pr ?(num_thresholds = 200) ?(curve = false) () =
let _ = num_thresholds in
let _ = curve in
Expand Down
9 changes: 4 additions & 5 deletions kaun/lib/kaun/metrics.mli
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,13 @@ val f1_score :
@param averaging Multi-class averaging strategy (default: Micro)
@param beta Weight of recall vs precision (default: 1.0 for F1) *)

val auc_roc : ?num_thresholds:int -> ?curve:bool -> unit -> 'layout t
(** [auc_roc ?num_thresholds ?curve ()] creates an AUC-ROC metric.
val auc_roc : unit -> 'layout t
(** [auc_roc ()] creates an AUC-ROC metric.

Area Under the Receiver Operating Characteristic Curve.

@param num_thresholds
Number of thresholds for curve computation (default: 200)
@param curve If true, also return the ROC curve points (default: false) *)
Computes the exact ROC integral by sorting predictions and accumulating
true/false positive rates across all seen batches. *)

val auc_pr : ?num_thresholds:int -> ?curve:bool -> unit -> 'layout t
(** [auc_pr ?num_thresholds ?curve ()] creates an AUC-PR metric.
Expand Down
37 changes: 37 additions & 0 deletions kaun/test/test_metrics.ml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,41 @@ let test_f1_score () =
let expected = Rune.scalar dtype 0.8 in
check (tensor_testable 1e-5) "f1 score" expected result

let test_auc_roc () =
let dtype = Rune.float32 in

let predictions = Rune.create dtype [| 4 |] [| 0.8; 0.7; 0.6; 0.3 |] in
let targets = Rune.create dtype [| 4 |] [| 1.; 1.; 0.; 0. |] in

let auc = Metrics.auc_roc () in
Metrics.update auc ~predictions ~targets ();
let result = Metrics.compute auc in
(* For perfectly separable predictions, AUC should be 1.0 *)
let expected = Rune.scalar dtype 1.0 in
check (tensor_testable 1e-5) "auc roc" expected result

let test_auc_roc_multiple_updates () =
let dtype = Rune.float32 in

let predictions_full = Rune.create dtype [| 4 |] [| 0.8; 0.7; 0.6; 0.3 |] in
let targets_full = Rune.create dtype [| 4 |] [| 1.; 1.; 0.; 0. |] in

let auc_single = Metrics.auc_roc () in
Metrics.update auc_single ~predictions:predictions_full ~targets:targets_full ();
let full_result = Metrics.compute auc_single in

let auc_chunked = Metrics.auc_roc () in
let predictions_1 = Rune.create dtype [| 2 |] [| 0.8; 0.7 |] in
let targets_1 = Rune.create dtype [| 2 |] [| 1.; 1. |] in
Metrics.update auc_chunked ~predictions:predictions_1 ~targets:targets_1 ();
let predictions_2 = Rune.create dtype [| 2 |] [| 0.6; 0.3 |] in
let targets_2 = Rune.create dtype [| 2 |] [| 0.; 0. |] in
Metrics.update auc_chunked ~predictions:predictions_2 ~targets:targets_2 ();
let chunked_result = Metrics.compute auc_chunked in

check (tensor_testable 1e-5) "auc roc incremental"
full_result chunked_result

let test_confusion_matrix () =
let dtype = Rune.float32 in

Expand Down Expand Up @@ -376,6 +411,8 @@ let () =
test_case "accuracy" `Quick test_accuracy;
test_case "precision_recall" `Quick test_precision_recall;
test_case "f1_score" `Quick test_f1_score;
test_case "auc_roc" `Quick test_auc_roc;
test_case "auc_roc_multiple_updates" `Quick test_auc_roc_multiple_updates;
test_case "confusion_matrix" `Quick test_confusion_matrix;
] );
( "regression",
Expand Down
Loading