Skip to content

Commit 7143e9d

Browse files
mgaido91cloud-fan
authored andcommitted
[SPARK-25829][SQL][FOLLOWUP] Refactor MapConcat in order to check properly the limit size
## What changes were proposed in this pull request? The PR starts from the [comment](#23124 (comment)) in the main one and it aims at: - simplifying the code for `MapConcat`; - be more precise in checking the limit size. ## How was this patch tested? existing tests Closes #23217 from mgaido91/SPARK-25829_followup. Authored-by: Marco Gaido <marcogaido91@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 180f969 commit 7143e9d

File tree

2 files changed

+12
-75
lines changed

2 files changed

+12
-75
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 2 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -554,13 +554,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
554554
return null
555555
}
556556

557-
val numElements = maps.foldLeft(0L)((sum, ad) => sum + ad.numElements())
558-
if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
559-
throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " +
560-
s"elements due to exceeding the map size limit " +
561-
s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
562-
}
563-
564557
for (map <- maps) {
565558
mapBuilder.putAll(map.keyArray(), map.valueArray())
566559
}
@@ -569,8 +562,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
569562

570563
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
571564
val mapCodes = children.map(_.genCode(ctx))
572-
val keyType = dataType.keyType
573-
val valueType = dataType.valueType
574565
val argsName = ctx.freshName("args")
575566
val hasNullName = ctx.freshName("hasNull")
576567
val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder)
@@ -610,41 +601,12 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
610601
)
611602

612603
val idxName = ctx.freshName("idx")
613-
val numElementsName = ctx.freshName("numElems")
614-
val finKeysName = ctx.freshName("finalKeys")
615-
val finValsName = ctx.freshName("finalValues")
616-
617-
val keyConcat = genCodeForArrays(ctx, keyType, false)
618-
619-
val valueConcat =
620-
if (valueType.sameType(keyType) &&
621-
!(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) {
622-
keyConcat
623-
} else {
624-
genCodeForArrays(ctx, valueType, dataType.valueContainsNull)
625-
}
626-
627-
val keyArgsName = ctx.freshName("keyArgs")
628-
val valArgsName = ctx.freshName("valArgs")
629-
630604
val mapMerge =
631605
s"""
632-
|ArrayData[] $keyArgsName = new ArrayData[${mapCodes.size}];
633-
|ArrayData[] $valArgsName = new ArrayData[${mapCodes.size}];
634-
|long $numElementsName = 0;
635606
|for (int $idxName = 0; $idxName < $argsName.length; $idxName++) {
636-
| $keyArgsName[$idxName] = $argsName[$idxName].keyArray();
637-
| $valArgsName[$idxName] = $argsName[$idxName].valueArray();
638-
| $numElementsName += $argsName[$idxName].numElements();
607+
| $builderTerm.putAll($argsName[$idxName].keyArray(), $argsName[$idxName].valueArray());
639608
|}
640-
|if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
641-
| throw new RuntimeException("Unsuccessful attempt to concat maps with " +
642-
| $numElementsName + " elements due to exceeding the map size limit " +
643-
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
644-
|}
645-
|ArrayData $finKeysName = $keyConcat($keyArgsName, (int) $numElementsName);
646-
|ArrayData $finValsName = $valueConcat($valArgsName, (int) $numElementsName);
647-
|${ev.value} = $builderTerm.from($finKeysName, $finValsName);
609+
|${ev.value} = $builderTerm.build();
648610
""".stripMargin
649611

650612
ev.copy(
@@ -660,41 +622,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
660622
""".stripMargin)
661623
}
662624

663-
private def genCodeForArrays(
664-
ctx: CodegenContext,
665-
elementType: DataType,
666-
checkForNull: Boolean): String = {
667-
val counter = ctx.freshName("counter")
668-
val arrayData = ctx.freshName("arrayData")
669-
val argsName = ctx.freshName("args")
670-
val numElemName = ctx.freshName("numElements")
671-
val y = ctx.freshName("y")
672-
val z = ctx.freshName("z")
673-
674-
val allocation = CodeGenerator.createArrayData(
675-
arrayData, elementType, numElemName, s" $prettyName failed.")
676-
val assignment = CodeGenerator.createArrayAssignment(
677-
arrayData, elementType, s"$argsName[$y]", counter, z, checkForNull)
678-
679-
val concat = ctx.freshName("concat")
680-
val concatDef =
681-
s"""
682-
|private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
683-
| $allocation
684-
| int $counter = 0;
685-
| for (int $y = 0; $y < ${children.length}; $y++) {
686-
| for (int $z = 0; $z < $argsName[$y].numElements(); $z++) {
687-
| $assignment
688-
| $counter++;
689-
| }
690-
| }
691-
| return $arrayData;
692-
|}
693-
""".stripMargin
694-
695-
ctx.addNewFunction(concat, concatDef)
696-
}
697-
698625
override def prettyName: String = "map_concat"
699626
}
700627

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.collection.mutable
2121

2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.types._
24+
import org.apache.spark.unsafe.array.ByteArrayMethods
2425

2526
/**
2627
* A builder of [[ArrayBasedMapData]], which fails if a null map key is detected, and removes
@@ -54,6 +55,10 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
5455

5556
val index = keyToIndex.getOrDefault(key, -1)
5657
if (index == -1) {
58+
if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
59+
throw new RuntimeException(s"Unsuccessful attempt to build maps with $size elements " +
60+
s"due to exceeding the map size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
61+
}
5762
keyToIndex.put(key, values.length)
5863
keys.append(key)
5964
values.append(value)
@@ -117,4 +122,9 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
117122
build()
118123
}
119124
}
125+
126+
/**
127+
* Returns the current size of the map which is going to be produced by the current builder.
128+
*/
129+
def size: Int = keys.size
120130
}

0 commit comments

Comments
 (0)