Skip to content

Commit e891193

Browse files
authored
Enhance error return to fulfill std::error::Error + 'static + Send + Sync (#10)
* let all Box error return with sync and send * fix return error to fufill Sync and Send * add static into error return * change to use GBDT it's own error and result
1 parent 87ac494 commit e891193

File tree

5 files changed

+103
-30
lines changed

5 files changed

+103
-30
lines changed

src/decision_tree.rs

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ use crate::binary_tree::BinaryTree;
9999
use crate::binary_tree::BinaryTreeNode;
100100
use crate::binary_tree::TreeIndex;
101101
use crate::config::Loss;
102+
use crate::errors::{GbdtError, Result};
102103
#[cfg(feature = "enable_training")]
103104
use crate::fitness::almost_equal;
104-
use std::error::Error;
105105

106106
#[cfg(feature = "enable_training")]
107107
use rand::prelude::SliceRandom;
@@ -1746,7 +1746,7 @@ impl DecisionTree {
17461746
/// let node: Value = serde_json::from_str(data).unwrap();
17471747
/// let dt = DecisionTree::get_from_xgboost(&node);
17481748
/// ```
1749-
pub fn get_from_xgboost(node: &serde_json::Value) -> Result<Self, Box<dyn Error>> {
1749+
pub fn get_from_xgboost(node: &serde_json::Value) -> Result<Self> {
17501750
// Parameters are not used in prediction process, so we use default parameters.
17511751
let mut tree = DecisionTree::new();
17521752
let index = tree.tree.add_root(BinaryTreeNode::new(DTNode::new()));
@@ -1755,65 +1755,66 @@ impl DecisionTree {
17551755
}
17561756

17571757
/// Recursively build the tree node from the JSON value.
1758-
fn add_node_from_json(
1759-
&mut self,
1760-
index: TreeIndex,
1761-
node: &serde_json::Value,
1762-
) -> Result<(), Box<dyn Error>> {
1758+
fn add_node_from_json(&mut self, index: TreeIndex, node: &serde_json::Value) -> Result<()> {
17631759
{
17641760
let node_ref = self
17651761
.tree
17661762
.get_node_mut(index)
17671763
.expect("node should not be empty!");
17681764
// This is the leaf node
17691765
if let serde_json::Value::Number(pred) = &node["leaf"] {
1770-
let leaf_value = pred.as_f64().ok_or("parse 'leaf' error")?;
1766+
let leaf_value = pred.as_f64().ok_or_else(|| "parse 'leaf' error")?;
17711767
node_ref.value.pred = leaf_value as ValueType;
17721768
node_ref.value.is_leaf = true;
17731769
return Ok(());
17741770
} else {
17751771
// feature value
17761772
let feature_value = node["split_condition"]
17771773
.as_f64()
1778-
.ok_or("parse 'split condition' error")?;
1774+
.ok_or_else(|| "parse 'split condition' error")?;
17791775
node_ref.value.feature_value = feature_value as ValueType;
17801776

17811777
// feature index
17821778
let feature_index = match node["split"].as_i64() {
17831779
Some(v) => v,
17841780
None => {
1785-
let feature_name = node["split"].as_str().ok_or("parse 'split' error")?;
1781+
let feature_name = node["split"]
1782+
.as_str()
1783+
.ok_or_else(|| "parse 'split' error")?;
17861784
let feature_str: String = feature_name.chars().skip(3).collect();
17871785
feature_str.parse::<i64>()?
17881786
}
17891787
};
17901788
node_ref.value.feature_index = feature_index as usize;
17911789

17921790
// handle unknown feature
1793-
let missing = node["missing"].as_i64().ok_or("parse 'missing' error")?;
1794-
let left_child = node["yes"].as_i64().ok_or("parse 'yes' error")?;
1795-
let right_child = node["no"].as_i64().ok_or("parse 'no' error")?;
1791+
let missing = node["missing"]
1792+
.as_i64()
1793+
.ok_or_else(|| "parse 'missing' error")?;
1794+
let left_child = node["yes"].as_i64().ok_or_else(|| "parse 'yes' error")?;
1795+
let right_child = node["no"].as_i64().ok_or_else(|| "parse 'no' error")?;
17961796
if missing == left_child {
17971797
node_ref.value.missing = -1;
17981798
} else if missing == right_child {
17991799
node_ref.value.missing = 1;
18001800
} else {
1801-
let err: Box<dyn Error> = From::from("not support extra missing node".to_string());
1802-
return Err(err);
1801+
return Err(GbdtError::NotSupportExtraMissingNode);
18031802
}
18041803
}
18051804
}
18061805

18071806
// ids for children
1808-
let left_child = node["yes"].as_i64().ok_or("parse 'yes' error")?;
1809-
let right_child = node["no"].as_i64().ok_or("parse 'no' error")?;
1807+
let left_child = node["yes"].as_i64().ok_or_else(|| "parse 'yes' error")?;
1808+
let right_child = node["no"].as_i64().ok_or_else(|| "parse 'no' error")?;
18101809
let children = node["children"]
18111810
.as_array()
1812-
.ok_or("parse 'children' error")?;
1811+
.ok_or_else(|| "parse 'children' error")?;
18131812
let mut find_left = false;
18141813
let mut find_right = false;
18151814
for child in children.iter() {
1816-
let node_id = child["nodeid"].as_i64().ok_or("parse 'nodeid' error")?;
1815+
let node_id = child["nodeid"]
1816+
.as_i64()
1817+
.ok_or_else(|| "parse 'nodeid' error")?;
18171818

18181819
// build left child
18191820
if node_id == left_child {
@@ -1835,8 +1836,7 @@ impl DecisionTree {
18351836
}
18361837

18371838
if (!find_left) || (!find_right) {
1838-
let err: Box<dyn Error> = From::from("children not found".to_string());
1839-
return Err(err);
1839+
return Err(GbdtError::ChildrenNotFound);
18401840
}
18411841
Ok(())
18421842
}

src/errors.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
use std::error::Error;
2+
use std::fmt::{Display, Formatter};
3+
use std::io;
4+
use std::num;
5+
6+
pub type Result<T> = std::result::Result<T, GbdtError>;
7+
8+
#[derive(Debug)]
9+
pub enum GbdtError {
10+
NotSupportExtraMissingNode,
11+
ChildrenNotFound,
12+
IO(io::Error),
13+
ParseInt(num::ParseIntError),
14+
ParseFloat(num::ParseFloatError),
15+
SerdeJson(serde_json::Error),
16+
}
17+
18+
impl From<&str> for GbdtError {
19+
fn from(err: &str) -> GbdtError {
20+
GbdtError::IO(io::Error::new(io::ErrorKind::Other, err))
21+
}
22+
}
23+
24+
impl From<serde_json::Error> for GbdtError {
25+
fn from(err: serde_json::Error) -> GbdtError {
26+
GbdtError::SerdeJson(err)
27+
}
28+
}
29+
30+
impl From<num::ParseFloatError> for GbdtError {
31+
fn from(err: num::ParseFloatError) -> GbdtError {
32+
GbdtError::ParseFloat(err)
33+
}
34+
}
35+
36+
impl From<num::ParseIntError> for GbdtError {
37+
fn from(err: num::ParseIntError) -> GbdtError {
38+
GbdtError::ParseInt(err)
39+
}
40+
}
41+
42+
impl From<io::Error> for GbdtError {
43+
fn from(err: io::Error) -> GbdtError {
44+
GbdtError::IO(err)
45+
}
46+
}
47+
48+
impl Display for GbdtError {
49+
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
50+
match *self {
51+
GbdtError::NotSupportExtraMissingNode => write!(f, "Not support extra missing node"),
52+
GbdtError::ChildrenNotFound => write!(f, "Children not found"),
53+
GbdtError::IO(ref e) => write!(f, "IO error: {}", e),
54+
GbdtError::ParseInt(ref e) => write!(f, "ParseInt error: {}", e),
55+
GbdtError::ParseFloat(ref e) => write!(f, "ParseFloat error: {}", e),
56+
GbdtError::SerdeJson(ref e) => write!(f, "SerdeJson error: {}", e),
57+
}
58+
}
59+
}
60+
61+
impl Error for GbdtError {
62+
fn source(&self) -> Option<&(dyn Error + 'static)> {
63+
match *self {
64+
GbdtError::NotSupportExtraMissingNode => None,
65+
GbdtError::ChildrenNotFound => None,
66+
GbdtError::IO(ref e) => Some(e),
67+
GbdtError::ParseInt(ref e) => Some(e),
68+
GbdtError::ParseFloat(ref e) => Some(e),
69+
GbdtError::SerdeJson(ref e) => Some(e),
70+
}
71+
}
72+
}

src/gradient_boost.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,13 @@ use crate::decision_tree::DecisionTree;
8989
#[cfg(feature = "enable_training")]
9090
use crate::decision_tree::TrainingCache;
9191
use crate::decision_tree::{DataVec, PredVec, ValueType, VALUE_TYPE_MIN, VALUE_TYPE_UNKNOWN};
92+
use crate::errors::Result;
9293
#[cfg(feature = "enable_training")]
9394
use crate::fitness::{label_average, logit_loss_gradient, weighted_label_median, AUC, MAE, RMSE};
9495
#[cfg(feature = "enable_training")]
9596
use rand::prelude::SliceRandom;
9697
#[cfg(feature = "enable_training")]
9798
use rand::thread_rng;
98-
use std::error::Error;
9999

100100
#[cfg(not(feature = "mesalock_sgx"))]
101101
use std::fs::File;
@@ -707,7 +707,7 @@ impl GBDT {
707707
/// // Save model.
708708
/// // gbdt.save_model("gbdt.model");
709709
/// ```
710-
pub fn save_model(&self, filename: &str) -> Result<(), Box<dyn Error>> {
710+
pub fn save_model(&self, filename: &str) -> Result<()> {
711711
let mut file = File::create(filename)?;
712712
let serialized = serde_json::to_string(self)?;
713713
file.write_all(serialized.as_bytes())?;
@@ -726,7 +726,7 @@ impl GBDT {
726726
///
727727
/// # Error
728728
/// Error when get exception during model file parsing or deserialize.
729-
pub fn load_model(filename: &str) -> Result<Self, Box<dyn Error>> {
729+
pub fn load_model(filename: &str) -> Result<Self> {
730730
let mut file = File::open(filename)?;
731731
let mut contents = String::new();
732732
file.read_to_string(&mut contents)?;
@@ -746,7 +746,7 @@ impl GBDT {
746746
///
747747
/// # Error
748748
/// Error when get exception during model file parsing.
749-
pub fn from_xgoost_dump(model_file: &str, objective: &str) -> Result<Self, Box<dyn Error>> {
749+
pub fn from_xgoost_dump(model_file: &str, objective: &str) -> Result<Self> {
750750
let tree_file = File::open(&model_file)?;
751751
let reader = BufReader::new(tree_file);
752752
let mut all_lines: Vec<String> = Vec::new();
@@ -766,7 +766,7 @@ impl GBDT {
766766
let single_line = all_lines.join("");
767767
let json_obj: serde_json::Value = serde_json::from_str(&single_line)?;
768768

769-
let nodes = json_obj.as_array().ok_or("parse trees error")?;
769+
let nodes = json_obj.as_array().ok_or_else(|| "parse trees error")?;
770770

771771
let mut cfg = Config::new();
772772
cfg.set_loss(objective);

src/input.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
use std::prelude::v1::*;
2929

3030
use crate::decision_tree::{Data, DataVec, ValueType, VALUE_TYPE_UNKNOWN};
31+
use crate::errors::Result;
3132

3233
cfg_if! {
3334
if #[cfg(all(feature = "mesalock_sgx", not(target_env = "sgx")))] {
@@ -37,7 +38,6 @@ cfg_if! {
3738
use std::io::{BufRead, BufReader, Seek, SeekFrom};
3839
} else {
3940
use std::collections::HashMap;
40-
use std::error::Error;
4141
#[cfg(not(feature = "mesalock_sgx"))]
4242
use std::fs::File;
4343
#[cfg(feature = "mesalock_sgx")]
@@ -286,7 +286,7 @@ pub fn infer(file_name: &str) -> InputFormat {
286286
///
287287
/// # Error
288288
/// Raise error if file cannot be read correctly.
289-
pub fn load_csv(file: &mut File, input_format: InputFormat) -> Result<DataVec, Box<dyn Error>> {
289+
pub fn load_csv(file: &mut File, input_format: InputFormat) -> Result<DataVec> {
290290
file.seek(SeekFrom::Start(0))?;
291291
let mut dv = Vec::new();
292292

@@ -337,7 +337,7 @@ pub fn load_csv(file: &mut File, input_format: InputFormat) -> Result<DataVec, B
337337
///
338338
/// # Error
339339
/// Raise error if file cannot be read correctly.
340-
pub fn load_txt(file: &mut File, input_format: InputFormat) -> Result<DataVec, Box<dyn Error>> {
340+
pub fn load_txt(file: &mut File, input_format: InputFormat) -> Result<DataVec> {
341341
file.seek(SeekFrom::Start(0))?;
342342
let mut dv = Vec::new();
343343

@@ -414,7 +414,7 @@ pub fn load_txt(file: &mut File, input_format: InputFormat) -> Result<DataVec, B
414414
///
415415
/// # Error
416416
/// Raise error if file cannot be open correctly.
417-
pub fn load(file_name: &str, input_format: InputFormat) -> Result<DataVec, Box<dyn Error>> {
417+
pub fn load(file_name: &str, input_format: InputFormat) -> Result<DataVec> {
418418
let mut file = File::open(file_name.to_string())?;
419419
match input_format.ftype {
420420
FileFormat::CSV => load_csv(&mut file, input_format),

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ extern crate sgx_tstd as std;
4343
pub mod binary_tree;
4444
pub mod config;
4545
pub mod decision_tree;
46+
pub mod errors;
4647
pub mod fitness;
4748
pub mod gradient_boost;
4849

0 commit comments

Comments
 (0)