18
18
use std:: any:: Any ;
19
19
use std:: sync:: Arc ;
20
20
21
- use crate :: strings:: make_and_append_view;
22
- use crate :: utils:: make_scalar_function;
23
21
use arrow:: array:: {
24
- Array , ArrayIter , ArrayRef , AsArray , Int64Array , NullBufferBuilder , StringArrayType ,
25
- StringViewArray , StringViewBuilder ,
22
+ Array , ArrayIter , ArrayRef , AsArray , Int64Array , OffsetSizeTrait ,
23
+ StringArrayType , StringViewBuilder ,
26
24
} ;
27
- use arrow:: buffer:: ScalarBuffer ;
28
25
use arrow:: datatypes:: DataType ;
26
+
29
27
use datafusion_common:: cast:: as_int64_array;
30
28
use datafusion_common:: { exec_err, plan_err, Result } ;
31
29
use datafusion_expr:: {
32
30
ColumnarValue , Documentation , ScalarUDFImpl , Signature , Volatility ,
33
31
} ;
34
32
use datafusion_macros:: user_doc;
35
33
34
+ use crate :: utils:: { make_scalar_function, utf8_to_str_type} ;
35
+
36
36
#[ user_doc(
37
37
doc_section( label = "String Functions" ) ,
38
38
description = "Extracts a substring of a specified number of characters from a specific starting position in a string." ,
@@ -44,7 +44,7 @@ use datafusion_macros::user_doc;
44
44
| substr(Utf8("datafusion"),Int64(5),Int64(3)) |
45
45
+----------------------------------------------+
46
46
| fus |
47
- +----------------------------------------------+
47
+ +----------------------------------------------+
48
48
```"# ,
49
49
standard_argument( name = "str" , prefix = "String" ) ,
50
50
argument(
@@ -90,9 +90,8 @@ impl ScalarUDFImpl for SubstrFunc {
90
90
& self . signature
91
91
}
92
92
93
- // `SubstrFunc` always generates `Utf8View` output for its efficiency.
94
- fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
95
- Ok ( DataType :: Utf8View )
93
+ fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
94
+ utf8_to_str_type ( & arg_types[ 0 ] , "substr" )
96
95
}
97
96
98
97
fn invoke_with_args (
@@ -177,28 +176,21 @@ impl ScalarUDFImpl for SubstrFunc {
177
176
}
178
177
}
179
178
180
- /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
181
- /// substr('alphabet', 3) = 'phabet'
182
- /// substr('alphabet', 3, 2) = 'ph'
183
- /// The implementation uses UTF-8 code points as characters
184
179
pub fn substr ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
185
180
match args[ 0 ] . data_type ( ) {
186
181
DataType :: Utf8 => {
187
182
let string_array = args[ 0 ] . as_string :: < i32 > ( ) ;
188
- string_substr :: < _ > ( string_array, & args[ 1 ..] )
183
+ calculate_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
189
184
}
190
185
DataType :: LargeUtf8 => {
191
186
let string_array = args[ 0 ] . as_string :: < i64 > ( ) ;
192
- string_substr :: < _ > ( string_array, & args[ 1 ..] )
187
+ calculate_substr :: < _ , i64 > ( string_array, & args[ 1 ..] )
193
188
}
194
189
DataType :: Utf8View => {
195
190
let string_array = args[ 0 ] . as_string_view ( ) ;
196
- string_view_substr ( string_array, & args[ 1 ..] )
191
+ calculate_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
197
192
}
198
- other => exec_err ! (
199
- "Unsupported data type {other:?} for function substr,\
200
- expected Utf8View, Utf8 or LargeUtf8."
201
- ) ,
193
+ other => exec_err ! ( "Unsupported data type {other:?} for function substr" ) ,
202
194
}
203
195
}
204
196
@@ -312,120 +304,11 @@ fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>(
312
304
}
313
305
}
314
306
315
- // The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44
316
- // From<u128> for ByteView
317
- fn string_view_substr (
318
- string_view_array : & StringViewArray ,
319
- args : & [ ArrayRef ] ,
320
- ) -> Result < ArrayRef > {
321
- let mut views_buf = Vec :: with_capacity ( string_view_array. len ( ) ) ;
322
- let mut null_builder = NullBufferBuilder :: new ( string_view_array. len ( ) ) ;
323
-
324
- let start_array = as_int64_array ( & args[ 0 ] ) ?;
325
- let count_array_opt = if args. len ( ) == 2 {
326
- Some ( as_int64_array ( & args[ 1 ] ) ?)
327
- } else {
328
- None
329
- } ;
330
-
331
- let enable_ascii_fast_path =
332
- enable_ascii_fast_path ( & string_view_array, start_array, count_array_opt) ;
333
-
334
- // In either case of `substr(s, i)` or `substr(s, i, cnt)`
335
- // If any of input argument is `NULL`, the result is `NULL`
336
- match args. len ( ) {
337
- 1 => {
338
- for ( ( str_opt, raw_view) , start_opt) in string_view_array
339
- . iter ( )
340
- . zip ( string_view_array. views ( ) . iter ( ) )
341
- . zip ( start_array. iter ( ) )
342
- {
343
- if let ( Some ( str) , Some ( start) ) = ( str_opt, start_opt) {
344
- let ( start, end) =
345
- get_true_start_end ( str, start, None , enable_ascii_fast_path) ;
346
- let substr = & str[ start..end] ;
347
-
348
- make_and_append_view (
349
- & mut views_buf,
350
- & mut null_builder,
351
- raw_view,
352
- substr,
353
- start as u32 ,
354
- ) ;
355
- } else {
356
- null_builder. append_null ( ) ;
357
- views_buf. push ( 0 ) ;
358
- }
359
- }
360
- }
361
- 2 => {
362
- let count_array = count_array_opt. unwrap ( ) ;
363
- for ( ( ( str_opt, raw_view) , start_opt) , count_opt) in string_view_array
364
- . iter ( )
365
- . zip ( string_view_array. views ( ) . iter ( ) )
366
- . zip ( start_array. iter ( ) )
367
- . zip ( count_array. iter ( ) )
368
- {
369
- if let ( Some ( str) , Some ( start) , Some ( count) ) =
370
- ( str_opt, start_opt, count_opt)
371
- {
372
- if count < 0 {
373
- return exec_err ! (
374
- "negative substring length not allowed: substr(<str>, {start}, {count})"
375
- ) ;
376
- } else {
377
- if start == i64:: MIN {
378
- return exec_err ! (
379
- "negative overflow when calculating skip value"
380
- ) ;
381
- }
382
- let ( start, end) = get_true_start_end (
383
- str,
384
- start,
385
- Some ( count as u64 ) ,
386
- enable_ascii_fast_path,
387
- ) ;
388
- let substr = & str[ start..end] ;
389
-
390
- make_and_append_view (
391
- & mut views_buf,
392
- & mut null_builder,
393
- raw_view,
394
- substr,
395
- start as u32 ,
396
- ) ;
397
- }
398
- } else {
399
- null_builder. append_null ( ) ;
400
- views_buf. push ( 0 ) ;
401
- }
402
- }
403
- }
404
- other => {
405
- return exec_err ! (
406
- "substr was called with {other} arguments. It requires 2 or 3."
407
- )
408
- }
409
- }
410
-
411
- let views_buf = ScalarBuffer :: from ( views_buf) ;
412
- let nulls_buf = null_builder. finish ( ) ;
413
-
414
- // Safety:
415
- // (1) The blocks of the given views are all provided
416
- // (2) Each of the range `view.offset+start..end` of view in views_buf is within
417
- // the bounds of each of the blocks
418
- unsafe {
419
- let array = StringViewArray :: new_unchecked (
420
- views_buf,
421
- string_view_array. data_buffers ( ) . to_vec ( ) ,
422
- nulls_buf,
423
- ) ;
424
- Ok ( Arc :: new ( array) as ArrayRef )
425
- }
426
- }
427
-
428
- fn string_substr < ' a , V > ( string_array : V , args : & [ ArrayRef ] ) -> Result < ArrayRef >
307
+ /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
308
+ /// substr('alphabet', 3) = 'phabet'
309
+ /// substr('alphabet', 3, 2) = 'ph'
310
+ /// The implementation uses UTF-8 code points as characters
311
+ fn calculate_substr < ' a , V , T > ( string_array : V , args : & [ ArrayRef ] ) -> Result < ArrayRef >
429
312
where
430
313
V : StringArrayType < ' a > ,
431
314
{
@@ -507,8 +390,8 @@ where
507
390
508
391
#[ cfg( test) ]
509
392
mod tests {
510
- use arrow:: array:: { Array , StringViewArray } ;
511
- use arrow:: datatypes:: DataType :: Utf8View ;
393
+ use arrow:: array:: { Array , StringArray } ;
394
+ use arrow:: datatypes:: DataType :: Utf8 ;
512
395
513
396
use datafusion_common:: { exec_err, Result , ScalarValue } ;
514
397
use datafusion_expr:: { ColumnarValue , ScalarUDFImpl } ;
@@ -526,8 +409,8 @@ mod tests {
526
409
] ,
527
410
Ok ( None ) ,
528
411
& str ,
529
- Utf8View ,
530
- StringViewArray
412
+ Utf8 ,
413
+ StringArray
531
414
) ;
532
415
test_function ! (
533
416
SubstrFunc :: new( ) ,
@@ -539,35 +422,8 @@ mod tests {
539
422
] ,
540
423
Ok ( Some ( "alphabet" ) ) ,
541
424
& str ,
542
- Utf8View ,
543
- StringViewArray
544
- ) ;
545
- test_function ! (
546
- SubstrFunc :: new( ) ,
547
- vec![
548
- ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
549
- "this és longer than 12B"
550
- ) ) ) ) ,
551
- ColumnarValue :: Scalar ( ScalarValue :: from( 5i64 ) ) ,
552
- ColumnarValue :: Scalar ( ScalarValue :: from( 2i64 ) ) ,
553
- ] ,
554
- Ok ( Some ( " é" ) ) ,
555
- & str ,
556
- Utf8View ,
557
- StringViewArray
558
- ) ;
559
- test_function ! (
560
- SubstrFunc :: new( ) ,
561
- vec![
562
- ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
563
- "this is longer than 12B"
564
- ) ) ) ) ,
565
- ColumnarValue :: Scalar ( ScalarValue :: from( 5i64 ) ) ,
566
- ] ,
567
- Ok ( Some ( " is longer than 12B" ) ) ,
568
- & str ,
569
- Utf8View ,
570
- StringViewArray
425
+ Utf8 ,
426
+ StringArray
571
427
) ;
572
428
test_function ! (
573
429
SubstrFunc :: new( ) ,
@@ -579,8 +435,8 @@ mod tests {
579
435
] ,
580
436
Ok ( Some ( "ésoj" ) ) ,
581
437
& str ,
582
- Utf8View ,
583
- StringViewArray
438
+ Utf8 ,
439
+ StringArray
584
440
) ;
585
441
test_function ! (
586
442
SubstrFunc :: new( ) ,
@@ -593,8 +449,8 @@ mod tests {
593
449
] ,
594
450
Ok ( Some ( "ph" ) ) ,
595
451
& str ,
596
- Utf8View ,
597
- StringViewArray
452
+ Utf8 ,
453
+ StringArray
598
454
) ;
599
455
test_function ! (
600
456
SubstrFunc :: new( ) ,
@@ -607,8 +463,8 @@ mod tests {
607
463
] ,
608
464
Ok ( Some ( "phabet" ) ) ,
609
465
& str ,
610
- Utf8View ,
611
- StringViewArray
466
+ Utf8 ,
467
+ StringArray
612
468
) ;
613
469
test_function ! (
614
470
SubstrFunc :: new( ) ,
0 commit comments