Skip to content

Commit a05835b

Browse files
mengxrliancheng
authored andcommitted
[SPARK-6542][SQL] add CreateStruct
Similar to `CreateArray`, we can add `CreateStruct` to create nested columns. marmbrus Author: Xiangrui Meng <meng@databricks.com> Closes #5195 from mengxr/SPARK-6542 and squashes the following commits: 3795c57 [Xiangrui Meng] update error message ae7ac3e [Xiangrui Meng] move unit test to a separate suite 85dd559 [Xiangrui Meng] use NamedExpr c78e31a [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-6542 85f3106 [Xiangrui Meng] add CreateStruct
1 parent 314afd0 commit a05835b

File tree

3 files changed

+73
-23
lines changed

3 files changed

+73
-23
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,12 @@ class Analyzer(catalog: Catalog,
212212
case o => o :: Nil
213213
}
214214
Alias(c.copy(children = expandedArgs), name)() :: Nil
215+
case Alias(c @ CreateStruct(args), name) if containsStar(args) =>
216+
val expandedArgs = args.flatMap {
217+
case s: Star => s.expand(child.output, resolver)
218+
case o => o :: Nil
219+
}
220+
Alias(c.copy(children = expandedArgs), name)() :: Nil
215221
case o => o :: Nil
216222
},
217223
child)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, co
120120
case class CreateArray(children: Seq[Expression]) extends Expression {
121121
override type EvaluatedType = Any
122122

123-
override def foldable: Boolean = !children.exists(!_.foldable)
123+
override def foldable: Boolean = children.forall(_.foldable)
124124

125125
lazy val childTypes = children.map(_.dataType).distinct
126126

@@ -142,3 +142,30 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
142142

143143
override def toString: String = s"Array(${children.mkString(",")})"
144144
}
145+
146+
/**
147+
* Returns a Row containing the evaluation of all children expressions.
148+
* TODO: [[CreateStruct]] does not support codegen.
149+
*/
150+
case class CreateStruct(children: Seq[NamedExpression]) extends Expression {
151+
override type EvaluatedType = Row
152+
153+
override def foldable: Boolean = children.forall(_.foldable)
154+
155+
override lazy val resolved: Boolean = childrenResolved
156+
157+
override lazy val dataType: StructType = {
158+
assert(resolved,
159+
s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.")
160+
val fields = children.map { child =>
161+
StructField(child.name, child.dataType, child.nullable, child.metadata)
162+
}
163+
StructType(fields)
164+
}
165+
166+
override def nullable: Boolean = false
167+
168+
override def eval(input: Row): EvaluatedType = {
169+
Row(children.map(_.eval(input)): _*)
170+
}
171+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,34 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
3030
import org.apache.spark.sql.types._
3131

3232

33-
class ExpressionEvaluationSuite extends FunSuite {
33+
class ExpressionEvaluationBaseSuite extends FunSuite {
34+
35+
def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = {
36+
expression.eval(inputRow)
37+
}
38+
39+
def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = {
40+
val actual = try evaluate(expression, inputRow) catch {
41+
case e: Exception => fail(s"Exception evaluating $expression", e)
42+
}
43+
if(actual != expected) {
44+
val input = if(inputRow == EmptyRow) "" else s", input: $inputRow"
45+
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
46+
}
47+
}
48+
49+
def checkDoubleEvaluation(
50+
expression: Expression,
51+
expected: Spread[Double],
52+
inputRow: Row = EmptyRow): Unit = {
53+
val actual = try evaluate(expression, inputRow) catch {
54+
case e: Exception => fail(s"Exception evaluating $expression", e)
55+
}
56+
actual.asInstanceOf[Double] shouldBe expected
57+
}
58+
}
59+
60+
class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
3461

3562
test("literals") {
3663
checkEvaluation(Literal(1), 1)
@@ -134,27 +161,6 @@ class ExpressionEvaluationSuite extends FunSuite {
134161
}
135162
}
136163

137-
def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = {
138-
expression.eval(inputRow)
139-
}
140-
141-
def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = {
142-
val actual = try evaluate(expression, inputRow) catch {
143-
case e: Exception => fail(s"Exception evaluating $expression", e)
144-
}
145-
if(actual != expected) {
146-
val input = if(inputRow == EmptyRow) "" else s", input: $inputRow"
147-
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
148-
}
149-
}
150-
151-
def checkDoubleEvaluation(expression: Expression, expected: Spread[Double], inputRow: Row = EmptyRow): Unit = {
152-
val actual = try evaluate(expression, inputRow) catch {
153-
case e: Exception => fail(s"Exception evaluating $expression", e)
154-
}
155-
actual.asInstanceOf[Double] shouldBe expected
156-
}
157-
158164
test("IN") {
159165
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
160166
checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
@@ -1081,3 +1087,14 @@ class ExpressionEvaluationSuite extends FunSuite {
10811087
checkEvaluation(~c1, -2, row)
10821088
}
10831089
}
1090+
1091+
// TODO: Make the tests work with codegen.
1092+
class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite {
1093+
1094+
test("CreateStruct") {
1095+
val row = Row(1, 2, 3)
1096+
val c1 = 'a.int.at(0).as("a")
1097+
val c3 = 'c.int.at(2).as("c")
1098+
checkEvaluation(CreateStruct(Seq(c1, c3)), Row(1, 3), row)
1099+
}
1100+
}

0 commit comments

Comments
 (0)