Skip to content

Commit 0e41a26

Browse files
committed
Fix GenerateUnsafeProjection.
1 parent a4cbf26 commit 0e41a26

File tree

3 files changed

+76
-4
lines changed

3 files changed

+76
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,15 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
5050
fieldTypes: Seq[DataType],
5151
bufferHolder: String): String = {
5252
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
53-
val fieldName = ctx.freshName("fieldName")
54-
val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};"
55-
val isNull = s"$input.isNullAt($i)"
56-
ExprCode(code, isNull, fieldName)
53+
val javaType = ctx.javaType(dt)
54+
val isNullVar = ctx.freshName("isNull")
55+
val valueVar = ctx.freshName("value")
56+
val defaultValue = ctx.defaultValue(dt)
57+
val readValue = ctx.getValue(input, dt, i.toString)
58+
val code = s"""
59+
boolean $isNullVar = $input.isNullAt($i);
60+
$javaType $valueVar = $isNullVar ? $defaultValue : $readValue;"""
61+
ExprCode(code, isNullVar, valueVar)
5762
}
5863

5964
s"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.codegen
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.BoundReference
23+
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
24+
import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType}
25+
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
26+
27+
class GenerateUnsafeProjectionSuite extends SparkFunSuite {
28+
test("Test unsafe projection string access pattern") {
29+
val dataType = (new StructType).add("a", StringType)
30+
val exprs = BoundReference(0, dataType, nullable = true) :: Nil
31+
val projection = GenerateUnsafeProjection.generate(exprs)
32+
val result = projection.apply(InternalRow(AlwaysNull))
33+
assert(!result.isNullAt(0))
34+
assert(result.getStruct(0, 1).isNullAt(0))
35+
}
36+
}
37+
38+
object AlwaysNull extends InternalRow {
39+
override def numFields: Int = 1
40+
override def setNullAt(i: Int): Unit = {}
41+
override def copy(): InternalRow = this
42+
override def anyNull: Boolean = true
43+
override def isNullAt(ordinal: Int): Boolean = true
44+
override def update(i: Int, value: Any): Unit = notSupported
45+
override def getBoolean(ordinal: Int): Boolean = notSupported
46+
override def getByte(ordinal: Int): Byte = notSupported
47+
override def getShort(ordinal: Int): Short = notSupported
48+
override def getInt(ordinal: Int): Int = notSupported
49+
override def getLong(ordinal: Int): Long = notSupported
50+
override def getFloat(ordinal: Int): Float = notSupported
51+
override def getDouble(ordinal: Int): Double = notSupported
52+
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported
53+
override def getUTF8String(ordinal: Int): UTF8String = notSupported
54+
override def getBinary(ordinal: Int): Array[Byte] = notSupported
55+
override def getInterval(ordinal: Int): CalendarInterval = notSupported
56+
override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported
57+
override def getArray(ordinal: Int): ArrayData = notSupported
58+
override def getMap(ordinal: Int): MapData = notSupported
59+
override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
60+
private def notSupported: Nothing = throw new UnsupportedOperationException
61+
}

sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,33 +198,39 @@ public boolean anyNull() {
198198

199199
@Override
200200
public Decimal getDecimal(int ordinal, int precision, int scale) {
201+
if (columns[ordinal].isNullAt(rowId)) return null;
201202
return columns[ordinal].getDecimal(rowId, precision, scale);
202203
}
203204

204205
@Override
205206
public UTF8String getUTF8String(int ordinal) {
207+
if (columns[ordinal].isNullAt(rowId)) return null;
206208
return columns[ordinal].getUTF8String(rowId);
207209
}
208210

209211
@Override
210212
public byte[] getBinary(int ordinal) {
213+
if (columns[ordinal].isNullAt(rowId)) return null;
211214
return columns[ordinal].getBinary(rowId);
212215
}
213216

214217
@Override
215218
public CalendarInterval getInterval(int ordinal) {
219+
if (columns[ordinal].isNullAt(rowId)) return null;
216220
final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
217221
final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
218222
return new CalendarInterval(months, microseconds);
219223
}
220224

221225
@Override
222226
public InternalRow getStruct(int ordinal, int numFields) {
227+
if (columns[ordinal].isNullAt(rowId)) return null;
223228
return columns[ordinal].getStruct(rowId);
224229
}
225230

226231
@Override
227232
public ArrayData getArray(int ordinal) {
233+
if (columns[ordinal].isNullAt(rowId)) return null;
228234
return columns[ordinal].getArray(rowId);
229235
}
230236

0 commit comments

Comments
 (0)