Skip to content

Commit 7ca7a63

Browse files
hvanhovellrxin
authored andcommitted
[SPARK-15214][SQL] Code-generation for Generate
## What changes were proposed in this pull request? This PR adds code generation to `Generate`. It supports two code paths: - General `TraversableOnce` based iteration. This used for regular `Generator` (code generation supporting) expressions. This code path expects the expression to return a `TraversableOnce[InternalRow]` and it will iterate over the returned collection. This PR adds code generation for the `stack` generator. - Specialized `ArrayData/MapData` based iteration. This is used for the `explode`, `posexplode` & `inline` functions and operates directly on the `ArrayData`/`MapData` result that the child of the generator returns. ### Benchmarks I have added some benchmarks and it seems we can create a nice speedup for explode: #### Environment ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 Intel(R) Core(TM) i7-4980HQ CPU 2.80GHz ``` #### Explode Array ##### Before ``` generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode array wholestage off 7377 / 7607 2.3 439.7 1.0X generate explode array wholestage on 6055 / 6086 2.8 360.9 1.2X ``` ##### After ``` generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode array wholestage off 7432 / 7696 2.3 443.0 1.0X generate explode array wholestage on 631 / 646 26.6 37.6 11.8X ``` #### Explode Map ##### Before ``` generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode map wholestage off 12792 / 12848 1.3 762.5 1.0X generate explode map wholestage on 11181 / 11237 1.5 666.5 1.1X ``` ##### After ``` generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode map wholestage off 10949 / 10972 1.5 652.6 1.0X generate explode map wholestage on 870 / 913 19.3 51.9 12.6X ``` #### Posexplode ##### Before ``` generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate posexplode array wholestage off 7547 / 7580 2.2 449.8 1.0X generate posexplode array wholestage on 5786 / 5838 2.9 344.9 1.3X ``` ##### After ``` generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate posexplode array wholestage off 7535 / 7548 2.2 449.1 1.0X generate posexplode array wholestage on 620 / 624 27.1 37.0 12.1X ``` #### Inline ##### Before ``` generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate inline array wholestage off 6935 / 6978 2.4 413.3 1.0X generate inline array wholestage on 6360 / 6400 2.6 379.1 1.1X ``` ##### After ``` generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate inline array wholestage off 6940 / 6966 2.4 413.6 1.0X generate inline array wholestage on 1002 / 1012 16.7 59.7 6.9X ``` #### Stack ##### Before ``` generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate stack wholestage off 12980 / 13104 1.3 773.7 1.0X generate stack wholestage on 11566 / 11580 1.5 689.4 1.1X ``` ##### After ``` generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate stack wholestage off 12875 / 12949 1.3 767.4 1.0X generate stack wholestage on 840 / 845 20.0 50.0 15.3X ``` ## How was this patch tested? Existing tests. Author: Herman van Hovell <hvanhovell@databricks.com> Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #13065 from hvanhovell/SPARK-15214.
1 parent a64f25d commit 7ca7a63

File tree

7 files changed

+463
-37
lines changed

7 files changed

+463
-37
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 90 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.spark.sql.Row
2123
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2224
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
23-
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
25+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
2426
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
2527
import org.apache.spark.sql.types._
2628

@@ -60,6 +62,26 @@ trait Generator extends Expression {
6062
* rows can be made here.
6163
*/
6264
def terminate(): TraversableOnce[InternalRow] = Nil
65+
66+
/**
67+
* Check if this generator supports code generation.
68+
*/
69+
def supportCodegen: Boolean = !isInstanceOf[CodegenFallback]
70+
}
71+
72+
/**
73+
* A collection producing [[Generator]]. This trait provides a different path for code generation,
74+
* by allowing code generation to return either an [[ArrayData]] or a [[MapData]] object.
75+
*/
76+
trait CollectionGenerator extends Generator {
77+
/** The position of an element within the collection should also be returned. */
78+
def position: Boolean
79+
80+
/** Rows will be inlined during generation. */
81+
def inline: Boolean
82+
83+
/** The type of the returned collection object. */
84+
def collectionType: DataType = dataType
6385
}
6486

6587
/**
@@ -77,7 +99,9 @@ case class UserDefinedGenerator(
7799
private def initializeConverters(): Unit = {
78100
inputRow = new InterpretedProjection(children)
79101
convertToScala = {
80-
val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
102+
val inputSchema = StructType(children.map { e =>
103+
StructField(e.simpleString, e.dataType, nullable = true)
104+
})
81105
CatalystTypeConverters.createToScalaConverter(inputSchema)
82106
}.asInstanceOf[InternalRow => Row]
83107
}
@@ -109,8 +133,7 @@ case class UserDefinedGenerator(
109133
1 2
110134
3 NULL
111135
""")
112-
case class Stack(children: Seq[Expression])
113-
extends Expression with Generator with CodegenFallback {
136+
case class Stack(children: Seq[Expression]) extends Generator {
114137

115138
private lazy val numRows = children.head.eval().asInstanceOf[Int]
116139
private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt
@@ -149,29 +172,58 @@ case class Stack(children: Seq[Expression])
149172
InternalRow(fields: _*)
150173
}
151174
}
175+
176+
177+
/**
178+
* Only support code generation when stack produces 50 rows or less.
179+
*/
180+
override def supportCodegen: Boolean = numRows <= 50
181+
182+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
183+
// Rows - we write these into an array.
184+
val rowData = ctx.freshName("rows")
185+
ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];")
186+
val values = children.tail
187+
val dataTypes = values.take(numFields).map(_.dataType)
188+
val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row =>
189+
val fields = Seq.tabulate(numFields) { col =>
190+
val index = row * numFields + col
191+
if (index < values.length) values(index) else Literal(null, dataTypes(col))
192+
}
193+
val eval = CreateStruct(fields).genCode(ctx)
194+
s"${eval.code}\nthis.$rowData[$row] = ${eval.value};"
195+
})
196+
197+
// Create the collection.
198+
val wrapperClass = classOf[mutable.WrappedArray[_]].getName
199+
ctx.addMutableState(
200+
s"$wrapperClass<InternalRow>",
201+
ev.value,
202+
s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);")
203+
ev.copy(code = code, isNull = "false")
204+
}
152205
}
153206

154207
/**
155-
* A base class for Explode and PosExplode
208+
* A base class for [[Explode]] and [[PosExplode]].
156209
*/
157-
abstract class ExplodeBase(child: Expression, position: Boolean)
158-
extends UnaryExpression with Generator with CodegenFallback with Serializable {
210+
abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with Serializable {
211+
override val inline: Boolean = false
159212

160-
override def checkInputDataTypes(): TypeCheckResult = {
161-
if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) {
213+
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
214+
case _: ArrayType | _: MapType =>
162215
TypeCheckResult.TypeCheckSuccess
163-
} else {
216+
case _ =>
164217
TypeCheckResult.TypeCheckFailure(
165218
s"input to function explode should be array or map type, not ${child.dataType}")
166-
}
167219
}
168220

169221
// hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
170222
override def elementSchema: StructType = child.dataType match {
171223
case ArrayType(et, containsNull) =>
172224
if (position) {
173225
new StructType()
174-
.add("pos", IntegerType, false)
226+
.add("pos", IntegerType, nullable = false)
175227
.add("col", et, containsNull)
176228
} else {
177229
new StructType()
@@ -180,12 +232,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
180232
case MapType(kt, vt, valueContainsNull) =>
181233
if (position) {
182234
new StructType()
183-
.add("pos", IntegerType, false)
184-
.add("key", kt, false)
235+
.add("pos", IntegerType, nullable = false)
236+
.add("key", kt, nullable = false)
185237
.add("value", vt, valueContainsNull)
186238
} else {
187239
new StructType()
188-
.add("key", kt, false)
240+
.add("key", kt, nullable = false)
189241
.add("value", vt, valueContainsNull)
190242
}
191243
}
@@ -218,6 +270,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
218270
}
219271
}
220272
}
273+
274+
override def collectionType: DataType = child.dataType
275+
276+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
277+
child.genCode(ctx)
278+
}
221279
}
222280

223281
/**
@@ -239,7 +297,9 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
239297
20
240298
""")
241299
// scalastyle:on line.size.limit
242-
case class Explode(child: Expression) extends ExplodeBase(child, position = false)
300+
case class Explode(child: Expression) extends ExplodeBase {
301+
override val position: Boolean = false
302+
}
243303

244304
/**
245305
* Given an input array produces a sequence of rows for each position and value in the array.
@@ -260,7 +320,9 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals
260320
1 20
261321
""")
262322
// scalastyle:on line.size.limit
263-
case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)
323+
case class PosExplode(child: Expression) extends ExplodeBase {
324+
override val position = true
325+
}
264326

265327
/**
266328
* Explodes an array of structs into a table.
@@ -273,20 +335,24 @@ case class PosExplode(child: Expression) extends ExplodeBase(child, position = t
273335
1 a
274336
2 b
275337
""")
276-
case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
338+
case class Inline(child: Expression) extends UnaryExpression with CollectionGenerator {
339+
override val inline: Boolean = true
340+
override val position: Boolean = false
277341

278342
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
279-
case ArrayType(et, _) if et.isInstanceOf[StructType] =>
343+
case ArrayType(st: StructType, _) =>
280344
TypeCheckResult.TypeCheckSuccess
281345
case _ =>
282346
TypeCheckResult.TypeCheckFailure(
283347
s"input to function $prettyName should be array of struct type, not ${child.dataType}")
284348
}
285349

286350
override def elementSchema: StructType = child.dataType match {
287-
case ArrayType(et : StructType, _) => et
351+
case ArrayType(st: StructType, _) => st
288352
}
289353

354+
override def collectionType: DataType = child.dataType
355+
290356
private lazy val numFields = elementSchema.fields.length
291357

292358
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
@@ -298,4 +364,8 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with
298364
yield inputArray.getStruct(i, numFields)
299365
}
300366
}
367+
368+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
369+
child.genCode(ctx)
370+
}
301371
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
package org.apache.spark.sql.catalyst.expressions
1818

1919
import org.apache.spark.SparkFunSuite
20-
import org.apache.spark.sql.types.IntegerType
20+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
21+
import org.apache.spark.sql.types.{DataType, IntegerType}
2122

2223
class SubexpressionEliminationSuite extends SparkFunSuite {
2324
test("Semantic equals and hash") {
@@ -162,13 +163,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
162163
test("Children of CodegenFallback") {
163164
val one = Literal(1)
164165
val two = Add(one, one)
165-
val explode = Explode(two)
166-
val add = Add(two, explode)
166+
val fallback = CodegenFallbackExpression(two)
167+
val add = Add(two, fallback)
167168

168-
var equivalence = new EquivalentExpressions
169+
val equivalence = new EquivalentExpressions
169170
equivalence.addExprTree(add, true)
170-
// the `two` inside `explode` should not be added
171+
// the `two` inside `fallback` should not be added
171172
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
172173
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
173174
}
174175
}
176+
177+
case class CodegenFallbackExpression(child: Expression)
178+
extends UnaryExpression with CodegenFallback {
179+
override def dataType: DataType = child.dataType
180+
}

0 commit comments

Comments
 (0)