Skip to content

Handle multiclass precision/recall #152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/math/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ pub trait RealNumber:
self * self
}

/// Raw transmutation to u64
/// Raw transmutation to u32
fn to_f32_bits(self) -> u32;

/// Raw transmutation to u64
fn to_f64_bits(self) -> u64;
}

impl RealNumber for f64 {
Expand Down Expand Up @@ -89,6 +92,10 @@ impl RealNumber for f64 {
fn to_f32_bits(self) -> u32 {
self.to_bits() as u32
}

fn to_f64_bits(self) -> u64 {
self.to_bits()
}
}

impl RealNumber for f32 {
Expand Down Expand Up @@ -130,6 +137,10 @@ impl RealNumber for f32 {
fn to_f32_bits(self) -> u32 {
self.to_bits()
}

fn to_f64_bits(self) -> u64 {
self.to_bits() as u64
}
}

#[cfg(test)]
Expand Down
64 changes: 42 additions & 22 deletions src/metrics/precision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::HashSet;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand All @@ -42,34 +44,33 @@ impl Precision {
);
}

let mut tp = 0;
let mut p = 0;
let n = y_true.len();
for i in 0..n {
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
panic!(
"Precision can only be applied to binary classification: {}",
y_true.get(i)
);
}

if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
panic!(
"Precision can only be applied to binary classification: {}",
y_pred.get(i)
);
}

if y_pred.get(i) == T::one() {
p += 1;
let mut classes = HashSet::new();
for i in 0..y_true.len() {
classes.insert(y_true.get(i).to_f64_bits());
}
let classes = classes.len();

if y_true.get(i) == T::one() {
let mut tp = 0;
let mut fp = 0;
for i in 0..y_true.len() {
if y_pred.get(i) == y_true.get(i) {
if classes == 2 {
if y_true.get(i) == T::one() {
tp += 1;
}
} else {
tp += 1;
}
} else if classes == 2 {
if y_true.get(i) == T::one() {
fp += 1;
}
} else {
fp += 1;
}
}

T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fp).unwrap())
}
}

Expand All @@ -88,5 +89,24 @@ mod tests {

assert!((score1 - 0.5).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);

let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];

let score3: f64 = Precision {}.get_score(&y_pred, &y_true);
assert!((score3 - 0.5).abs() < 1e-8);
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn precision_multiclass() {
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];

let score1: f64 = Precision {}.get_score(&y_pred, &y_true);
let score2: f64 = Precision {}.get_score(&y_pred, &y_pred);

assert!((score1 - 0.333333333).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
}
}
66 changes: 43 additions & 23 deletions src/metrics/recall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::HashSet;
use std::convert::TryInto;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand All @@ -42,34 +45,32 @@ impl Recall {
);
}

let mut tp = 0;
let mut p = 0;
let n = y_true.len();
for i in 0..n {
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
panic!(
"Recall can only be applied to binary classification: {}",
y_true.get(i)
);
}

if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
panic!(
"Recall can only be applied to binary classification: {}",
y_pred.get(i)
);
}

if y_true.get(i) == T::one() {
p += 1;
let mut classes = HashSet::new();
for i in 0..y_true.len() {
classes.insert(y_true.get(i).to_f64_bits());
}
let classes: i64 = classes.len().try_into().unwrap();

if y_pred.get(i) == T::one() {
let mut tp = 0;
let mut fne = 0;
for i in 0..y_true.len() {
if y_pred.get(i) == y_true.get(i) {
if classes == 2 {
if y_true.get(i) == T::one() {
tp += 1;
}
} else {
tp += 1;
}
} else if classes == 2 {
if y_true.get(i) != T::one() {
fne += 1;
}
} else {
fne += 1;
}
}

T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fne).unwrap())
}
}

Expand All @@ -88,5 +89,24 @@ mod tests {

assert!((score1 - 0.5).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);

let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];

let score3: f64 = Recall {}.get_score(&y_pred, &y_true);
assert!((score3 - 0.66666666).abs() < 1e-8);
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn recall_multiclass() {
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];

let score1: f64 = Recall {}.get_score(&y_pred, &y_true);
let score2: f64 = Recall {}.get_score(&y_pred, &y_pred);

assert!((score1 - 0.333333333).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
}
}