Skip to content

Commit 495e932

Browse files
author
Davies Liu
committed
support wider table with more than 1k columns for generated projections
1 parent fe12277 commit 495e932

File tree

6 files changed

+185
-122
lines changed

6 files changed

+185
-122
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions.codegen
1919

2020
import scala.collection.mutable
21+
import scala.collection.mutable.ArrayBuffer
2122
import scala.language.existentials
2223

2324
import com.google.common.cache.{CacheBuilder, CacheLoader}
@@ -265,6 +266,43 @@ class CodeGenContext {
265266
def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)
266267

267268
def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
269+
270+
/**
271+
* Splits the generated code of expressions into multiple functions, because function has
272+
* 64kb code size limit in JVM
273+
*/
274+
def splitExpressions(input: String, expressions: Seq[String]): String = {
275+
val blocks = new ArrayBuffer[String]()
276+
val blockBuilder = new StringBuilder()
277+
for (code <- expressions) {
278+
// We can't know how many byte code will be generated, so use the number of bytes as limit
279+
if (blockBuilder.length > 64 * 1000) {
280+
blocks.append(blockBuilder.toString())
281+
blockBuilder.clear()
282+
}
283+
blockBuilder.append(code)
284+
}
285+
blocks.append(blockBuilder.toString())
286+
287+
if (blocks.length == 1) {
288+
// inline execution if only one block
289+
blocks.head
290+
} else {
291+
val apply = freshName("apply")
292+
val functions = blocks.zipWithIndex.map { case (body, i) =>
293+
val name = s"${apply}_$i"
294+
val code = s"""
295+
|private void $name(InternalRow $input) {
296+
| $body
297+
|}
298+
""".stripMargin
299+
addNewFunction(name, code)
300+
name
301+
}
302+
303+
functions.map(name => s"$name($input);").mkString("\n")
304+
}
305+
}
268306
}
269307

270308
/**
@@ -289,15 +327,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
289327
protected def declareMutableStates(ctx: CodeGenContext): String = {
290328
ctx.mutableStates.map { case (javaType, variableName, _) =>
291329
s"private $javaType $variableName;"
292-
}.mkString
330+
}.mkString("\n")
293331
}
294332

295333
protected def initMutableStates(ctx: CodeGenContext): String = {
296-
ctx.mutableStates.map(_._3).mkString
334+
ctx.mutableStates.map(_._3).mkString("\n")
297335
}
298336

299337
protected def declareAddedFunctions(ctx: CodeGenContext): String = {
300-
ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString
338+
ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
301339
}
302340

303341
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
4040

4141
protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
4242
val ctx = newCodeGenContext()
43-
val projectionCode = expressions.zipWithIndex.map {
43+
val projectionCodes = expressions.zipWithIndex.map {
4444
case (NoOp, _) => ""
4545
case (e, i) =>
4646
val evaluationCode = e.gen(ctx)
@@ -65,35 +65,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
6565
"""
6666
}
6767
}
68-
// collect projections into blocks as function has 64kb codesize limit in JVM
69-
val projectionBlocks = new ArrayBuffer[String]()
70-
val blockBuilder = new StringBuilder()
71-
for (projection <- projectionCode) {
72-
if (blockBuilder.length > 16 * 1000) {
73-
projectionBlocks.append(blockBuilder.toString())
74-
blockBuilder.clear()
75-
}
76-
blockBuilder.append(projection)
77-
}
78-
projectionBlocks.append(blockBuilder.toString())
79-
80-
val (projectionFuns, projectionCalls) = {
81-
// inline execution if codesize limit was not broken
82-
if (projectionBlocks.length == 1) {
83-
("", projectionBlocks.head)
84-
} else {
85-
(
86-
projectionBlocks.zipWithIndex.map { case (body, i) =>
87-
s"""
88-
|private void apply$i(InternalRow i) {
89-
| $body
90-
|}
91-
""".stripMargin
92-
}.mkString,
93-
projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
94-
)
95-
}
96-
}
68+
val allProjections = ctx.splitExpressions("i", projectionCodes)
9769

9870
val code = s"""
9971
public Object generate($exprType[] expr) {
@@ -123,12 +95,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
12395
return (InternalRow) mutableRow;
12496
}
12597

126-
$projectionFuns
127-
12898
public Object apply(Object _i) {
12999
InternalRow i = (InternalRow) _i;
130-
$projectionCalls
131-
100+
$allProjections
132101
return mutableRow;
133102
}
134103
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,11 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
5555
${genUpdater(ctx, rowTerm, dt, i, colTerm)};
5656
}
5757
"""
58-
}.mkString("\n")
58+
}
59+
val allUpdates = ctx.splitExpressions(value, updates)
5960
s"""
6061
$genericMutableRowType $rowTerm = new $genericMutableRowType(${struct.fields.length});
61-
$updates
62+
$allUpdates
6263
$setter.update($ordinal, $rowTerm.copy());
6364
"""
6465
case _ =>
@@ -68,7 +69,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
6869

6970
protected def create(expressions: Seq[Expression]): Projection = {
7071
val ctx = newCodeGenContext()
71-
val projectionCode = expressions.zipWithIndex.map {
72+
val expressionCodes = expressions.zipWithIndex.map {
7273
case (NoOp, _) => ""
7374
case (e, i) =>
7475
val evaluationCode = e.gen(ctx)
@@ -81,36 +82,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
8182
}
8283
"""
8384
}
84-
// collect projections into blocks as function has 64kb codesize limit in JVM
85-
val projectionBlocks = new ArrayBuffer[String]()
86-
val blockBuilder = new StringBuilder()
87-
for (projection <- projectionCode) {
88-
if (blockBuilder.length > 16 * 1000) {
89-
projectionBlocks.append(blockBuilder.toString())
90-
blockBuilder.clear()
91-
}
92-
blockBuilder.append(projection)
93-
}
94-
projectionBlocks.append(blockBuilder.toString())
95-
96-
val (projectionFuns, projectionCalls) = {
97-
// inline it if we have only one block
98-
if (projectionBlocks.length == 1) {
99-
("", projectionBlocks.head)
100-
} else {
101-
(
102-
projectionBlocks.zipWithIndex.map { case (body, i) =>
103-
s"""
104-
|private void apply$i(InternalRow i) {
105-
| $body
106-
|}
107-
""".stripMargin
108-
}.mkString,
109-
projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n")
110-
)
111-
}
112-
}
113-
85+
val allExpressions = ctx.splitExpressions("i", expressionCodes)
11486
val code = s"""
11587
public Object generate($exprType[] expr) {
11688
return new SpecificSafeProjection(expr);
@@ -121,19 +93,17 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
12193
private $exprType[] expressions;
12294
private $mutableRowType mutableRow;
12395
${declareMutableStates(ctx)}
96+
${declareAddedFunctions(ctx)}
12497

12598
public SpecificSafeProjection($exprType[] expr) {
12699
expressions = expr;
127100
mutableRow = new $genericMutableRowType(${expressions.size});
128101
${initMutableStates(ctx)}
129102
}
130103

131-
$projectionFuns
132-
133104
public Object apply(Object _i) {
134105
InternalRow i = (InternalRow) _i;
135-
$projectionCalls
136-
106+
$allExpressions
137107
return mutableRow;
138108
}
139109
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 51 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
5656

5757
def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match {
5858
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
59-
s" + $DecimalWriter.getSize(${ev.primitive})"
59+
s"$DecimalWriter.getSize(${ev.primitive})"
6060
case StringType =>
61-
s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))"
61+
s"${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive})"
6262
case BinaryType =>
63-
s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))"
63+
s"${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive})"
6464
case CalendarIntervalType =>
65-
s" + (${ev.isNull} ? 0 : 16)"
65+
s"${ev.isNull} ? 0 : 16"
6666
case _: StructType =>
67-
s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))"
67+
s"${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive})"
6868
case _: ArrayType =>
69-
s" + (${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive}))"
69+
s"${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive})"
7070
case _: MapType =>
71-
s" + (${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive}))"
71+
s"${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive})"
7272
case _ => ""
7373
}
7474

@@ -125,64 +125,68 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
125125
*/
126126
private def createCodeForStruct(
127127
ctx: CodeGenContext,
128+
row: String,
128129
inputs: Seq[GeneratedExpressionCode],
129130
inputTypes: Seq[DataType]): GeneratedExpressionCode = {
130131

132+
val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length)
133+
131134
val output = ctx.freshName("convertedStruct")
132-
ctx.addMutableState("UnsafeRow", output, s"$output = new UnsafeRow();")
135+
ctx.addMutableState("UnsafeRow", output, s"this.$output = new UnsafeRow();")
133136
val buffer = ctx.freshName("buffer")
134-
ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
135-
val numBytes = ctx.freshName("numBytes")
137+
ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];")
136138
val cursor = ctx.freshName("cursor")
139+
ctx.addMutableState("int", cursor, s"this.$cursor = 0;")
140+
val tmp = ctx.freshName("tmpBuffer")
137141

138-
val convertedFields = inputTypes.zip(inputs).map { case (dt, input) =>
139-
createConvertCode(ctx, input, dt)
140-
}
141-
142-
val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length)
143-
val additionalSize = inputTypes.zip(convertedFields).map { case (dt, ev) =>
144-
genAdditionalSize(dt, ev)
145-
}.mkString("")
146-
147-
val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) =>
148-
val update = genFieldWriter(ctx, dt, ev, output, i, cursor)
149-
if (dt.isInstanceOf[DecimalType]) {
150-
// Can't call setNullAt() for DecimalType
142+
val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) =>
143+
val ev = createConvertCode(ctx, input, dt)
144+
val growBuffer = if (!UnsafeRow.isFixedLength(dt)) {
145+
val numBytes = ctx.freshName("numBytes")
151146
s"""
147+
int $numBytes = $cursor + (${genAdditionalSize(dt, ev)});
148+
if ($buffer.length < $numBytes) {
149+
// This will not happen frequently, because the buffer is re-used.
150+
byte[] $tmp = new byte[$numBytes * 3 / 2];
151+
System.arraycopy($buffer, 0, $tmp, 0, $buffer.length);
152+
$buffer = $tmp;
153+
}
154+
$output.pointTo($buffer, $PlatformDependent.BYTE_ARRAY_OFFSET,
155+
${inputTypes.length}, $numBytes);
156+
"""
157+
} else {
158+
""
159+
}
160+
val update = dt match {
161+
case dt: DecimalType if dt.precision > Decimal.MAX_LONG_DIGITS =>
162+
// Can't call setNullAt() for DecimalType
163+
s"""
152164
if (${ev.isNull}) {
153-
$cursor += $DecimalWriter.write($output, $i, $cursor, null);
165+
$cursor += $DecimalWriter.write($output, $i, $cursor, null);
154166
} else {
155-
$update;
167+
${genFieldWriter(ctx, dt, ev, output, i, cursor)};
156168
}
157169
"""
158-
} else {
159-
s"""
170+
case _ =>
171+
s"""
160172
if (${ev.isNull}) {
161173
$output.setNullAt($i);
162174
} else {
163-
$update;
175+
${genFieldWriter(ctx, dt, ev, output, i, cursor)};
164176
}
165177
"""
166178
}
167-
}.mkString("\n")
179+
s"""
180+
${ev.code}
181+
$growBuffer
182+
$update
183+
"""
184+
}
168185

169186
val code = s"""
170-
${convertedFields.map(_.code).mkString("\n")}
171-
172-
final int $numBytes = $fixedSize $additionalSize;
173-
if ($numBytes > $buffer.length) {
174-
$buffer = new byte[$numBytes];
175-
}
176-
177-
$output.pointTo(
178-
$buffer,
179-
$PlatformDependent.BYTE_ARRAY_OFFSET,
180-
${inputTypes.length},
181-
$numBytes);
182-
183-
int $cursor = $fixedSize;
184-
185-
$fieldWriters
187+
$cursor = $fixedSize;
188+
$output.pointTo($buffer, $PlatformDependent.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor);
189+
${ctx.splitExpressions(row, convertedFields)}
186190
"""
187191
GeneratedExpressionCode(code, "false", output)
188192
}
@@ -400,7 +404,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
400404
val fieldIsNull = s"$tmp.isNullAt($i)"
401405
GeneratedExpressionCode("", fieldIsNull, getFieldCode)
402406
}
403-
val converter = createCodeForStruct(ctx, fieldEvals, fieldTypes)
407+
val converter = createCodeForStruct(ctx, tmp, fieldEvals, fieldTypes)
404408
val code = s"""
405409
${input.code}
406410
UnsafeRow $output = null;
@@ -427,7 +431,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
427431
def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = {
428432
val exprEvals = expressions.map(e => e.gen(ctx))
429433
val exprTypes = expressions.map(_.dataType)
430-
createCodeForStruct(ctx, exprEvals, exprTypes)
434+
createCodeForStruct(ctx, "i", exprEvals, exprTypes)
431435
}
432436

433437
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
155155
|$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));
156156
""".stripMargin
157157
}
158-
}.mkString
158+
}
159159

160160
// ------------------------ Finally, put everything together --------------------------- //
161161
val code = s"""

0 commit comments

Comments
 (0)