Skip to content

[SPARK-37019][SQL] Add codegen support to array transform #34294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,17 @@ class EquivalentExpressions {

// There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. If: common subexpressions will always be evaluated at the beginning, but the true and
// 2. LambdaFunction: it's children operate in the context of local lambdas and can't be split
// 3. If: common subexpressions will always be evaluated at the beginning, but the true and
// false expressions in `If` may not get accessed, according to the predicate
// expression. We should only recurse into the predicate expression.
// 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain
// 4. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain
// condition. We should only recurse into the first condition expression as it
// will always get accessed.
// 4. Coalesce: it's also a conditional expression, we should only recurse into the first
// 5. Coalesce: it's also a conditional expression, we should only recurse into the first
// children, because others may not get accessed.
private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case _: CodegenFallback | _: LambdaFunction => Nil
case i: If => i.predicate :: Nil
case c: CaseWhen => c.children.head :: Nil
case c: Coalesce => c.children.head :: Nil
Expand All @@ -122,7 +123,7 @@ class EquivalentExpressions {
// For some special expressions we cannot just recurse into all of its children, but we can
// recursively add the common expressions shared between all of its children.
private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match {
case _: CodegenFallback => Nil
case _: CodegenFallback | _: LambdaFunction => Nil
case i: If => Seq(Seq(i.trueValue, i.falseValue))
case c: CaseWhen =>
// We look at subexpressions in conditions and values of `CaseWhen` separately. It is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}

import scala.collection.mutable

import org.apache.spark.sql.catalyst.CatalystTypeConverters.isPrimitive
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
Expand Down Expand Up @@ -76,8 +78,7 @@ case class NamedLambdaVariable(
exprId: ExprId = NamedExpression.newExprId,
value: AtomicReference[Any] = new AtomicReference())
extends LeafExpression
with NamedExpression
with CodegenFallback {
with NamedExpression {

override def qualifier: Seq[String] = Seq.empty

Expand All @@ -98,6 +99,31 @@ case class NamedLambdaVariable(
override def simpleString(maxFields: Int): String = {
s"lambda $name#${exprId.id}: ${dataType.simpleString(maxFields)}"
}

// We need to include the Expr ID in the Codegen variable name since several tests bypass
// `UnresolvedNamedLambdaVariable.freshVarName`
lazy val variableName = s"${name}_${exprId.id}"

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val atomicRef = ctx.addReferenceObj(variableName, value)
val tmpAtomic = ctx.freshName("tmpAtomic")
val boxedType = CodeGenerator.boxedType(dataType)

if (nullable) {
ev.copy(code = code"""
Object $tmpAtomic = $atomicRef.get();
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
boolean ${ev.isNull} = $tmpAtomic == null;
if (!${ev.isNull}) {
${ev.value} = ($boxedType)$tmpAtomic;
}
""")
} else {
ev.copy(code = code"""
${CodeGenerator.javaType(dataType)} ${ev.value} = ($boxedType)$atomicRef.get();
""", isNull = FalseLiteral)
}
}
}

/**
Expand All @@ -109,7 +135,7 @@ case class LambdaFunction(
function: Expression,
arguments: Seq[NamedExpression],
hidden: Boolean = false)
extends Expression with CodegenFallback {
extends Expression {

override def children: Seq[Expression] = function +: arguments
override def dataType: DataType = function.dataType
Expand All @@ -127,6 +153,23 @@ case class LambdaFunction(

override def eval(input: InternalRow): Any = function.eval(input)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val functionCode = function.genCode(ctx)

if (nullable) {
ev.copy(code = code"""
|${functionCode.code}
|boolean ${ev.isNull} = ${functionCode.isNull};
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${functionCode.value};
""".stripMargin)
} else {
ev.copy(code = code"""
|${functionCode.code}
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${functionCode.value};
""".stripMargin, isNull = FalseLiteral)
}
}

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): LambdaFunction =
copy(
Expand Down Expand Up @@ -224,6 +267,21 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes {
val canonicalizedChildren = cleaned.children.map(_.canonicalized)
Canonicalize.execute(withNewChildren(canonicalizedChildren))
}

protected def assignAtomic(atomicRef: String, value: String, isNull: String = FalseLiteral,
nullable: Boolean = false) = {
if (nullable) {
s"""
if ($isNull) {
$atomicRef.set(null);
} else {
$atomicRef.set($value);
}
"""
} else {
s"$atomicRef.set($value);"
}
}
}

/**
Expand Down Expand Up @@ -269,10 +327,49 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expr
}
}

protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: String => String): ExprCode = {
val argumentGen = argument.genCode(ctx)
val resultCode = f(argumentGen.value)

if (nullable) {
val nullSafeEval = ctx.nullSafeExec(argument.nullable, argumentGen.isNull)(resultCode)
ev.copy(code = code"""
|${argumentGen.code}
|boolean ${ev.isNull} = ${argumentGen.isNull};
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$nullSafeEval
""")
} else {
ev.copy(code = code"""
|${argumentGen.code}
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$resultCode
""", isNull = FalseLiteral)
}
}
}

trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
override def argumentType: AbstractDataType = ArrayType

protected def assignElement(ctx: CodegenContext, arrayName: String,
elementVar: NamedLambdaVariable, index: String): String = {
val elementType = elementVar.dataType
val elementAtomic = ctx.addReferenceObj(elementVar.variableName, elementVar.value)
val extractElement = CodeGenerator.getValue(arrayName, elementType, index)

assignAtomic(elementAtomic, extractElement, s"$arrayName.isNullAt($index)",
elementVar.nullable)
}

protected def assignIndex(ctx: CodegenContext, indexVar: NamedLambdaVariable,
index: String): String = {
val indexAtomic = ctx.addReferenceObj(indexVar.variableName, indexVar.value)
assignAtomic(indexAtomic, index)
}
}

trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
Expand All @@ -297,7 +394,7 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
case class ArrayTransform(
argument: Expression,
function: Expression)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
extends ArrayBasedSimpleHigherOrderFunction {

override def dataType: ArrayType = ArrayType(function.dataType, function.nullable)

Expand Down Expand Up @@ -338,6 +435,43 @@ case class ArrayTransform(
result
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of this can probably be abstracted out into the parent traits, but I figured that will be easier to do when implementing a second function

nullSafeCodeGen(ctx, ev, arg => {
val numElements = ctx.freshName("numElements")
val arrayData = ctx.freshName("arrayData")
val i = ctx.freshName("i")

val initialization = CodeGenerator.createArrayData(
arrayData, dataType.elementType, numElements, s" $prettyName failed.")

val functionCode = function.genCode(ctx)

val elementAssignment = assignElement(ctx, arg, elementVar, i)
val indexAssignment = indexVar.map(c => assignIndex(ctx, c, i))
val varAssignments = (Seq(elementAssignment) ++: indexAssignment).mkString("\n")

// Some expressions return internal buffers that we have to copy
val copy = if (isPrimitive(function.dataType)) {
s"${functionCode.value}"
} else {
s"InternalRow.copyValue(${functionCode.value})"
}
val resultAssignment = CodeGenerator.setArrayElement(arrayData, dataType.elementType,
i, copy, isNull = Some(functionCode.isNull))

s"""
|final int $numElements = ${arg}.numElements();
|$initialization
|for (int $i = 0; $i < $numElements; $i++) {
| $varAssignments
| ${functionCode.code}
| $resultAssignment
|}
|${ev.value} = $arrayData;
""".stripMargin
})
}

override def prettyName: String = "transform"

override protected def withNewChildrenInternal(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,15 @@ object QueryExecutionErrors {
new IllegalArgumentException(s"$funcName is not matched at addNewFunction")
}

def lambdaVariableAlreadyDefinedError(name: String): Throwable = {
new IllegalArgumentException(s"Lambda variable $name cannot be redefined")
}

def lambdaVariableNotDefinedError(name: String): Throwable = {
new IllegalArgumentException(
s"Lambda variable $name is not defined in the current codegen scope")
}

def cannotGenerateCodeForUncomparableTypeError(
codeType: String, dataType: DataType): Throwable = {
new IllegalArgumentException(
Expand Down