Skip to content

[SPARK-37019][SQL] Add codegen support to array higher-order functions #34558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,13 @@ class EquivalentExpressions(
// There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. ConditionalExpression: use its children that will always be evaluated.
// 3. HigherOrderFunction: lambda functions operate in the context of local lambdas and can't
// be called outside of that scope, only the arguments can be evaluated ahead of
// time.
private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut)
case h: HigherOrderFunction => h.arguments
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to do the same for commonChildrenToRecurse?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. That only cares if the current expression is a ConditionalExpression. The default is Nil if it's not that so I don't think it needs any special handling for HOFs

case other => skipForShortcut(other).children
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,41 @@ class CodegenContext extends Logging {
*/
var currentVars: Seq[ExprCode] = null

/**
* Holding a map of current lambda variables.
*/
var currentLambdaVars: mutable.Map[Long, ExprCode] = mutable.HashMap.empty

def withLambdaVars(
namedLambdas: Seq[NamedLambdaVariable],
f: Seq[ExprCode] => ExprCode): ExprCode = {
val lambdaVars = namedLambdas.map { lambda =>
val id = lambda.exprId.id
if (currentLambdaVars.get(id).nonEmpty) {
throw QueryExecutionErrors.lambdaVariableAlreadyDefinedError(id)
}
val isNull = if (lambda.nullable) {
JavaCode.isNullGlobal(addMutableState(JAVA_BOOLEAN, "lambdaIsNull"))
} else {
FalseLiteral
}
val value = addMutableState(javaType(lambda.dataType), "lambdaValue")
val lambdaVar = ExprCode(isNull, JavaCode.global(value, lambda.dataType))
currentLambdaVars.put(id, lambdaVar)
lambdaVar
}

val result = f(lambdaVars)
namedLambdas.map(_.exprId.id).foreach(currentLambdaVars.remove)
result
}

def getLambdaVar(id: Long): ExprCode = {
currentLambdaVars.getOrElse(
id,
throw QueryExecutionErrors.lambdaVariableNotDefinedError(id))
}

/**
* Holding expressions' inlined mutable states like `MonotonicallyIncreasingID.count` as a
* 2-tuple: java type, variable name.
Expand Down
Loading