Skip to content

Commit

Permalink
[SPARK-43137][SQL] Improve ArrayInsert if the position is foldable an…
Browse files Browse the repository at this point in the history
…d positive

### What changes were proposed in this pull request?
Currently, Spark supports the `array_insert` and `array_prepend`. Users insert an element into the head of array is common operation. Considered, we want make array_prepend reuse the implementation of array_insert, but it seems a bit performance worse if the position is foldable and positive.
The reason is that always do the check for position is negative or positive, and the code is too long. Too long code will lead to JIT failed.

### Why are the changes needed?
Improve ArrayInsert if the position is foldable and positive.

### Does this PR introduce _any_ user-facing change?
'No'.
Just change the inner implementation.

### How was this patch tested?
Exists test cases.

Closes #40833 from beliefer/SPARK-43137_new.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
beliefer authored and cloud-fan committed Apr 19, 2023
1 parent 74ce620 commit 8db31aa
Showing 1 changed file with 162 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4806,6 +4806,17 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:
}
}

private lazy val positivePos = if (second.foldable) {
val pos = second.eval().asInstanceOf[Int]
if (pos > 0) {
Some(pos)
} else {
None
}
} else {
None
}

override def eval(input: InternalRow): Any = {
val value1 = first.eval(input)
if (value1 != null) {
Expand All @@ -4819,70 +4830,91 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:
}

override def nullSafeEval(arr: Any, pos: Any, item: Any): Any = {
var posInt = pos.asInstanceOf[Int]
if (posInt == 0) {
throw QueryExecutionErrors.invalidIndexOfZeroError(getContextOrNull())
}
val baseArr = arr.asInstanceOf[ArrayData]
val arrayElementType = dataType.asInstanceOf[ArrayType].elementType

val newPosExtendsArrayLeft = (posInt < 0) && (-posInt > baseArr.numElements())

if (newPosExtendsArrayLeft) {
// special case- if the new position is negative but larger than the current array size
// place the new item at start of array, place the current array contents at the end
// and fill the newly created array elements inbetween with a null

val newArrayLength = -posInt + 1
if (positivePos.isDefined) {
val newArrayLength = math.max(baseArr.numElements() + 1, positivePos.get)

if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
}

val newArray = new Array[Any](newArrayLength)

baseArr.foreach(arrayElementType, (i, v) => {
// current position, offset by new item + new null array elements
val elementPosition = i + 1 + math.abs(posInt + baseArr.numElements())
newArray(elementPosition) = v
val posInt = positivePos.get - 1
baseArr.foreach(elementType, (i, v) => {
if (i >= posInt) {
newArray(i + 1) = v
} else {
newArray(i) = v
}
})

newArray(0) = item
newArray(posInt) = item

return new GenericArrayData(newArray)
new GenericArrayData(newArray)
} else {
if (posInt < 0) {
posInt = posInt + baseArr.numElements()
} else if (posInt > 0) {
posInt = posInt - 1
var posInt = pos.asInstanceOf[Int]
if (posInt == 0) {
throw QueryExecutionErrors.invalidIndexOfZeroError(getContextOrNull())
}

val newArrayLength = math.max(baseArr.numElements() + 1, posInt + 1)
val newPosExtendsArrayLeft = (posInt < 0) && (-posInt > baseArr.numElements())

if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
}
if (newPosExtendsArrayLeft) {
// special case- if the new position is negative but larger than the current array size
// place the new item at start of array, place the current array contents at the end
// and fill the newly created array elements inbetween with a null

val newArray = new Array[Any](newArrayLength)
val newArrayLength = -posInt + 1

baseArr.foreach(arrayElementType, (i, v) => {
if (i >= posInt) {
newArray(i + 1) = v
} else {
newArray(i) = v
if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
}
})

newArray(posInt) = item
val newArray = new Array[Any](newArrayLength)

baseArr.foreach(elementType, (i, v) => {
// current position, offset by new item + new null array elements
val elementPosition = i + 1 + math.abs(posInt + baseArr.numElements())
newArray(elementPosition) = v
})

newArray(0) = item

new GenericArrayData(newArray)
} else {
if (posInt < 0) {
posInt = posInt + baseArr.numElements()
} else if (posInt > 0) {
posInt = posInt - 1
}

val newArrayLength = math.max(baseArr.numElements() + 1, posInt + 1)

return new GenericArrayData(newArray)
if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
}

val newArray = new Array[Any](newArrayLength)

baseArr.foreach(elementType, (i, v) => {
if (i >= posInt) {
newArray(i + 1) = v
} else {
newArray(i) = v
}
})

newArray(posInt) = item

new GenericArrayData(newArray)
}
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val f = (arrExpr: ExprCode, posExpr: ExprCode, itemExpr: ExprCode) => {
val arr = arrExpr.value
val pos = posExpr.value
val item = itemExpr.value

val itemInsertionIndex = ctx.freshName("itemInsertionIndex")
Expand All @@ -4898,69 +4930,99 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:
val assignment = CodeGenerator.createArrayAssignment(values, elementType, arr,
adjustedAllocIdx, i, first.dataType.asInstanceOf[ArrayType].containsNull)
val errorContext = getContextOrNullCode(ctx)

s"""
|int $itemInsertionIndex = 0;
|int $resLength = 0;
|int $adjustedAllocIdx = 0;
|boolean $insertedItemIsNull = ${itemExpr.isNull};
|
|if ($pos == 0) {
| throw QueryExecutionErrors.invalidIndexOfZeroError($errorContext);
|}
|
|if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) {
|
| $resLength = java.lang.Math.abs($pos) + 1;
| if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
| }
|
| $allocation
| for (int $i = 0; $i < $arr.numElements(); $i ++) {
| $adjustedAllocIdx = $i + 1 + java.lang.Math.abs($pos + $arr.numElements());
| $assignment
| }
| ${CodeGenerator.setArrayElement(
values, elementType, itemInsertionIndex, item, Some(insertedItemIsNull))}
|
| for (int $j = $pos + $arr.numElements(); $j < 0; $j ++) {
| $values.setNullAt($j + 1 + java.lang.Math.abs($pos + $arr.numElements()));
| }
|
| ${ev.value} = $values;
|} else {
|
| $itemInsertionIndex = 0;
| if ($pos < 0) {
| $itemInsertionIndex = $pos + $arr.numElements();
| } else if ($pos > 0) {
| $itemInsertionIndex = $pos - 1;
| }
|
| $resLength = java.lang.Math.max($arr.numElements() + 1, $itemInsertionIndex + 1);
| if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
| }
|
| $allocation
| for (int $i = 0; $i < $arr.numElements(); $i ++) {
| $adjustedAllocIdx = $i;
| if ($i >= $itemInsertionIndex) {
| $adjustedAllocIdx = $adjustedAllocIdx + 1;
| }
| $assignment
| }
| ${CodeGenerator.setArrayElement(
if (positivePos.isDefined) {
s"""
|int $itemInsertionIndex = ${positivePos.get - 1};
|int $adjustedAllocIdx = 0;
|boolean $insertedItemIsNull = ${itemExpr.isNull};
|
|final int $resLength = java.lang.Math.max($arr.numElements() + 1, ${positivePos.get});
|if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
|}
|
|$allocation
|for (int $i = 0; $i < $arr.numElements(); $i ++) {
| $adjustedAllocIdx = $i;
| if ($i >= $itemInsertionIndex) {
| $adjustedAllocIdx = $adjustedAllocIdx + 1;
| }
| $assignment
|}
|${CodeGenerator.setArrayElement(
values, elementType, itemInsertionIndex, item, Some(insertedItemIsNull))}
|
| for (int $j = $arr.numElements(); $j < $resLength - 1; $j ++) {
| $values.setNullAt($j);
| }
|
| ${ev.value} = $values;
|}
""".stripMargin
|
|for (int $j = $arr.numElements(); $j < $resLength - 1; $j ++) {
| $values.setNullAt($j);
|}
|
|${ev.value} = $values;
|""".stripMargin
} else {
val pos = posExpr.value
s"""
|int $itemInsertionIndex = 0;
|int $resLength = 0;
|int $adjustedAllocIdx = 0;
|boolean $insertedItemIsNull = ${itemExpr.isNull};
|
|if ($pos == 0) {
| throw QueryExecutionErrors.invalidIndexOfZeroError($errorContext);
|}
|
|if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) {
|
| $resLength = java.lang.Math.abs($pos) + 1;
| if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
| }
|
| $allocation
| for (int $i = 0; $i < $arr.numElements(); $i ++) {
| $adjustedAllocIdx = $i + 1 + java.lang.Math.abs($pos + $arr.numElements());
| $assignment
| }
| ${CodeGenerator.setArrayElement(
values, elementType, itemInsertionIndex, item, Some(insertedItemIsNull))}
|
| for (int $j = $pos + $arr.numElements(); $j < 0; $j ++) {
| $values.setNullAt($j + 1 + java.lang.Math.abs($pos + $arr.numElements()));
| }
|
| ${ev.value} = $values;
|} else {
|
| $itemInsertionIndex = 0;
| if ($pos < 0) {
| $itemInsertionIndex = $pos + $arr.numElements();
| } else if ($pos > 0) {
| $itemInsertionIndex = $pos - 1;
| }
|
| $resLength = java.lang.Math.max($arr.numElements() + 1, $itemInsertionIndex + 1);
| if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
| }
|
| $allocation
| for (int $i = 0; $i < $arr.numElements(); $i ++) {
| $adjustedAllocIdx = $i;
| if ($i >= $itemInsertionIndex) {
| $adjustedAllocIdx = $adjustedAllocIdx + 1;
| }
| $assignment
| }
| ${CodeGenerator.setArrayElement(
values, elementType, itemInsertionIndex, item, Some(insertedItemIsNull))}
|
| for (int $j = $arr.numElements(); $j < $resLength - 1; $j ++) {
| $values.setNullAt($j);
| }
|
| ${ev.value} = $values;
|}
|""".stripMargin
}
}

val leftGen = first.genCode(ctx)
Expand Down

0 comments on commit 8db31aa

Please sign in to comment.