@@ -19,25 +19,21 @@ use std::any::Any;
1919use std:: mem:: size_of;
2020use std:: sync:: Arc ;
2121
22- use arrow:: array:: { Array , ArrayRef , ArrowPrimitiveType , AsArray , PrimitiveArray } ;
23- use arrow:: compute:: try_binary;
24- use arrow:: datatypes:: DataType :: {
25- Int16 , Int32 , Int64 , Int8 , UInt16 , UInt32 , UInt64 , UInt8 ,
26- } ;
27- use arrow:: datatypes:: {
28- ArrowNativeType , DataType , Int16Type , Int32Type , Int64Type , Int8Type , UInt16Type ,
29- UInt32Type , UInt64Type , UInt8Type ,
22+ use arrow:: array:: {
23+ downcast_integer_array, Array , ArrayRef , ArrowPrimitiveType , AsArray , Int32Array ,
24+ Int8Array , PrimitiveArray ,
3025} ;
31- use datafusion_common:: { exec_err, Result } ;
26+ use arrow:: compute:: try_binary;
27+ use arrow:: datatypes:: { ArrowNativeType , DataType , Int32Type , Int8Type } ;
28+ use datafusion_common:: types:: { logical_int32, NativeType } ;
29+ use datafusion_common:: utils:: take_function_args;
30+ use datafusion_common:: { internal_err, Result } ;
3231use datafusion_expr:: {
33- ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , Volatility ,
32+ Coercion , ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature ,
33+ TypeSignatureClass , Volatility ,
3434} ;
3535use datafusion_functions:: utils:: make_scalar_function;
3636
37- use crate :: function:: error_utils:: {
38- invalid_arg_count_exec_err, unsupported_data_type_exec_err,
39- } ;
40-
4137#[ derive( Debug , PartialEq , Eq , Hash ) ]
4238pub struct SparkBitGet {
4339 signature : Signature ,
@@ -53,7 +49,17 @@ impl Default for SparkBitGet {
5349impl SparkBitGet {
5450 pub fn new ( ) -> Self {
5551 Self {
56- signature : Signature :: user_defined ( Volatility :: Immutable ) ,
52+ signature : Signature :: coercible (
53+ vec ! [
54+ Coercion :: new_exact( TypeSignatureClass :: Integer ) ,
55+ Coercion :: new_implicit(
56+ TypeSignatureClass :: Native ( logical_int32( ) ) ,
57+ vec![ TypeSignatureClass :: Integer ] ,
58+ NativeType :: Int32 ,
59+ ) ,
60+ ] ,
61+ Volatility :: Immutable ,
62+ ) ,
5763 aliases : vec ! [ "getbit" . to_string( ) ] ,
5864 }
5965 }
@@ -64,34 +70,6 @@ impl ScalarUDFImpl for SparkBitGet {
6470 self
6571 }
6672
67- fn coerce_types ( & self , arg_types : & [ DataType ] ) -> Result < Vec < DataType > > {
68- if arg_types. len ( ) != 2 {
69- return Err ( invalid_arg_count_exec_err (
70- "bit_get" ,
71- ( 2 , 2 ) ,
72- arg_types. len ( ) ,
73- ) ) ;
74- }
75- if !arg_types[ 0 ] . is_integer ( ) && !arg_types[ 0 ] . is_null ( ) {
76- return Err ( unsupported_data_type_exec_err (
77- "bit_get" ,
78- "Integer Type" ,
79- & arg_types[ 0 ] ,
80- ) ) ;
81- }
82- if !arg_types[ 1 ] . is_integer ( ) && !arg_types[ 1 ] . is_null ( ) {
83- return Err ( unsupported_data_type_exec_err (
84- "bit_get" ,
85- "Integer Type" ,
86- & arg_types[ 1 ] ,
87- ) ) ;
88- }
89- if arg_types[ 0 ] . is_null ( ) {
90- return Ok ( vec ! [ Int8 , Int32 ] ) ;
91- }
92- Ok ( vec ! [ arg_types[ 0 ] . clone( ) , Int32 ] )
93- }
94-
9573 fn name ( & self ) -> & str {
9674 "bit_get"
9775 }
@@ -105,7 +83,7 @@ impl ScalarUDFImpl for SparkBitGet {
10583 }
10684
10785 fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
108- Ok ( Int8 )
86+ Ok ( DataType :: Int8 )
10987 }
11088
11189 fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
@@ -115,8 +93,8 @@ impl ScalarUDFImpl for SparkBitGet {
11593
11694fn spark_bit_get_inner < T : ArrowPrimitiveType > (
11795 value : & PrimitiveArray < T > ,
118- pos : & PrimitiveArray < Int32Type > ,
119- ) -> Result < PrimitiveArray < Int8Type > > {
96+ pos : & Int32Array ,
97+ ) -> Result < Int8Array > {
12098 let bit_length = ( size_of :: < T :: Native > ( ) * 8 ) as i32 ;
12199
122100 let result: PrimitiveArray < Int8Type > = try_binary ( value, pos, |value, pos| {
@@ -130,164 +108,13 @@ fn spark_bit_get_inner<T: ArrowPrimitiveType>(
130108 Ok ( result)
131109}
132110
133- pub fn spark_bit_get ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
134- if args. len ( ) != 2 {
135- return exec_err ! ( "`bit_get` expects exactly two arguments" ) ;
136- }
137-
138- if args[ 1 ] . data_type ( ) != & Int32 {
139- return exec_err ! ( "`bit_get` expects Int32 as the second argument" ) ;
140- }
141-
142- let pos_arg = args[ 1 ] . as_primitive :: < Int32Type > ( ) ;
143-
144- let ret = match & args[ 0 ] . data_type ( ) {
145- Int64 => {
146- let value_arg = args[ 0 ] . as_primitive :: < Int64Type > ( ) ;
147- spark_bit_get_inner ( value_arg, pos_arg)
148- }
149- Int32 => {
150- let value_arg = args[ 0 ] . as_primitive :: < Int32Type > ( ) ;
151- spark_bit_get_inner ( value_arg, pos_arg)
152- }
153- Int16 => {
154- let value_arg = args[ 0 ] . as_primitive :: < Int16Type > ( ) ;
155- spark_bit_get_inner ( value_arg, pos_arg)
156- }
157- Int8 => {
158- let value_arg = args[ 0 ] . as_primitive :: < Int8Type > ( ) ;
159- spark_bit_get_inner ( value_arg, pos_arg)
160- }
161- UInt64 => {
162- let value_arg = args[ 0 ] . as_primitive :: < UInt64Type > ( ) ;
163- spark_bit_get_inner ( value_arg, pos_arg)
164- }
165- UInt32 => {
166- let value_arg = args[ 0 ] . as_primitive :: < UInt32Type > ( ) ;
167- spark_bit_get_inner ( value_arg, pos_arg)
168- }
169- UInt16 => {
170- let value_arg = args[ 0 ] . as_primitive :: < UInt16Type > ( ) ;
171- spark_bit_get_inner ( value_arg, pos_arg)
172- }
173- UInt8 => {
174- let value_arg = args[ 0 ] . as_primitive :: < UInt8Type > ( ) ;
175- spark_bit_get_inner ( value_arg, pos_arg)
176- }
177- _ => {
178- exec_err ! (
179- "`bit_get` expects Int64, Int32, Int16, or Int8 as the first argument"
180- )
181- }
182- } ?;
111+ fn spark_bit_get ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
112+ let [ value, position] = take_function_args ( "bit_get" , args) ?;
113+ let pos_arg = position. as_primitive :: < Int32Type > ( ) ;
114+ let ret = downcast_integer_array ! (
115+ value => spark_bit_get_inner( value, pos_arg) ,
116+ DataType :: Null => Ok ( Int8Array :: new_null( value. len( ) ) ) ,
117+ d => internal_err!( "Unsupported datatype for bit_get: {d}" ) ,
118+ ) ?;
183119 Ok ( Arc :: new ( ret) )
184120}
185-
186- #[ cfg( test) ]
187- mod tests {
188- use arrow:: array:: { Int32Array , Int64Array } ;
189-
190- use super :: * ;
191-
192- #[ test]
193- fn test_bit_get_basic ( ) {
194- // Test bit_get(11, 0) - 11 = 1011 in binary, bit 0 = 1
195- let result = spark_bit_get ( & [
196- Arc :: new ( Int64Array :: from ( vec ! [ 11 ] ) ) ,
197- Arc :: new ( Int32Array :: from ( vec ! [ 0 ] ) ) ,
198- ] )
199- . unwrap ( ) ;
200-
201- assert_eq ! ( result. as_primitive:: <Int8Type >( ) . value( 0 ) , 1 ) ;
202-
203- // Test bit_get(11, 2) - 11 = 1011 in binary, bit 2 = 0
204- let result = spark_bit_get ( & [
205- Arc :: new ( Int64Array :: from ( vec ! [ 11 ] ) ) ,
206- Arc :: new ( Int32Array :: from ( vec ! [ 2 ] ) ) ,
207- ] )
208- . unwrap ( ) ;
209-
210- assert_eq ! ( result. as_primitive:: <Int8Type >( ) . value( 0 ) , 0 ) ;
211-
212- // Test bit_get(11, 3) - 11 = 1011 in binary, bit 3 = 1
213- let result = spark_bit_get ( & [
214- Arc :: new ( Int64Array :: from ( vec ! [ 11 ] ) ) ,
215- Arc :: new ( Int32Array :: from ( vec ! [ 3 ] ) ) ,
216- ] )
217- . unwrap ( ) ;
218-
219- assert_eq ! ( result. as_primitive:: <Int8Type >( ) . value( 0 ) , 1 ) ;
220- }
221-
222- #[ test]
223- fn test_bit_get_edge_cases ( ) {
224- // Test with 0
225- let result = spark_bit_get ( & [
226- Arc :: new ( Int64Array :: from ( vec ! [ 0 ] ) ) ,
227- Arc :: new ( Int32Array :: from ( vec ! [ 0 ] ) ) ,
228- ] )
229- . unwrap ( ) ;
230-
231- assert_eq ! ( result. as_primitive:: <Int8Type >( ) . value( 0 ) , 0 ) ;
232-
233- let result = spark_bit_get ( & [
234- Arc :: new ( Int64Array :: from ( vec ! [ 11 ] ) ) ,
235- Arc :: new ( Int32Array :: from ( vec ! [ -1 ] ) ) ,
236- ] ) ;
237- assert_eq ! (
238- result. unwrap_err( ) . message( ) . lines( ) . next( ) . unwrap( ) ,
239- "Compute error: bit_get: position -1 is out of bounds. Expected pos < 64 and pos >= 0"
240- ) ;
241-
242- let result = spark_bit_get ( & [
243- Arc :: new ( Int64Array :: from ( vec ! [ 11 ] ) ) ,
244- Arc :: new ( Int32Array :: from ( vec ! [ 64 ] ) ) ,
245- ] ) ;
246-
247- assert_eq ! (
248- result. unwrap_err( ) . message( ) . lines( ) . next( ) . unwrap( ) ,
249- "Compute error: bit_get: position 64 is out of bounds. Expected pos < 64 and pos >= 0"
250- ) ;
251- }
252-
253- #[ test]
254- fn test_bit_get_null_inputs ( ) {
255- // Test with NULL value
256- let result = spark_bit_get ( & [
257- Arc :: new ( Int64Array :: from ( vec ! [ None ] ) ) ,
258- Arc :: new ( Int32Array :: from ( vec ! [ 0 ] ) ) ,
259- ] )
260- . unwrap ( ) ;
261-
262- assert_eq ! ( result. as_primitive:: <Int8Type >( ) . value( 0 ) , 0 ) ;
263-
264- // Test with NULL position
265- let result = spark_bit_get ( & [
266- Arc :: new ( Int64Array :: from ( vec ! [ 11 ] ) ) ,
267- Arc :: new ( Int32Array :: from ( vec ! [ None ] ) ) ,
268- ] )
269- . unwrap ( ) ;
270-
271- assert_eq ! ( result. as_primitive:: <Int8Type >( ) . value( 0 ) , 0 ) ;
272- }
273-
274- #[ test]
275- fn test_bit_get_large_numbers ( ) {
276- // Test with larger number
277- let result = spark_bit_get ( & [
278- Arc :: new ( Int64Array :: from ( vec ! [ 255 ] ) ) , // 11111111 in binary
279- Arc :: new ( Int32Array :: from ( vec ! [ 7 ] ) ) , // bit 7 = 1
280- ] )
281- . unwrap ( ) ;
282-
283- assert_eq ! ( result. as_primitive:: <Int8Type >( ) . value( 0 ) , 1 ) ;
284-
285- let result = spark_bit_get ( & [
286- Arc :: new ( Int64Array :: from ( vec ! [ 255 ] ) ) , // 11111111 in binary
287- Arc :: new ( Int32Array :: from ( vec ! [ 8 ] ) ) , // bit 8 = 0
288- ] )
289- . unwrap ( ) ;
290-
291- assert_eq ! ( result. as_primitive:: <Int8Type >( ) . value( 0 ) , 0 ) ;
292- }
293- }
0 commit comments