Skip to content

Commit e326de4

Browse files
alahvanhovell
authored andcommitted
[SPARK-20798] GenerateUnsafeProjection should check if a value is null before calling the getter
## What changes were proposed in this pull request? GenerateUnsafeProjection.writeStructToBuffer() did not honor the assumption that the caller must make sure that a value is not null before using the getter. This could lead to various errors. This change fixes that behavior. Example of code generated before: ```scala /* 059 */ final UTF8String fieldName = value.getUTF8String(0); /* 060 */ if (value.isNullAt(0)) { /* 061 */ rowWriter1.setNullAt(0); /* 062 */ } else { /* 063 */ rowWriter1.write(0, fieldName); /* 064 */ } ``` Example of code generated now: ```scala /* 060 */ boolean isNull1 = value.isNullAt(0); /* 061 */ UTF8String value1 = isNull1 ? null : value.getUTF8String(0); /* 062 */ if (isNull1) { /* 063 */ rowWriter1.setNullAt(0); /* 064 */ } else { /* 065 */ rowWriter1.write(0, value1); /* 066 */ } ``` ## How was this patch tested? Adds GenerateUnsafeProjectionSuite. Author: Ala Luszczak <ala@databricks.com> Closes #18030 from ala/fix-generate-unsafe-projection. (cherry picked from commit ce8edb8) Signed-off-by: Herman van Hovell <hvanhovell@databricks.com>
1 parent e06d936 commit e326de4

File tree

3 files changed

+78
-4
lines changed

3 files changed

+78
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
5050
fieldTypes: Seq[DataType],
5151
bufferHolder: String): String = {
5252
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
53-
val fieldName = ctx.freshName("fieldName")
54-
val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};"
55-
val isNull = s"$input.isNullAt($i)"
56-
ExprCode(code, isNull, fieldName)
53+
val javaType = ctx.javaType(dt)
54+
val isNullVar = ctx.freshName("isNull")
55+
val valueVar = ctx.freshName("value")
56+
val defaultValue = ctx.defaultValue(dt)
57+
val readValue = ctx.getValue(input, dt, i.toString)
58+
val code =
59+
s"""
60+
boolean $isNullVar = $input.isNullAt($i);
61+
$javaType $valueVar = $isNullVar ? $defaultValue : $readValue;
62+
"""
63+
ExprCode(code, isNullVar, valueVar)
5764
}
5865

5966
s"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.codegen
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.BoundReference
23+
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
24+
import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType}
25+
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
26+
27+
class GenerateUnsafeProjectionSuite extends SparkFunSuite {
28+
test("Test unsafe projection string access pattern") {
29+
val dataType = (new StructType).add("a", StringType)
30+
val exprs = BoundReference(0, dataType, nullable = true) :: Nil
31+
val projection = GenerateUnsafeProjection.generate(exprs)
32+
val result = projection.apply(InternalRow(AlwaysNull))
33+
assert(!result.isNullAt(0))
34+
assert(result.getStruct(0, 1).isNullAt(0))
35+
}
36+
}
37+
38+
object AlwaysNull extends InternalRow {
39+
override def numFields: Int = 1
40+
override def setNullAt(i: Int): Unit = {}
41+
override def copy(): InternalRow = this
42+
override def anyNull: Boolean = true
43+
override def isNullAt(ordinal: Int): Boolean = true
44+
override def update(i: Int, value: Any): Unit = notSupported
45+
override def getBoolean(ordinal: Int): Boolean = notSupported
46+
override def getByte(ordinal: Int): Byte = notSupported
47+
override def getShort(ordinal: Int): Short = notSupported
48+
override def getInt(ordinal: Int): Int = notSupported
49+
override def getLong(ordinal: Int): Long = notSupported
50+
override def getFloat(ordinal: Int): Float = notSupported
51+
override def getDouble(ordinal: Int): Double = notSupported
52+
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported
53+
override def getUTF8String(ordinal: Int): UTF8String = notSupported
54+
override def getBinary(ordinal: Int): Array[Byte] = notSupported
55+
override def getInterval(ordinal: Int): CalendarInterval = notSupported
56+
override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported
57+
override def getArray(ordinal: Int): ArrayData = notSupported
58+
override def getMap(ordinal: Int): MapData = notSupported
59+
override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
60+
private def notSupported: Nothing = throw new UnsupportedOperationException
61+
}

sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,33 +198,39 @@ public boolean anyNull() {
198198

199199
@Override
200200
public Decimal getDecimal(int ordinal, int precision, int scale) {
201+
if (columns[ordinal].isNullAt(rowId)) return null;
201202
return columns[ordinal].getDecimal(rowId, precision, scale);
202203
}
203204

204205
@Override
205206
public UTF8String getUTF8String(int ordinal) {
207+
if (columns[ordinal].isNullAt(rowId)) return null;
206208
return columns[ordinal].getUTF8String(rowId);
207209
}
208210

209211
@Override
210212
public byte[] getBinary(int ordinal) {
213+
if (columns[ordinal].isNullAt(rowId)) return null;
211214
return columns[ordinal].getBinary(rowId);
212215
}
213216

214217
@Override
215218
public CalendarInterval getInterval(int ordinal) {
219+
if (columns[ordinal].isNullAt(rowId)) return null;
216220
final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
217221
final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
218222
return new CalendarInterval(months, microseconds);
219223
}
220224

221225
@Override
222226
public InternalRow getStruct(int ordinal, int numFields) {
227+
if (columns[ordinal].isNullAt(rowId)) return null;
223228
return columns[ordinal].getStruct(rowId);
224229
}
225230

226231
@Override
227232
public ArrayData getArray(int ordinal) {
233+
if (columns[ordinal].isNullAt(rowId)) return null;
228234
return columns[ordinal].getArray(rowId);
229235
}
230236

0 commit comments

Comments
 (0)