Skip to content

Commit 07514a1

Browse files
committed
[SPARK-29503][SQL] Copy nested unsafe data in return value of lambda function of MapObject
1 parent ab92e17 commit 07514a1

File tree

10 files changed

+205
-7
lines changed

10 files changed

+205
-7
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ public UnsafeArrayData copy() {
365365
return arrayCopy;
366366
}
367367

368+
@Override
369+
public ArrayData copyUnsafeData(ArrayType dataType) {
370+
return copy();
371+
}
372+
368373
@Override
369374
public boolean[] toBooleanArray() {
370375
boolean[] values = new boolean[numElements];

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import com.esotericsoftware.kryo.io.Output;
3030

3131
import org.apache.spark.sql.catalyst.util.MapData;
32+
import org.apache.spark.sql.types.MapType;
3233
import org.apache.spark.unsafe.Platform;
3334

3435
import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
@@ -135,6 +136,11 @@ public UnsafeMapData copy() {
135136
return mapCopy;
136137
}
137138

139+
@Override
140+
public MapData copyUnsafeData(MapType dataType) {
141+
return copy();
142+
}
143+
138144
@Override
139145
public void writeExternal(ObjectOutput out) throws IOException {
140146
byte[] bytes = UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes);

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,11 @@ public UnsafeRow copy() {
468468
return rowCopy;
469469
}
470470

471+
@Override
472+
public InternalRow copyUnsafeData(StructType dataType) {
473+
return copy();
474+
}
475+
471476
/**
472477
* Creates an empty UnsafeRow from a byte array with specified numBytes and numFields.
473478
* The returned row is invalid until we call copyFrom on it.

sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.vectorized;
1919

2020
import org.apache.spark.sql.catalyst.util.MapData;
21+
import org.apache.spark.sql.types.MapType;
2122

2223
/**
2324
* Map abstraction in {@link ColumnVector}.
@@ -50,4 +51,9 @@ public ColumnarArray valueArray() {
5051
public ColumnarMap copy() {
5152
throw new UnsupportedOperationException();
5253
}
54+
55+
@Override
56+
public MapData copyUnsafeData(MapType dataType) {
57+
throw new UnsupportedOperationException();
58+
}
5359
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,53 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
6363
*/
6464
def copy(): InternalRow
6565

66+
def copyUnsafeData(dataTypeJson: String): InternalRow = {
67+
val dataType = DataType.fromJson(dataTypeJson)
68+
assert(dataType.isInstanceOf[StructType])
69+
copyUnsafeData(dataType.asInstanceOf[StructType])
70+
}
71+
72+
def copyUnsafeData(dataType: StructType): InternalRow = {
73+
def updateRetIfNecessary(
74+
ret: InternalRow,
75+
field: AnyRef,
76+
newField: AnyRef,
77+
idx: Int): InternalRow = {
78+
var newRet: InternalRow = ret
79+
if (field.ne(newField)) {
80+
if (newRet == null) newRet = this.copy()
81+
if (newField != null) {
82+
newRet.update(idx, newField)
83+
} else {
84+
newRet.setNullAt(idx)
85+
}
86+
}
87+
newRet
88+
}
89+
90+
var ret: InternalRow = null
91+
dataType.map(_.dataType).zipWithIndex.foreach {
92+
case (ty: StructType, idx) =>
93+
val field = getStruct(idx, ty.size)
94+
val newField = field.copyUnsafeData(ty)
95+
ret = updateRetIfNecessary(ret, field, newField, idx)
96+
97+
case (ty: ArrayType, idx) =>
98+
val field = getArray(idx)
99+
val newField = field.copyUnsafeData(ty)
100+
ret = updateRetIfNecessary(ret, field, newField, idx)
101+
102+
case (ty: MapType, idx) =>
103+
val field = getMap(idx)
104+
val newField = field.copyUnsafeData(ty)
105+
ret = updateRetIfNecessary(ret, field, newField, idx)
106+
107+
case _ =>
108+
}
109+
110+
if (ret != null) ret else this
111+
}
112+
66113
/** Returns true if there are any NULL values in this row. */
67114
def anyNull: Boolean = {
68115
val len = numFields

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import scala.collection.mutable.Builder
2424
import scala.reflect.ClassTag
2525
import scala.util.Try
2626

27+
import org.apache.commons.text.StringEscapeUtils
28+
2729
import org.apache.spark.{SparkConf, SparkEnv}
2830
import org.apache.spark.serializer._
2931
import org.apache.spark.sql.Row
@@ -885,13 +887,16 @@ case class MapObjects private(
885887
)
886888
}
887889

888-
// Make a copy of the data if it's unsafe-backed
889-
def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
890-
s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value"
890+
// Make a copy of the unsafe data if the result contains any
891+
def makeCopyUnsafeData(dataType: DataType, value: String) = {
892+
val typeToJson = StringEscapeUtils.escapeJava(StringEscapeUtils.escapeJson(dataType.json))
893+
s"""${value}.copyUnsafeData("${typeToJson}")"""
894+
}
895+
891896
val genFunctionValue: String = lambdaFunction.dataType match {
892-
case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value)
893-
case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
894-
case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
897+
case ty: StructType => makeCopyUnsafeData(ty, genFunction.value)
898+
case ty: ArrayType => makeCopyUnsafeData(ty, genFunction.value)
899+
case ty: MapType => makeCopyUnsafeData(ty, genFunction.value)
895900
case _ => genFunction.value
896901
}
897902

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util
1919

2020
import java.util.{Map => JavaMap}
2121

22+
import org.apache.spark.sql.types.{ArrayType, MapType}
23+
2224
/**
2325
* A simple `MapData` implementation which is backed by 2 arrays.
2426
*
@@ -35,6 +37,22 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte
3537
override def toString: String = {
3638
s"keys: $keyArray, values: $valueArray"
3739
}
40+
41+
override def copyUnsafeData(dataType: MapType): MapData = {
42+
val keyType = dataType.keyType
43+
val valueType = dataType.valueType
44+
45+
val keyArr = keyArray
46+
val valueArr = valueArray
47+
val newKeyArray = keyArr.copyUnsafeData(ArrayType(keyType))
48+
val newValueArray = valueArr.copyUnsafeData(ArrayType(valueType))
49+
50+
if (keyArr.ne(newKeyArray) || valueArr.ne(newValueArray)) {
51+
new ArrayBasedMapData(newKeyArray, newValueArray)
52+
} else {
53+
this
54+
}
55+
}
3856
}
3957

4058
object ArrayBasedMapData {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,69 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
6868

6969
def copy(): ArrayData
7070

71+
def copyUnsafeData(dataTypeJson: String): ArrayData = {
72+
val dataType = DataType.fromJson(dataTypeJson)
73+
assert(dataType.isInstanceOf[ArrayType])
74+
copyUnsafeData(dataType.asInstanceOf[ArrayType])
75+
}
76+
77+
def copyUnsafeData(dataType: ArrayType): ArrayData = {
78+
def updateRetIfNecessary(
79+
ret: ArrayData,
80+
field: AnyRef,
81+
newField: AnyRef,
82+
idx: Int): ArrayData = {
83+
var newRet: ArrayData = ret
84+
if (field.ne(newField)) {
85+
if (newRet == null) newRet = this.copy()
86+
if (newField != null) {
87+
newRet.update(idx, newField)
88+
} else {
89+
newRet.setNullAt(idx)
90+
}
91+
}
92+
newRet
93+
}
94+
95+
var ret: ArrayData = null
96+
dataType.elementType match {
97+
case ty: StructType =>
98+
val len = numElements()
99+
var i = 0
100+
while (i < len) {
101+
val field = getStruct(i, ty.size)
102+
val newField = field.copyUnsafeData(ty)
103+
ret = updateRetIfNecessary(ret, field, newField, i)
104+
i += 1
105+
}
106+
if (ret != null) ret else this
107+
108+
case ty: ArrayType =>
109+
val len = numElements()
110+
var i = 0
111+
while (i < len) {
112+
val field = getArray(i)
113+
val newField = field.copyUnsafeData(ty)
114+
ret = updateRetIfNecessary(ret, field, newField, i)
115+
i += 1
116+
}
117+
if (ret != null) ret else this
118+
119+
case ty: MapType =>
120+
val len = numElements()
121+
var i = 0
122+
while (i < len) {
123+
val field = getMap(i)
124+
val newField = field.copyUnsafeData(ty)
125+
ret = updateRetIfNecessary(ret, field, newField, i)
126+
i += 1
127+
}
128+
if (ret != null) ret else this
129+
130+
case _ => this
131+
}
132+
}
133+
71134
def array: Array[Any]
72135

73136
def toSeq[T](dataType: DataType): IndexedSeq[T] =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MapData.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.util
1919

20-
import org.apache.spark.sql.types.DataType
20+
import org.apache.spark.sql.catalyst.InternalRow
21+
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
2122

2223
/**
2324
* This is an internal data representation for map type in Spark SQL. This should not implement
@@ -34,6 +35,14 @@ abstract class MapData extends Serializable {
3435

3536
def copy(): MapData
3637

38+
def copyUnsafeData(dataTypeJson: String): MapData = {
39+
val dataType = DataType.fromJson(dataTypeJson)
40+
assert(dataType.isInstanceOf[MapType])
41+
copyUnsafeData(dataType.asInstanceOf[MapType])
42+
}
43+
44+
def copyUnsafeData(dataType: MapType): MapData
45+
3746
def foreach(keyType: DataType, valueType: DataType, f: (Any, Any) => Unit): Unit = {
3847
val length = numElements()
3948
val keys = keyArray()

sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,16 @@
1717

1818
package org.apache.spark.sql
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.spark.sql.catalyst.DefinedByConstructorParams
23+
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericRowWithSchema}
24+
import org.apache.spark.sql.catalyst.expressions.objects.MapObjects
2125
import org.apache.spark.sql.functions._
26+
import org.apache.spark.sql.internal.SQLConf
2227
import org.apache.spark.sql.test.SharedSparkSession
28+
import org.apache.spark.sql.types.ArrayType
29+
2330

2431
/**
2532
* A test suite to test DataFrame/SQL functionalities with complex types (i.e. array, struct, map).
@@ -64,6 +71,33 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSparkSession {
6471
val ds100_5 = Seq(S100_5()).toDS()
6572
ds100_5.rdd.count
6673
}
74+
75+
test("SPARK-29503 nest unsafe struct inside safe array") {
76+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
77+
val exampleDS = spark.sparkContext.parallelize(Seq(Seq(1, 2, 3))).toDF("items")
78+
79+
// items: Seq[Int] => items.map { item => Seq(Struct(item)) }
80+
val result = exampleDS.select(
81+
new Column(MapObjects(
82+
(item: Expression) => array(struct(new Column(item))).expr,
83+
$"items".expr,
84+
exampleDS.schema("items").dataType.asInstanceOf[ArrayType].elementType
85+
)) as "items"
86+
).collect()
87+
88+
def getValueInsideDepth(result: Row, index: Int): Int = {
89+
// expected output:
90+
// WrappedArray([WrappedArray(WrappedArray([1]), WrappedArray([2]), WrappedArray([3]))])
91+
result.getSeq[mutable.WrappedArray[_]](0)(index)(0)
92+
.asInstanceOf[GenericRowWithSchema].getInt(0)
93+
}
94+
95+
assert(result.size === 1)
96+
assert(getValueInsideDepth(result.head, 0) === 1)
97+
assert(getValueInsideDepth(result.head, 1) === 2)
98+
assert(getValueInsideDepth(result.head, 2) === 3)
99+
}
100+
}
67101
}
68102

69103
class S100(

0 commit comments

Comments
 (0)