Skip to content

Commit 417d111

Browse files
scwfmarmbrus
authored andcommitted
[SPARK-5367][SQL] Support star expression in udfs
A follow up for #4163: support `select array(key, *) from src` Since array(key, *) will not go into this case ``` case Alias(f UnresolvedFunction(_, args), name) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } ``` here added a case to cover the corner case of array. /cc liancheng Author: wangfei <wangfei1@huawei.com> Author: scwf <wangfei1@huawei.com> Closes #4353 from scwf/udf-star1 and squashes the following commits: 4350d17 [wangfei] minor fix a7cd191 [wangfei] minor fix 0942fb1 [wangfei] follow up: support select array(key, *) from src 6ae00db [wangfei] also fix problem with array da1da09 [scwf] minor fix f87b5f9 [scwf] added test case 587bf7e [wangfei] compile fix eb93c16 [wangfei] fix star resolve issue in udf
1 parent 424cb69 commit 417d111

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,12 @@ class Analyzer(catalog: Catalog,
257257
case o => o :: Nil
258258
}
259259
Alias(child = f.copy(children = expandedArgs), name)() :: Nil
260+
case Alias(c @ CreateArray(args), name) if containsStar(args) =>
261+
val expandedArgs = args.flatMap {
262+
case s: Star => s.expand(child.output, resolver)
263+
case o => o :: Nil
264+
}
265+
Alias(c.copy(children = expandedArgs), name)() :: Nil
260266
case o => o :: Nil
261267
},
262268
child)

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
620620
test("SPARK-5367: resolve star expression in udf") {
621621
assert(sql("select concat(*) from src limit 5").collect().size == 5)
622622
assert(sql("select array(*) from src limit 5").collect().size == 5)
623+
assert(sql("select concat(key, *) from src limit 5").collect().size == 5)
624+
assert(sql("select array(key, *) from src limit 5").collect().size == 5)
623625
}
624626

625627
test("Query Hive native command execution result") {

0 commit comments

Comments
 (0)