Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
TypeCoercion.typeCoercionRules
}

override def batches: Seq[Batch] = Seq(
private def earlyBatches: Seq[Batch] = Seq(
Batch("Substitution", fixedPoint,
// This rule optimizes `UpdateFields` expression chains so looks more like optimization rule.
// However, when manipulating deeply nested schema, `UpdateFields` expression tree could be
Expand All @@ -357,7 +357,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
Batch("Simple Sanity Check", Once,
LookupFunctions),
Batch("Keep Legacy Outputs", Once,
KeepLegacyOutputs),
KeepLegacyOutputs)
)

override def batches: Seq[Batch] = earlyBatches ++ Seq(
Batch("Resolution", fixedPoint,
new ResolveCatalogs(catalogManager) ::
ResolveInsertInto ::
Expand Down Expand Up @@ -409,7 +412,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ResolveTimeZone ::
ResolveRandomSeed ::
ResolveBinaryArithmetic ::
ResolveIdentifierClause ::
new ResolveIdentifierClause(earlyBatches) ::
ResolveUnion ::
ResolveRowLevelCommandAssignments ::
MoveParameterizedQueriesDown ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,24 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_IDENTIFIER
import org.apache.spark.sql.types.StringType

/**
* Resolves the identifier expressions and builds the original plans/expressions.
*/
object ResolveIdentifierClause extends Rule[LogicalPlan] with AliasHelper with EvalHelper {
class ResolveIdentifierClause(earlyBatches: Seq[RuleExecutor[LogicalPlan]#Batch])
extends Rule[LogicalPlan] with AliasHelper with EvalHelper {

private val executor = new RuleExecutor[LogicalPlan] {
override def batches: Seq[Batch] = earlyBatches.asInstanceOf[Seq[Batch]]
}

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsPattern(UNRESOLVED_IDENTIFIER)) {
case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved && p.childrenResolved =>
p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr), p.children)
executor.execute(p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr), p.children))
case other =>
other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) {
case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
override val maxIterationsSetting: String = null) extends Strategy

/** A batch of rules. */
protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*)
protected[catalyst] case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*)

/** Defines a sequence of rule batches, to be overridden by the implementation. */
protected def batches: Seq[Batch]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -76,6 +77,14 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
}
}

private def withTable(spark: SparkSession, tableNames: String*)(f: => Unit): Unit = {
try f finally {
tableNames.foreach { name =>
spark.sql(s"DROP TABLE IF EXISTS $name")
}
}
}

test("inject analyzer rule") {
withSession(Seq(_.injectResolutionRule(MyRule))) { session =>
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
Expand Down Expand Up @@ -571,6 +580,28 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
assert(res.collect {case s: Sort => s}.isEmpty)
}
}

test("early batch rule is applied on resolved IDENTIFIER") {
var ruleApplied = false

case class UnresolvedRelationRule(spark: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case r: UnresolvedRelation =>
ruleApplied = true
r
}
}

withSession(Seq(_.injectHintResolutionRule(UnresolvedRelationRule))) { session =>
withTable(session, "my_table") {
session.sql("CREATE TABLE IF NOT EXISTS my_table (col1 INT)")
ruleApplied = false

session.sql("SELECT * FROM IDENTIFIER('my_' || 'table')").collect()
assert(ruleApplied)
}
}
}
}

case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
Expand Down