Skip to content

Commit

Permalink
Add test.
Browse files Browse the repository at this point in the history
  • Loading branch information
frozenlib committed Dec 29, 2023
1 parent 44a6b49 commit 6ff163a
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions lgbm/src/booster/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,63 @@ fn binary_classification() -> Result<()> {
Ok(())
}

#[test]
fn binary_classification_categorical() -> Result<()> {
let num_category = 3;

let mut p = Parameters::new();
p.push("boosting_type", Boosting::Gbdt);
p.push("objective", Objective::Binary);
p.push("metric", [Metric::BinaryLogloss, Metric::Auc]);
p.push("min_data_in_leaf", 20);
p.push("verbosity", Verbosity::Fatal);

println!("make train dataset");
let train_feature = make_features_categorycal(128, num_category);
let train_label = make_labels_categorycal(128, num_category);
let mut train = Dataset::from_mat(&train_feature, None, &p)?;
train.set_field(Field::LABEL, &train_label)?;

println!("make test dataset");
let test_feature = make_features_categorycal(50, num_category);
let test_label = make_labels_categorycal(50, num_category);
let mut test = Dataset::from_mat(&test_feature, Some(&train), &p)?;
test.set_field(Field::LABEL, &test_label)?;

println!("crate booster");
let mut b = Booster::new(Arc::new(train), &p)?;
b.add_valid_data(Arc::new(test))?;
for n in 0..100 {
println!("iter {n}");
let is_finish = b.update_one_iter()?;
let eval_names = b.get_eval_names()?;
let evals = b.get_eval(0)?;
for i in 0..eval_names.len() {
println!("training {}: {}", eval_names[i], evals[i]);
}
let evals = b.get_eval(1)?;
for i in 0..eval_names.len() {
println!("valid {}: {}", eval_names[i], evals[i]);
}
if is_finish {
break;
}
}

let p = Parameters::new();
let rs = b.predict_for_mat(&test_feature, PredictType::Normal, 0, None, &p)?;
println!("\n{rs}");
for i in 0..test_label.len() {
let r = rs[i];
if test_label[i] == 0.0 {
assert!(r < 0.1);
} else {
assert!(r > 0.9);
}
}
Ok(())
}

#[test]
fn multiclass_classification() -> Result<()> {
let num_class = 3;
Expand Down Expand Up @@ -412,6 +469,27 @@ fn make_labels(num_row: usize, num_class: usize) -> Vec<f32> {
(0..num_row).map(|x| (x % num_class) as f32).collect()
}

fn make_features_categorycal(num_row: usize, num_category: usize) -> MatBuf<f64, RowMajor> {
MatBuf::from_rows((0..num_row).map(|x| {
let category = x % num_category;
let value = (x / num_category) % ((category + 1) * 2);
[category as f64, value as f64]
}))
}
fn make_labels_categorycal(num_row: usize, num_category: usize) -> Vec<f32> {
(0..num_row)
.map(|x| {
let category = x % num_category;
let value = (x / num_category) % ((category + 1) * 2);
if value > category {
1.0
} else {
0.0
}
})
.collect()
}

fn make_dataset(
num_row: usize,
num_class: usize,
Expand Down

0 comments on commit 6ff163a

Please sign in to comment.