diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2d637a1923e45..beed5a6e3651a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4806,6 +4806,17 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr: } } + private lazy val positivePos = if (second.foldable) { + val pos = second.eval().asInstanceOf[Int] + if (pos > 0) { + Some(pos) + } else { + None + } + } else { + None + } + override def eval(input: InternalRow): Any = { val value1 = first.eval(input) if (value1 != null) { @@ -4819,21 +4830,9 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr: } override def nullSafeEval(arr: Any, pos: Any, item: Any): Any = { - var posInt = pos.asInstanceOf[Int] - if (posInt == 0) { - throw QueryExecutionErrors.invalidIndexOfZeroError(getContextOrNull()) - } val baseArr = arr.asInstanceOf[ArrayData] - val arrayElementType = dataType.asInstanceOf[ArrayType].elementType - - val newPosExtendsArrayLeft = (posInt < 0) && (-posInt > baseArr.numElements()) - - if (newPosExtendsArrayLeft) { - // special case- if the new position is negative but larger than the current array size - // place the new item at start of array, place the current array contents at the end - // and fill the newly created array elements inbetween with a null - - val newArrayLength = -posInt + 1 + if (positivePos.isDefined) { + val newArrayLength = math.max(baseArr.numElements() + 1, positivePos.get) if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength) @@ -4841,48 +4840,81 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr: val newArray = new Array[Any](newArrayLength) - baseArr.foreach(arrayElementType, (i, v) => { - // current position, offset by new item + new null array elements - val elementPosition = i + 1 + math.abs(posInt + baseArr.numElements()) - newArray(elementPosition) = v + val posInt = positivePos.get - 1 + baseArr.foreach(elementType, (i, v) => { + if (i >= posInt) { + newArray(i + 1) = v + } else { + newArray(i) = v + } }) - newArray(0) = item + newArray(posInt) = item - return new GenericArrayData(newArray) + new GenericArrayData(newArray) } else { - if (posInt < 0) { - posInt = posInt + baseArr.numElements() - } else if (posInt > 0) { - posInt = posInt - 1 + var posInt = pos.asInstanceOf[Int] + if (posInt == 0) { + throw QueryExecutionErrors.invalidIndexOfZeroError(getContextOrNull()) } - val newArrayLength = math.max(baseArr.numElements() + 1, posInt + 1) + val newPosExtendsArrayLeft = (posInt < 0) && (-posInt > baseArr.numElements()) - if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength) - } + if (newPosExtendsArrayLeft) { + // special case- if the new position is negative but larger than the current array size + // place the new item at start of array, place the current array contents at the end + // and fill the newly created array elements inbetween with a null - val newArray = new Array[Any](newArrayLength) + val newArrayLength = -posInt + 1 - baseArr.foreach(arrayElementType, (i, v) => { - if (i >= posInt) { - newArray(i + 1) = v - } else { - newArray(i) = v + if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength) } - }) - newArray(posInt) = item + val newArray = new Array[Any](newArrayLength) + + baseArr.foreach(elementType, (i, v) => { + // current position, offset by new item + new null array elements + val elementPosition = i + 1 + math.abs(posInt + baseArr.numElements()) + newArray(elementPosition) = v + }) + + newArray(0) = item + + new GenericArrayData(newArray) + } else { + if (posInt < 0) { + posInt = posInt + baseArr.numElements() + } else if (posInt > 0) { + posInt = posInt - 1 + } + + val newArrayLength = math.max(baseArr.numElements() + 1, posInt + 1) - return new GenericArrayData(newArray) + if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength) + } + + val newArray = new Array[Any](newArrayLength) + + baseArr.foreach(elementType, (i, v) => { + if (i >= posInt) { + newArray(i + 1) = v + } else { + newArray(i) = v + } + }) + + newArray(posInt) = item + + new GenericArrayData(newArray) + } } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val f = (arrExpr: ExprCode, posExpr: ExprCode, itemExpr: ExprCode) => { val arr = arrExpr.value - val pos = posExpr.value val item = itemExpr.value val itemInsertionIndex = ctx.freshName("itemInsertionIndex") @@ -4898,69 +4930,99 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr: val assignment = CodeGenerator.createArrayAssignment(values, elementType, arr, adjustedAllocIdx, i, first.dataType.asInstanceOf[ArrayType].containsNull) val errorContext = getContextOrNullCode(ctx) - - s""" - |int $itemInsertionIndex = 0; - |int $resLength = 0; - |int $adjustedAllocIdx = 0; - |boolean $insertedItemIsNull = ${itemExpr.isNull}; - | - |if ($pos == 0) { - | throw QueryExecutionErrors.invalidIndexOfZeroError($errorContext); - |} - | - |if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) { - | - | $resLength = java.lang.Math.abs($pos) + 1; - | if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength); - | } - | - | $allocation - | for (int $i = 0; $i < $arr.numElements(); $i ++) { - | $adjustedAllocIdx = $i + 1 + java.lang.Math.abs($pos + $arr.numElements()); - | $assignment - | } - | ${CodeGenerator.setArrayElement( - values, elementType, itemInsertionIndex, item, Some(insertedItemIsNull))} - | - | for (int $j = $pos + $arr.numElements(); $j < 0; $j ++) { - | $values.setNullAt($j + 1 + java.lang.Math.abs($pos + $arr.numElements())); - | } - | - | ${ev.value} = $values; - |} else { - | - | $itemInsertionIndex = 0; - | if ($pos < 0) { - | $itemInsertionIndex = $pos + $arr.numElements(); - | } else if ($pos > 0) { - | $itemInsertionIndex = $pos - 1; - | } - | - | $resLength = java.lang.Math.max($arr.numElements() + 1, $itemInsertionIndex + 1); - | if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength); - | } - | - | $allocation - | for (int $i = 0; $i < $arr.numElements(); $i ++) { - | $adjustedAllocIdx = $i; - | if ($i >= $itemInsertionIndex) { - | $adjustedAllocIdx = $adjustedAllocIdx + 1; - | } - | $assignment - | } - | ${CodeGenerator.setArrayElement( + if (positivePos.isDefined) { + s""" + |int $itemInsertionIndex = ${positivePos.get - 1}; + |int $adjustedAllocIdx = 0; + |boolean $insertedItemIsNull = ${itemExpr.isNull}; + | + |final int $resLength = java.lang.Math.max($arr.numElements() + 1, ${positivePos.get}); + |if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength); + |} + | + |$allocation + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $adjustedAllocIdx = $i; + | if ($i >= $itemInsertionIndex) { + | $adjustedAllocIdx = $adjustedAllocIdx + 1; + | } + | $assignment + |} + |${CodeGenerator.setArrayElement( values, elementType, itemInsertionIndex, item, Some(insertedItemIsNull))} - | - | for (int $j = $arr.numElements(); $j < $resLength - 1; $j ++) { - | $values.setNullAt($j); - | } - | - | ${ev.value} = $values; - |} - """.stripMargin + | + |for (int $j = $arr.numElements(); $j < $resLength - 1; $j ++) { + | $values.setNullAt($j); + |} + | + |${ev.value} = $values; + |""".stripMargin + } else { + val pos = posExpr.value + s""" + |int $itemInsertionIndex = 0; + |int $resLength = 0; + |int $adjustedAllocIdx = 0; + |boolean $insertedItemIsNull = ${itemExpr.isNull}; + | + |if ($pos == 0) { + | throw QueryExecutionErrors.invalidIndexOfZeroError($errorContext); + |} + | + |if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) { + | + | $resLength = java.lang.Math.abs($pos) + 1; + | if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength); + | } + | + | $allocation + | for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $adjustedAllocIdx = $i + 1 + java.lang.Math.abs($pos + $arr.numElements()); + | $assignment + | } + | ${CodeGenerator.setArrayElement( + values, elementType, itemInsertionIndex, item, Some(insertedItemIsNull))} + | + | for (int $j = $pos + $arr.numElements(); $j < 0; $j ++) { + | $values.setNullAt($j + 1 + java.lang.Math.abs($pos + $arr.numElements())); + | } + | + | ${ev.value} = $values; + |} else { + | + | $itemInsertionIndex = 0; + | if ($pos < 0) { + | $itemInsertionIndex = $pos + $arr.numElements(); + | } else if ($pos > 0) { + | $itemInsertionIndex = $pos - 1; + | } + | + | $resLength = java.lang.Math.max($arr.numElements() + 1, $itemInsertionIndex + 1); + | if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength); + | } + | + | $allocation + | for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $adjustedAllocIdx = $i; + | if ($i >= $itemInsertionIndex) { + | $adjustedAllocIdx = $adjustedAllocIdx + 1; + | } + | $assignment + | } + | ${CodeGenerator.setArrayElement( + values, elementType, itemInsertionIndex, item, Some(insertedItemIsNull))} + | + | for (int $j = $arr.numElements(); $j < $resLength - 1; $j ++) { + | $values.setNullAt($j); + | } + | + | ${ev.value} = $values; + |} + |""".stripMargin + } } val leftGen = first.genCode(ctx)