Skip to content

Commit 58ddf0d

Browse files
rafafrdzJefffreyalamb
authored
feat(spark): implement Spark try_parse_url function (#17485)
* impl try_parse_url spark function * suggestions * fix parse_url * fix parse_url * fix parse_url * suggestions * suggestions * suggestions * suggestions * edit * tests and clippy * suggestions and tests * fixing parse_url * fixing parse_url --------- Co-authored-by: Jeffrey Vo <jeffrey.vo.australia@gmail.com> Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent ed5f8e7 commit 58ddf0d

File tree

5 files changed

+504
-53
lines changed

5 files changed

+504
-53
lines changed

datafusion/spark/src/function/url/mod.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,26 @@ use datafusion_functions::make_udf_function;
2020
use std::sync::Arc;
2121

2222
pub mod parse_url;
23+
pub mod try_parse_url;
2324

2425
make_udf_function!(parse_url::ParseUrl, parse_url);
26+
make_udf_function!(try_parse_url::TryParseUrl, try_parse_url);
2527

2628
pub mod expr_fn {
2729
use datafusion_functions::export_functions;
2830

29-
export_functions!((parse_url, "Extracts a part from a URL.", args));
31+
export_functions!((
32+
parse_url,
33+
"Extracts a part from a URL, throwing an error if an invalid URL is provided.",
34+
args
35+
));
36+
export_functions!((
37+
try_parse_url,
38+
"Same as parse_url but returns NULL if an invalid URL is provided.",
39+
args
40+
));
3041
}
3142

3243
pub fn functions() -> Vec<Arc<ScalarUDF>> {
33-
vec![parse_url()]
44+
vec![parse_url(), try_parse_url()]
3445
}

datafusion/spark/src/function/url/parse_url.rs

Lines changed: 179 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ use arrow::datatypes::DataType;
2626
use datafusion_common::cast::{
2727
as_large_string_array, as_string_array, as_string_view_array,
2828
};
29-
use datafusion_common::{exec_datafusion_err, exec_err, plan_err, Result};
29+
use datafusion_common::{exec_datafusion_err, exec_err, Result};
3030
use datafusion_expr::{
3131
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
3232
Volatility,
3333
};
3434
use datafusion_functions::utils::make_scalar_function;
35-
use url::Url;
35+
use url::{ParseError, Url};
3636

3737
#[derive(Debug, PartialEq, Eq, Hash)]
3838
pub struct ParseUrl {
@@ -49,20 +49,7 @@ impl ParseUrl {
4949
pub fn new() -> Self {
5050
Self {
5151
signature: Signature::one_of(
52-
vec![
53-
TypeSignature::Uniform(
54-
1,
55-
vec![DataType::Utf8View, DataType::Utf8, DataType::LargeUtf8],
56-
),
57-
TypeSignature::Uniform(
58-
2,
59-
vec![DataType::Utf8View, DataType::Utf8, DataType::LargeUtf8],
60-
),
61-
TypeSignature::Uniform(
62-
3,
63-
vec![DataType::Utf8View, DataType::Utf8, DataType::LargeUtf8],
64-
),
65-
],
52+
vec![TypeSignature::String(2), TypeSignature::String(3)],
6653
Volatility::Immutable,
6754
),
6855
}
@@ -95,11 +82,22 @@ impl ParseUrl {
9582
/// * `Err(DataFusionError)` - If the URL is malformed and cannot be parsed
9683
///
9784
fn parse(value: &str, part: &str, key: Option<&str>) -> Result<Option<String>> {
98-
Url::parse(value)
99-
.map_err(|e| exec_datafusion_err!("{e:?}"))
85+
let url: std::result::Result<Url, ParseError> = Url::parse(value);
86+
if let Err(ParseError::RelativeUrlWithoutBase) = url {
87+
return if !value.contains("://") {
88+
Ok(None)
89+
} else {
90+
Err(exec_datafusion_err!("The url is invalid: {value}. Use `try_parse_url` to tolerate invalid URL and return NULL instead. SQLSTATE: 22P02"))
91+
};
92+
};
93+
url.map_err(|e| exec_datafusion_err!("{e:?}"))
10094
.map(|url| match part {
10195
"HOST" => url.host_str().map(String::from),
102-
"PATH" => Some(url.path().to_string()),
96+
"PATH" => {
97+
let path: String = url.path().to_string();
98+
let path: String = if path == "/" { "".to_string() } else { path };
99+
Some(path)
100+
}
103101
"QUERY" => match key {
104102
None => url.query().map(String::from),
105103
Some(key) => url
@@ -146,35 +144,7 @@ impl ScalarUDFImpl for ParseUrl {
146144
}
147145

148146
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
149-
if arg_types.len() < 2 || arg_types.len() > 3 {
150-
return plan_err!(
151-
"{} expects 2 or 3 arguments, but got {}",
152-
self.name(),
153-
arg_types.len()
154-
);
155-
}
156-
match arg_types.len() {
157-
2 | 3 => {
158-
if arg_types
159-
.iter()
160-
.any(|arg| matches!(arg, DataType::LargeUtf8))
161-
{
162-
Ok(DataType::LargeUtf8)
163-
} else if arg_types
164-
.iter()
165-
.any(|arg| matches!(arg, DataType::Utf8View))
166-
{
167-
Ok(DataType::Utf8View)
168-
} else {
169-
Ok(DataType::Utf8)
170-
}
171-
}
172-
_ => plan_err!(
173-
"`{}` expects 2 or 3 arguments, got {}",
174-
&self.name(),
175-
arg_types.len()
176-
),
177-
}
147+
Ok(arg_types[0].clone())
178148
}
179149

180150
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
@@ -200,6 +170,13 @@ impl ScalarUDFImpl for ParseUrl {
200170
/// - The output array type (StringArray or LargeStringArray) is determined by input types
201171
///
202172
fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> {
173+
spark_handled_parse_url(args, |x| x)
174+
}
175+
176+
pub fn spark_handled_parse_url(
177+
args: &[ArrayRef],
178+
handler_err: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
179+
) -> Result<ArrayRef> {
203180
if args.len() < 2 || args.len() > 3 {
204181
return exec_err!(
205182
"{} expects 2 or 3 arguments, but got {}",
@@ -212,6 +189,7 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> {
212189
let part = &args[1];
213190

214191
let result = if args.len() == 3 {
192+
// In this case, the 'key' argument is passed
215193
let key = &args[2];
216194

217195
match (url.data_type(), part.data_type(), key.data_type()) {
@@ -220,20 +198,23 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> {
220198
as_string_array(url)?,
221199
as_string_array(part)?,
222200
as_string_array(key)?,
201+
handler_err,
223202
)
224203
}
225204
(DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => {
226205
process_parse_url::<_, _, _, StringViewArray>(
227206
as_string_view_array(url)?,
228207
as_string_view_array(part)?,
229208
as_string_view_array(key)?,
209+
handler_err,
230210
)
231211
}
232212
(DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => {
233213
process_parse_url::<_, _, _, LargeStringArray>(
234214
as_large_string_array(url)?,
235215
as_large_string_array(part)?,
236216
as_large_string_array(key)?,
217+
handler_err,
237218
)
238219
}
239220
_ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args),
@@ -253,20 +234,23 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result<ArrayRef> {
253234
as_string_array(url)?,
254235
as_string_array(part)?,
255236
&key,
237+
handler_err,
256238
)
257239
}
258240
(DataType::Utf8View, DataType::Utf8View) => {
259241
process_parse_url::<_, _, _, StringViewArray>(
260242
as_string_view_array(url)?,
261243
as_string_view_array(part)?,
262244
&key,
245+
handler_err,
263246
)
264247
}
265248
(DataType::LargeUtf8, DataType::LargeUtf8) => {
266249
process_parse_url::<_, _, _, LargeStringArray>(
267250
as_large_string_array(url)?,
268251
as_large_string_array(part)?,
269252
&key,
253+
handler_err,
270254
)
271255
}
272256
_ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args),
@@ -279,6 +263,7 @@ fn process_parse_url<'a, A, B, C, T>(
279263
url_array: &'a A,
280264
part_array: &'a B,
281265
key_array: &'a C,
266+
handle: impl Fn(Result<Option<String>>) -> Result<Option<String>>,
282267
) -> Result<ArrayRef>
283268
where
284269
&'a A: StringArrayType<'a>,
@@ -292,11 +277,156 @@ where
292277
.zip(key_array.iter())
293278
.map(|((url, part), key)| {
294279
if let (Some(url), Some(part), key) = (url, part, key) {
295-
ParseUrl::parse(url, part, key)
280+
handle(ParseUrl::parse(url, part, key))
296281
} else {
297282
Ok(None)
298283
}
299284
})
300285
.collect::<Result<T>>()
301286
.map(|array| Arc::new(array) as ArrayRef)
302287
}
288+
289+
#[cfg(test)]
290+
mod tests {
291+
use super::*;
292+
use arrow::array::{ArrayRef, Int32Array, StringArray};
293+
use datafusion_common::Result;
294+
use std::array::from_ref;
295+
use std::sync::Arc;
296+
297+
fn sa(vals: &[Option<&str>]) -> ArrayRef {
298+
Arc::new(StringArray::from(vals.to_vec())) as ArrayRef
299+
}
300+
301+
#[test]
302+
fn test_parse_host() -> Result<()> {
303+
let got = ParseUrl::parse("https://example.com/a?x=1", "HOST", None)?;
304+
assert_eq!(got, Some("example.com".to_string()));
305+
Ok(())
306+
}
307+
308+
#[test]
309+
fn test_parse_query_no_key_vs_with_key() -> Result<()> {
310+
let got_all = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", None)?;
311+
assert_eq!(got_all, Some("a=1&b=2".to_string()));
312+
313+
let got_a = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("a"))?;
314+
assert_eq!(got_a, Some("1".to_string()));
315+
316+
let got_c = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("c"))?;
317+
assert_eq!(got_c, None);
318+
Ok(())
319+
}
320+
321+
#[test]
322+
fn test_parse_ref_protocol_userinfo_file_authority() -> Result<()> {
323+
let url = "ftp://user:pwd@ftp.example.com:21/files?x=1#frag";
324+
assert_eq!(ParseUrl::parse(url, "REF", None)?, Some("frag".to_string()));
325+
assert_eq!(
326+
ParseUrl::parse(url, "PROTOCOL", None)?,
327+
Some("ftp".to_string())
328+
);
329+
assert_eq!(
330+
ParseUrl::parse(url, "USERINFO", None)?,
331+
Some("user:pwd".to_string())
332+
);
333+
assert_eq!(
334+
ParseUrl::parse(url, "FILE", None)?,
335+
Some("/files?x=1".to_string())
336+
);
337+
assert_eq!(
338+
ParseUrl::parse(url, "AUTHORITY", None)?,
339+
Some("user:pwd@ftp.example.com".to_string())
340+
);
341+
Ok(())
342+
}
343+
344+
#[test]
345+
fn test_parse_path_root_is_empty_string() -> Result<()> {
346+
let got = ParseUrl::parse("https://example.com/", "PATH", None)?;
347+
assert_eq!(got, Some("".to_string()));
348+
Ok(())
349+
}
350+
351+
#[test]
352+
fn test_parse_malformed_url_returns_error() -> Result<()> {
353+
let got = ParseUrl::parse("notaurl", "HOST", None)?;
354+
assert_eq!(got, None);
355+
Ok(())
356+
}
357+
358+
#[test]
359+
fn test_spark_utf8_two_args() -> Result<()> {
360+
let urls = sa(&[Some("https://example.com/a?x=1"), Some("https://ex.com/")]);
361+
let parts = sa(&[Some("HOST"), Some("PATH")]);
362+
363+
let out = spark_handled_parse_url(&[urls, parts], |x| x)?;
364+
let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
365+
366+
assert_eq!(out_sa.len(), 2);
367+
assert_eq!(out_sa.value(0), "example.com");
368+
assert_eq!(out_sa.value(1), "");
369+
Ok(())
370+
}
371+
372+
#[test]
373+
fn test_spark_utf8_three_args_query_key() -> Result<()> {
374+
let urls = sa(&[
375+
Some("https://example.com/a?x=1&y=2"),
376+
Some("https://ex.com/?a=1"),
377+
]);
378+
let parts = sa(&[Some("QUERY"), Some("QUERY")]);
379+
let keys = sa(&[Some("y"), Some("b")]);
380+
381+
let out = spark_handled_parse_url(&[urls, parts, keys], |x| x)?;
382+
let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
383+
384+
assert_eq!(out_sa.len(), 2);
385+
assert_eq!(out_sa.value(0), "2");
386+
assert!(out_sa.is_null(1));
387+
Ok(())
388+
}
389+
390+
#[test]
391+
fn test_spark_userinfo_and_nulls() -> Result<()> {
392+
let urls = sa(&[
393+
Some("ftp://user:pwd@ftp.example.com:21/files"),
394+
Some("https://example.com"),
395+
None,
396+
]);
397+
let parts = sa(&[Some("USERINFO"), Some("USERINFO"), Some("USERINFO")]);
398+
399+
let out = spark_handled_parse_url(&[urls, parts], |x| x)?;
400+
let out_sa = out.as_any().downcast_ref::<StringArray>().unwrap();
401+
402+
assert_eq!(out_sa.len(), 3);
403+
assert_eq!(out_sa.value(0), "user:pwd");
404+
assert!(out_sa.is_null(1));
405+
assert!(out_sa.is_null(2));
406+
Ok(())
407+
}
408+
409+
#[test]
410+
fn test_invalid_arg_count() {
411+
let urls = sa(&[Some("https://example.com")]);
412+
let err = spark_handled_parse_url(from_ref(&urls), |x| x).unwrap_err();
413+
assert!(format!("{err}").contains("expects 2 or 3 arguments"));
414+
415+
let parts = sa(&[Some("HOST")]);
416+
let keys = sa(&[Some("x")]);
417+
let err =
418+
spark_handled_parse_url(&[urls, parts, keys, sa(&[Some("extra")])], |x| x)
419+
.unwrap_err();
420+
assert!(format!("{err}").contains("expects 2 or 3 arguments"));
421+
}
422+
423+
#[test]
424+
fn test_non_string_types_error() {
425+
let urls = sa(&[Some("https://example.com")]);
426+
let bad_part = Arc::new(Int32Array::from(vec![1])) as ArrayRef;
427+
428+
let err = spark_handled_parse_url(&[urls, bad_part], |x| x).unwrap_err();
429+
let msg = format!("{err}");
430+
assert!(msg.contains("expects STRING arguments"));
431+
}
432+
}

0 commit comments

Comments
 (0)