@@ -61,13 +61,11 @@ use kernels::{
6161} ;
6262use kernels_arrow:: {
6363 add_decimal, add_decimal_scalar, divide_decimal_scalar, divide_opt_decimal,
64- eq_decimal_scalar, gt_decimal_scalar, gt_eq_decimal_scalar, is_distinct_from,
65- is_distinct_from_bool, is_distinct_from_decimal, is_distinct_from_null,
66- is_distinct_from_utf8, is_not_distinct_from, is_not_distinct_from_bool,
67- is_not_distinct_from_decimal, is_not_distinct_from_null, is_not_distinct_from_utf8,
68- lt_decimal_scalar, lt_eq_decimal_scalar, modulus_decimal, modulus_decimal_scalar,
69- multiply_decimal, multiply_decimal_scalar, neq_decimal_scalar, subtract_decimal,
70- subtract_decimal_scalar,
64+ is_distinct_from, is_distinct_from_bool, is_distinct_from_decimal,
65+ is_distinct_from_null, is_distinct_from_utf8, is_not_distinct_from,
66+ is_not_distinct_from_bool, is_not_distinct_from_decimal, is_not_distinct_from_null,
67+ is_not_distinct_from_utf8, modulus_decimal, modulus_decimal_scalar, multiply_decimal,
68+ multiply_decimal_scalar, subtract_decimal, subtract_decimal_scalar,
7169} ;
7270
7371use arrow:: datatypes:: { DataType , Schema , TimeUnit } ;
@@ -124,11 +122,8 @@ impl std::fmt::Display for BinaryExpr {
124122macro_rules! compute_decimal_op_dyn_scalar {
125123 ( $LEFT: expr, $RIGHT: expr, $OP: ident, $OP_TYPE: expr) => { {
126124 let ll = as_decimal128_array( $LEFT) . unwrap( ) ;
127- if let ScalarValue :: Decimal128 ( Some ( _) , _, _) = $RIGHT {
128- Ok ( Arc :: new( paste:: expr! { [ <$OP _decimal_scalar>] } (
129- ll,
130- $RIGHT. try_into( ) ?,
131- ) ?) )
125+ if let ScalarValue :: Decimal128 ( Some ( v_i128) , _, _) = $RIGHT {
126+ Ok ( Arc :: new( paste:: expr! { [ <$OP _dyn_scalar>] } ( ll, v_i128) ?) )
132127 } else {
133128 // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE type
134129 Ok ( Arc :: new( new_null_array( $OP_TYPE, $LEFT. len( ) ) ) )
@@ -2304,6 +2299,82 @@ mod tests {
23042299
23052300 #[ test]
23062301 fn comparison_decimal_expr_test ( ) -> Result < ( ) > {
2302+ // scalar of decimal compare with decimal array
2303+ let value_i128 = 123 ;
2304+ let decimal_scalar = ScalarValue :: Decimal128 ( Some ( value_i128) , 25 , 3 ) ;
2305+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new(
2306+ "a" ,
2307+ DataType :: Decimal128 ( 25 , 3 ) ,
2308+ true ,
2309+ ) ] ) ) ;
2310+ let decimal_array = Arc :: new ( create_decimal_array (
2311+ & [
2312+ Some ( value_i128) ,
2313+ None ,
2314+ Some ( value_i128 - 1 ) ,
2315+ Some ( value_i128 + 1 ) ,
2316+ ] ,
2317+ 25 ,
2318+ 3 ,
2319+ ) ) as ArrayRef ;
2320+ // array = scalar
2321+ apply_logic_op_arr_scalar (
2322+ & schema,
2323+ & decimal_array,
2324+ & decimal_scalar,
2325+ Operator :: Eq ,
2326+ & BooleanArray :: from ( vec ! [ Some ( true ) , None , Some ( false ) , Some ( false ) ] ) ,
2327+ )
2328+ . unwrap ( ) ;
2329+ // array != scalar
2330+ apply_logic_op_arr_scalar (
2331+ & schema,
2332+ & decimal_array,
2333+ & decimal_scalar,
2334+ Operator :: NotEq ,
2335+ & BooleanArray :: from ( vec ! [ Some ( false ) , None , Some ( true ) , Some ( true ) ] ) ,
2336+ )
2337+ . unwrap ( ) ;
2338+ // array < scalar
2339+ apply_logic_op_arr_scalar (
2340+ & schema,
2341+ & decimal_array,
2342+ & decimal_scalar,
2343+ Operator :: Lt ,
2344+ & BooleanArray :: from ( vec ! [ Some ( false ) , None , Some ( true ) , Some ( false ) ] ) ,
2345+ )
2346+ . unwrap ( ) ;
2347+
2348+ // array <= scalar
2349+ apply_logic_op_arr_scalar (
2350+ & schema,
2351+ & decimal_array,
2352+ & decimal_scalar,
2353+ Operator :: LtEq ,
2354+ & BooleanArray :: from ( vec ! [ Some ( true ) , None , Some ( true ) , Some ( false ) ] ) ,
2355+ )
2356+ . unwrap ( ) ;
2357+ // array > scalar
2358+ apply_logic_op_arr_scalar (
2359+ & schema,
2360+ & decimal_array,
2361+ & decimal_scalar,
2362+ Operator :: Gt ,
2363+ & BooleanArray :: from ( vec ! [ Some ( false ) , None , Some ( false ) , Some ( true ) ] ) ,
2364+ )
2365+ . unwrap ( ) ;
2366+
2367+ // array >= scalar
2368+ apply_logic_op_arr_scalar (
2369+ & schema,
2370+ & decimal_array,
2371+ & decimal_scalar,
2372+ Operator :: GtEq ,
2373+ & BooleanArray :: from ( vec ! [ Some ( true ) , None , Some ( false ) , Some ( true ) ] ) ,
2374+ )
2375+ . unwrap ( ) ;
2376+
2377+ // scalar of different data type with decimal array
23072378 let decimal_scalar = ScalarValue :: Decimal128 ( Some ( 123_456 ) , 10 , 3 ) ;
23082379 let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int64 , true ) ] ) ) ;
23092380 // scalar == array
0 commit comments