@@ -40,7 +40,7 @@ use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1};
4040use crate :: numbers:: basenum:: Number ;
4141#[ cfg( feature = "serde" ) ]
4242use serde:: { Deserialize , Serialize } ;
43- use std:: marker:: PhantomData ;
43+ use std:: { cmp :: Ordering , marker:: PhantomData } ;
4444
4545/// Distribution used in the Naive Bayes classifier.
4646pub ( crate ) trait NBDistribution < X : Number , Y : Number > : Clone {
@@ -92,11 +92,10 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
9292 /// Returns a vector of size N with class estimates.
9393 pub fn predict ( & self , x : & X ) -> Result < Y , Failed > {
9494 let y_classes = self . distribution . classes ( ) ;
95- let ( rows, _) = x. shape ( ) ;
96- let predictions = ( 0 ..rows)
97- . map ( |row_index| {
98- let row = x. get_row ( row_index) ;
99- let ( prediction, _probability) = y_classes
95+ let predictions = x
96+ . row_iter ( )
97+ . map ( |row| {
98+ y_classes
10099 . iter ( )
101100 . enumerate ( )
102101 . map ( |( class_index, class) | {
@@ -106,11 +105,26 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
106105 + self . distribution . prior ( class_index) . ln ( ) ,
107106 )
108107 } )
109- . max_by ( |( _, p1) , ( _, p2) | p1. partial_cmp ( p2) . unwrap ( ) )
110- . unwrap ( ) ;
111- * prediction
108+ // For some reason, the max_by method cannot use NaNs for finding the maximum value, it panics.
109+ // NaN must be considered as minimum values,
110+ // therefore it's like NaNs would not be considered for choosing the maximum value.
111+ // So we need to handle this case for avoiding panicking by using `Option::unwrap`.
112+ . max_by ( |( _, p1) , ( _, p2) | match p1. partial_cmp ( p2) {
113+ Some ( ordering) => ordering,
114+ None => {
115+ if p1. is_nan ( ) {
116+ Ordering :: Less
117+ } else if p2. is_nan ( ) {
118+ Ordering :: Greater
119+ } else {
120+ Ordering :: Equal
121+ }
122+ }
123+ } )
124+ . map ( |( prediction, _probability) | * prediction)
125+ . ok_or_else ( || Failed :: predict ( "Failed to predict, there is no result" ) )
112126 } )
113- . collect :: < Vec < TY > > ( ) ;
127+ . collect :: < Result < Vec < TY > , Failed > > ( ) ? ;
114128 let y_hat = Y :: from_vec_slice ( & predictions) ;
115129 Ok ( y_hat)
116130 }
@@ -119,3 +133,63 @@ pub mod bernoulli;
119133pub mod categorical;
120134pub mod gaussian;
121135pub mod multinomial;
136+
137+ #[ cfg( test) ]
138+ mod tests {
139+ use super :: * ;
140+ use crate :: linalg:: basic:: arrays:: Array ;
141+ use crate :: linalg:: basic:: matrix:: DenseMatrix ;
142+ use num_traits:: float:: Float ;
143+
144+ type Model < ' d > = BaseNaiveBayes < i32 , i32 , DenseMatrix < i32 > , Vec < i32 > , TestDistribution < ' d > > ;
145+
146+ #[ derive( Debug , PartialEq , Clone ) ]
147+ struct TestDistribution < ' d > ( & ' d Vec < i32 > ) ;
148+
149+ impl < ' d > NBDistribution < i32 , i32 > for TestDistribution < ' d > {
150+ fn prior ( & self , _class_index : usize ) -> f64 {
151+ 1.
152+ }
153+
154+ fn log_likelihood < ' a > (
155+ & ' a self ,
156+ class_index : usize ,
157+ _j : & ' a Box < dyn ArrayView1 < i32 > + ' a > ,
158+ ) -> f64 {
159+ match self . 0 . get ( class_index) {
160+ & v @ 2 | & v @ 10 | & v @ 20 => v as f64 ,
161+ _ => f64:: nan ( ) ,
162+ }
163+ }
164+
165+ fn classes ( & self ) -> & Vec < i32 > {
166+ & self . 0
167+ }
168+ }
169+
170+ #[ test]
171+ fn test_predict ( ) {
172+ let matrix = DenseMatrix :: from_2d_array ( & [ & [ 1 , 2 , 3 ] , & [ 4 , 5 , 6 ] , & [ 7 , 8 , 9 ] ] ) ;
173+
174+ let val = vec ! [ ] ;
175+ match Model :: fit ( TestDistribution ( & val) ) . unwrap ( ) . predict ( & matrix) {
176+ Ok ( _) => panic ! ( "Should return error in case of empty classes" ) ,
177+ Err ( err) => assert_eq ! (
178+ err. to_string( ) ,
179+ "Predict failed: Failed to predict, there is no result"
180+ ) ,
181+ }
182+
183+ let val = vec ! [ 1 , 2 , 3 ] ;
184+ match Model :: fit ( TestDistribution ( & val) ) . unwrap ( ) . predict ( & matrix) {
185+ Ok ( r) => assert_eq ! ( r, vec![ 2 , 2 , 2 ] ) ,
186+ Err ( _) => panic ! ( "Should success in normal case with NaNs" ) ,
187+ }
188+
189+ let val = vec ! [ 20 , 2 , 10 ] ;
190+ match Model :: fit ( TestDistribution ( & val) ) . unwrap ( ) . predict ( & matrix) {
191+ Ok ( r) => assert_eq ! ( r, vec![ 20 , 20 , 20 ] ) ,
192+ Err ( _) => panic ! ( "Should success in normal case without NaNs" ) ,
193+ }
194+ }
195+ }
0 commit comments