@@ -26,13 +26,13 @@ use arrow::datatypes::DataType;
2626use datafusion_common:: cast:: {
2727 as_large_string_array, as_string_array, as_string_view_array,
2828} ;
29- use datafusion_common:: { exec_datafusion_err, exec_err, plan_err , Result } ;
29+ use datafusion_common:: { exec_datafusion_err, exec_err, Result } ;
3030use datafusion_expr:: {
3131 ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , TypeSignature ,
3232 Volatility ,
3333} ;
3434use datafusion_functions:: utils:: make_scalar_function;
35- use url:: Url ;
35+ use url:: { ParseError , Url } ;
3636
3737#[ derive( Debug , PartialEq , Eq , Hash ) ]
3838pub struct ParseUrl {
@@ -49,20 +49,7 @@ impl ParseUrl {
4949 pub fn new ( ) -> Self {
5050 Self {
5151 signature : Signature :: one_of (
52- vec ! [
53- TypeSignature :: Uniform (
54- 1 ,
55- vec![ DataType :: Utf8View , DataType :: Utf8 , DataType :: LargeUtf8 ] ,
56- ) ,
57- TypeSignature :: Uniform (
58- 2 ,
59- vec![ DataType :: Utf8View , DataType :: Utf8 , DataType :: LargeUtf8 ] ,
60- ) ,
61- TypeSignature :: Uniform (
62- 3 ,
63- vec![ DataType :: Utf8View , DataType :: Utf8 , DataType :: LargeUtf8 ] ,
64- ) ,
65- ] ,
52+ vec ! [ TypeSignature :: String ( 2 ) , TypeSignature :: String ( 3 ) ] ,
6653 Volatility :: Immutable ,
6754 ) ,
6855 }
@@ -95,11 +82,22 @@ impl ParseUrl {
9582 /// * `Err(DataFusionError)` - If the URL is malformed and cannot be parsed
9683 ///
9784 fn parse ( value : & str , part : & str , key : Option < & str > ) -> Result < Option < String > > {
98- Url :: parse ( value)
99- . map_err ( |e| exec_datafusion_err ! ( "{e:?}" ) )
85+ let url: std:: result:: Result < Url , ParseError > = Url :: parse ( value) ;
86+ if let Err ( ParseError :: RelativeUrlWithoutBase ) = url {
87+ return if !value. contains ( "://" ) {
88+ Ok ( None )
89+ } else {
90+ Err ( exec_datafusion_err ! ( "The url is invalid: {value}. Use `try_parse_url` to tolerate invalid URL and return NULL instead. SQLSTATE: 22P02" ) )
91+ } ;
92+ } ;
93+ url. map_err ( |e| exec_datafusion_err ! ( "{e:?}" ) )
10094 . map ( |url| match part {
10195 "HOST" => url. host_str ( ) . map ( String :: from) ,
102- "PATH" => Some ( url. path ( ) . to_string ( ) ) ,
96+ "PATH" => {
97+ let path: String = url. path ( ) . to_string ( ) ;
98+ let path: String = if path == "/" { "" . to_string ( ) } else { path } ;
99+ Some ( path)
100+ }
103101 "QUERY" => match key {
104102 None => url. query ( ) . map ( String :: from) ,
105103 Some ( key) => url
@@ -146,35 +144,7 @@ impl ScalarUDFImpl for ParseUrl {
146144 }
147145
148146 fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
149- if arg_types. len ( ) < 2 || arg_types. len ( ) > 3 {
150- return plan_err ! (
151- "{} expects 2 or 3 arguments, but got {}" ,
152- self . name( ) ,
153- arg_types. len( )
154- ) ;
155- }
156- match arg_types. len ( ) {
157- 2 | 3 => {
158- if arg_types
159- . iter ( )
160- . any ( |arg| matches ! ( arg, DataType :: LargeUtf8 ) )
161- {
162- Ok ( DataType :: LargeUtf8 )
163- } else if arg_types
164- . iter ( )
165- . any ( |arg| matches ! ( arg, DataType :: Utf8View ) )
166- {
167- Ok ( DataType :: Utf8View )
168- } else {
169- Ok ( DataType :: Utf8 )
170- }
171- }
172- _ => plan_err ! (
173- "`{}` expects 2 or 3 arguments, got {}" ,
174- & self . name( ) ,
175- arg_types. len( )
176- ) ,
177- }
147+ Ok ( arg_types[ 0 ] . clone ( ) )
178148 }
179149
180150 fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
@@ -200,6 +170,13 @@ impl ScalarUDFImpl for ParseUrl {
200170/// - The output array type (StringArray or LargeStringArray) is determined by input types
201171///
202172fn spark_parse_url ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
173+ spark_handled_parse_url ( args, |x| x)
174+ }
175+
176+ pub fn spark_handled_parse_url (
177+ args : & [ ArrayRef ] ,
178+ handler_err : impl Fn ( Result < Option < String > > ) -> Result < Option < String > > ,
179+ ) -> Result < ArrayRef > {
203180 if args. len ( ) < 2 || args. len ( ) > 3 {
204181 return exec_err ! (
205182 "{} expects 2 or 3 arguments, but got {}" ,
@@ -212,6 +189,7 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> {
212189 let part = & args[ 1 ] ;
213190
214191 let result = if args. len ( ) == 3 {
192+ // In this case, the 'key' argument is passed
215193 let key = & args[ 2 ] ;
216194
217195 match ( url. data_type ( ) , part. data_type ( ) , key. data_type ( ) ) {
@@ -220,20 +198,23 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> {
220198 as_string_array ( url) ?,
221199 as_string_array ( part) ?,
222200 as_string_array ( key) ?,
201+ handler_err,
223202 )
224203 }
225204 ( DataType :: Utf8View , DataType :: Utf8View , DataType :: Utf8View ) => {
226205 process_parse_url :: < _ , _ , _ , StringViewArray > (
227206 as_string_view_array ( url) ?,
228207 as_string_view_array ( part) ?,
229208 as_string_view_array ( key) ?,
209+ handler_err,
230210 )
231211 }
232212 ( DataType :: LargeUtf8 , DataType :: LargeUtf8 , DataType :: LargeUtf8 ) => {
233213 process_parse_url :: < _ , _ , _ , LargeStringArray > (
234214 as_large_string_array ( url) ?,
235215 as_large_string_array ( part) ?,
236216 as_large_string_array ( key) ?,
217+ handler_err,
237218 )
238219 }
239220 _ => exec_err ! ( "{} expects STRING arguments, got {:?}" , "`parse_url`" , args) ,
@@ -253,20 +234,23 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> {
253234 as_string_array ( url) ?,
254235 as_string_array ( part) ?,
255236 & key,
237+ handler_err,
256238 )
257239 }
258240 ( DataType :: Utf8View , DataType :: Utf8View ) => {
259241 process_parse_url :: < _ , _ , _ , StringViewArray > (
260242 as_string_view_array ( url) ?,
261243 as_string_view_array ( part) ?,
262244 & key,
245+ handler_err,
263246 )
264247 }
265248 ( DataType :: LargeUtf8 , DataType :: LargeUtf8 ) => {
266249 process_parse_url :: < _ , _ , _ , LargeStringArray > (
267250 as_large_string_array ( url) ?,
268251 as_large_string_array ( part) ?,
269252 & key,
253+ handler_err,
270254 )
271255 }
272256 _ => exec_err ! ( "{} expects STRING arguments, got {:?}" , "`parse_url`" , args) ,
@@ -279,6 +263,7 @@ fn process_parse_url<'a, A, B, C, T>(
279263 url_array : & ' a A ,
280264 part_array : & ' a B ,
281265 key_array : & ' a C ,
266+ handle : impl Fn ( Result < Option < String > > ) -> Result < Option < String > > ,
282267) -> Result < ArrayRef >
283268where
284269 & ' a A : StringArrayType < ' a > ,
@@ -292,11 +277,156 @@ where
292277 . zip ( key_array. iter ( ) )
293278 . map ( |( ( url, part) , key) | {
294279 if let ( Some ( url) , Some ( part) , key) = ( url, part, key) {
295- ParseUrl :: parse ( url, part, key)
280+ handle ( ParseUrl :: parse ( url, part, key) )
296281 } else {
297282 Ok ( None )
298283 }
299284 } )
300285 . collect :: < Result < T > > ( )
301286 . map ( |array| Arc :: new ( array) as ArrayRef )
302287}
288+
289+ #[ cfg( test) ]
290+ mod tests {
291+ use super :: * ;
292+ use arrow:: array:: { ArrayRef , Int32Array , StringArray } ;
293+ use datafusion_common:: Result ;
294+ use std:: array:: from_ref;
295+ use std:: sync:: Arc ;
296+
297+ fn sa ( vals : & [ Option < & str > ] ) -> ArrayRef {
298+ Arc :: new ( StringArray :: from ( vals. to_vec ( ) ) ) as ArrayRef
299+ }
300+
301+ #[ test]
302+ fn test_parse_host ( ) -> Result < ( ) > {
303+ let got = ParseUrl :: parse ( "https://example.com/a?x=1" , "HOST" , None ) ?;
304+ assert_eq ! ( got, Some ( "example.com" . to_string( ) ) ) ;
305+ Ok ( ( ) )
306+ }
307+
308+ #[ test]
309+ fn test_parse_query_no_key_vs_with_key ( ) -> Result < ( ) > {
310+ let got_all = ParseUrl :: parse ( "https://ex.com/p?a=1&b=2" , "QUERY" , None ) ?;
311+ assert_eq ! ( got_all, Some ( "a=1&b=2" . to_string( ) ) ) ;
312+
313+ let got_a = ParseUrl :: parse ( "https://ex.com/p?a=1&b=2" , "QUERY" , Some ( "a" ) ) ?;
314+ assert_eq ! ( got_a, Some ( "1" . to_string( ) ) ) ;
315+
316+ let got_c = ParseUrl :: parse ( "https://ex.com/p?a=1&b=2" , "QUERY" , Some ( "c" ) ) ?;
317+ assert_eq ! ( got_c, None ) ;
318+ Ok ( ( ) )
319+ }
320+
321+ #[ test]
322+ fn test_parse_ref_protocol_userinfo_file_authority ( ) -> Result < ( ) > {
323+ let url = "ftp://user:pwd@ftp.example.com:21/files?x=1#frag" ;
324+ assert_eq ! ( ParseUrl :: parse( url, "REF" , None ) ?, Some ( "frag" . to_string( ) ) ) ;
325+ assert_eq ! (
326+ ParseUrl :: parse( url, "PROTOCOL" , None ) ?,
327+ Some ( "ftp" . to_string( ) )
328+ ) ;
329+ assert_eq ! (
330+ ParseUrl :: parse( url, "USERINFO" , None ) ?,
331+ Some ( "user:pwd" . to_string( ) )
332+ ) ;
333+ assert_eq ! (
334+ ParseUrl :: parse( url, "FILE" , None ) ?,
335+ Some ( "/files?x=1" . to_string( ) )
336+ ) ;
337+ assert_eq ! (
338+ ParseUrl :: parse( url, "AUTHORITY" , None ) ?,
339+ Some ( "user:pwd@ftp.example.com" . to_string( ) )
340+ ) ;
341+ Ok ( ( ) )
342+ }
343+
344+ #[ test]
345+ fn test_parse_path_root_is_empty_string ( ) -> Result < ( ) > {
346+ let got = ParseUrl :: parse ( "https://example.com/" , "PATH" , None ) ?;
347+ assert_eq ! ( got, Some ( "" . to_string( ) ) ) ;
348+ Ok ( ( ) )
349+ }
350+
351+ #[ test]
352+ fn test_parse_malformed_url_returns_error ( ) -> Result < ( ) > {
353+ let got = ParseUrl :: parse ( "notaurl" , "HOST" , None ) ?;
354+ assert_eq ! ( got, None ) ;
355+ Ok ( ( ) )
356+ }
357+
358+ #[ test]
359+ fn test_spark_utf8_two_args ( ) -> Result < ( ) > {
360+ let urls = sa ( & [ Some ( "https://example.com/a?x=1" ) , Some ( "https://ex.com/" ) ] ) ;
361+ let parts = sa ( & [ Some ( "HOST" ) , Some ( "PATH" ) ] ) ;
362+
363+ let out = spark_handled_parse_url ( & [ urls, parts] , |x| x) ?;
364+ let out_sa = out. as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
365+
366+ assert_eq ! ( out_sa. len( ) , 2 ) ;
367+ assert_eq ! ( out_sa. value( 0 ) , "example.com" ) ;
368+ assert_eq ! ( out_sa. value( 1 ) , "" ) ;
369+ Ok ( ( ) )
370+ }
371+
372+ #[ test]
373+ fn test_spark_utf8_three_args_query_key ( ) -> Result < ( ) > {
374+ let urls = sa ( & [
375+ Some ( "https://example.com/a?x=1&y=2" ) ,
376+ Some ( "https://ex.com/?a=1" ) ,
377+ ] ) ;
378+ let parts = sa ( & [ Some ( "QUERY" ) , Some ( "QUERY" ) ] ) ;
379+ let keys = sa ( & [ Some ( "y" ) , Some ( "b" ) ] ) ;
380+
381+ let out = spark_handled_parse_url ( & [ urls, parts, keys] , |x| x) ?;
382+ let out_sa = out. as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
383+
384+ assert_eq ! ( out_sa. len( ) , 2 ) ;
385+ assert_eq ! ( out_sa. value( 0 ) , "2" ) ;
386+ assert ! ( out_sa. is_null( 1 ) ) ;
387+ Ok ( ( ) )
388+ }
389+
390+ #[ test]
391+ fn test_spark_userinfo_and_nulls ( ) -> Result < ( ) > {
392+ let urls = sa ( & [
393+ Some ( "ftp://user:pwd@ftp.example.com:21/files" ) ,
394+ Some ( "https://example.com" ) ,
395+ None ,
396+ ] ) ;
397+ let parts = sa ( & [ Some ( "USERINFO" ) , Some ( "USERINFO" ) , Some ( "USERINFO" ) ] ) ;
398+
399+ let out = spark_handled_parse_url ( & [ urls, parts] , |x| x) ?;
400+ let out_sa = out. as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
401+
402+ assert_eq ! ( out_sa. len( ) , 3 ) ;
403+ assert_eq ! ( out_sa. value( 0 ) , "user:pwd" ) ;
404+ assert ! ( out_sa. is_null( 1 ) ) ;
405+ assert ! ( out_sa. is_null( 2 ) ) ;
406+ Ok ( ( ) )
407+ }
408+
409+ #[ test]
410+ fn test_invalid_arg_count ( ) {
411+ let urls = sa ( & [ Some ( "https://example.com" ) ] ) ;
412+ let err = spark_handled_parse_url ( from_ref ( & urls) , |x| x) . unwrap_err ( ) ;
413+ assert ! ( format!( "{err}" ) . contains( "expects 2 or 3 arguments" ) ) ;
414+
415+ let parts = sa ( & [ Some ( "HOST" ) ] ) ;
416+ let keys = sa ( & [ Some ( "x" ) ] ) ;
417+ let err =
418+ spark_handled_parse_url ( & [ urls, parts, keys, sa ( & [ Some ( "extra" ) ] ) ] , |x| x)
419+ . unwrap_err ( ) ;
420+ assert ! ( format!( "{err}" ) . contains( "expects 2 or 3 arguments" ) ) ;
421+ }
422+
423+ #[ test]
424+ fn test_non_string_types_error ( ) {
425+ let urls = sa ( & [ Some ( "https://example.com" ) ] ) ;
426+ let bad_part = Arc :: new ( Int32Array :: from ( vec ! [ 1 ] ) ) as ArrayRef ;
427+
428+ let err = spark_handled_parse_url ( & [ urls, bad_part] , |x| x) . unwrap_err ( ) ;
429+ let msg = format ! ( "{err}" ) ;
430+ assert ! ( msg. contains( "expects STRING arguments" ) ) ;
431+ }
432+ }
0 commit comments