Skip to content

Commit

Permalink
[SPARK-49245][SQL] Refactor some analyzer rules
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Refactor some rules in `Analyzer.scala` for better transparency.

### Why are the changes needed?
For some analyzer rules it is hard to track its bodies we moved them to seperated functions. This PR should introduce more transparency and more purity to the code.

### Does this PR introduce _any_ user-facing change?
No.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #47773 from mihailoale-db/mihailoale-db/refactor_analyzer0.

Authored-by: Mihailo Aleksic <mihailo.aleksic@databricks.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
mihailoale-db authored and MaxGekk committed Aug 16, 2024
1 parent e3a4430 commit 5166cb3
Showing 1 changed file with 115 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -413,75 +413,80 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
object ResolveBinaryArithmetic extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveExpressionsUpWithPruning(_.containsPattern(BINARY_ARITHMETIC), ruleId) {
case a @ Add(l, r, mode) if a.childrenResolved => (l.dataType, r.dataType) match {
case (DateType, DayTimeIntervalType(DAY, DAY)) => DateAdd(l, ExtractANSIIntervalDays(r))
case (DateType, _: DayTimeIntervalType) => TimeAdd(Cast(l, TimestampType), r)
case (DayTimeIntervalType(DAY, DAY), DateType) => DateAdd(r, ExtractANSIIntervalDays(l))
case (_: DayTimeIntervalType, DateType) => TimeAdd(Cast(r, TimestampType), l)
case (DateType, _: YearMonthIntervalType) => DateAddYMInterval(l, r)
case (_: YearMonthIntervalType, DateType) => DateAddYMInterval(r, l)
case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) =>
TimestampAddYMInterval(l, r)
case (_: YearMonthIntervalType, TimestampType | TimestampNTZType) =>
TimestampAddYMInterval(r, l)
case (CalendarIntervalType, CalendarIntervalType) |
(_: DayTimeIntervalType, _: DayTimeIntervalType) => a
case (_: NullType, _: AnsiIntervalType) =>
a.copy(left = Cast(a.left, a.right.dataType))
case (_: AnsiIntervalType, _: NullType) =>
a.copy(right = Cast(a.right, a.left.dataType))
case (DateType, CalendarIntervalType) =>
DateAddInterval(l, r, ansiEnabled = mode == EvalMode.ANSI)
case (_, CalendarIntervalType | _: DayTimeIntervalType) => Cast(TimeAdd(l, r), l.dataType)
case (CalendarIntervalType, DateType) =>
DateAddInterval(r, l, ansiEnabled = mode == EvalMode.ANSI)
case (CalendarIntervalType | _: DayTimeIntervalType, _) => Cast(TimeAdd(r, l), r.dataType)
case (DateType, dt) if dt != StringType => DateAdd(l, r)
case (dt, DateType) if dt != StringType => DateAdd(r, l)
case _ => a
}
case s @ Subtract(l, r, mode) if s.childrenResolved => (l.dataType, r.dataType) match {
case (DateType, DayTimeIntervalType(DAY, DAY)) =>
DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), mode == EvalMode.ANSI))
case (DateType, _: DayTimeIntervalType) =>
DatetimeSub(l, r, TimeAdd(Cast(l, TimestampType), UnaryMinus(r, mode == EvalMode.ANSI)))
case (DateType, _: YearMonthIntervalType) =>
DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI)))
case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) =>
DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI)))
case (CalendarIntervalType, CalendarIntervalType) |
(_: DayTimeIntervalType, _: DayTimeIntervalType) => s
case (_: NullType, _: AnsiIntervalType) =>
s.copy(left = Cast(s.left, s.right.dataType))
case (_: AnsiIntervalType, _: NullType) =>
s.copy(right = Cast(s.right, s.left.dataType))
case (DateType, CalendarIntervalType) =>
DatetimeSub(l, r, DateAddInterval(l,
UnaryMinus(r, mode == EvalMode.ANSI), ansiEnabled = mode == EvalMode.ANSI))
case (_, CalendarIntervalType | _: DayTimeIntervalType) =>
Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, mode == EvalMode.ANSI))), l.dataType)
case _ if AnyTimestampTypeExpression.unapply(l) ||
AnyTimestampTypeExpression.unapply(r) => SubtractTimestamps(l, r)
case (_, DateType) => SubtractDates(l, r)
case (DateType, dt) if dt != StringType => DateSub(l, r)
case _ => s
}
case m @ Multiply(l, r, mode) if m.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => MultiplyInterval(l, r, mode == EvalMode.ANSI)
case (_, CalendarIntervalType) => MultiplyInterval(r, l, mode == EvalMode.ANSI)
case (_: YearMonthIntervalType, _) => MultiplyYMInterval(l, r)
case (_, _: YearMonthIntervalType) => MultiplyYMInterval(r, l)
case (_: DayTimeIntervalType, _) => MultiplyDTInterval(l, r)
case (_, _: DayTimeIntervalType) => MultiplyDTInterval(r, l)
case _ => m
}
case d @ Divide(l, r, mode) if d.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => DivideInterval(l, r, mode == EvalMode.ANSI)
case (_: YearMonthIntervalType, _) => DivideYMInterval(l, r)
case (_: DayTimeIntervalType, _) => DivideDTInterval(l, r)
case _ => d
}
case expr @ (_: Add | _: Subtract | _: Multiply | _: Divide)
if expr.childrenResolved => resolve(expr)
}

def resolve(expr: Expression): Expression = expr match {
case a @ Add(l, r, mode) => (l.dataType, r.dataType) match {
case (DateType, DayTimeIntervalType(DAY, DAY)) => DateAdd(l, ExtractANSIIntervalDays(r))
case (DateType, _: DayTimeIntervalType) => TimeAdd(Cast(l, TimestampType), r)
case (DayTimeIntervalType(DAY, DAY), DateType) => DateAdd(r, ExtractANSIIntervalDays(l))
case (_: DayTimeIntervalType, DateType) => TimeAdd(Cast(r, TimestampType), l)
case (DateType, _: YearMonthIntervalType) => DateAddYMInterval(l, r)
case (_: YearMonthIntervalType, DateType) => DateAddYMInterval(r, l)
case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) =>
TimestampAddYMInterval(l, r)
case (_: YearMonthIntervalType, TimestampType | TimestampNTZType) =>
TimestampAddYMInterval(r, l)
case (CalendarIntervalType, CalendarIntervalType) |
(_: DayTimeIntervalType, _: DayTimeIntervalType) => a
case (_: NullType, _: AnsiIntervalType) =>
a.copy(left = Cast(a.left, a.right.dataType))
case (_: AnsiIntervalType, _: NullType) =>
a.copy(right = Cast(a.right, a.left.dataType))
case (DateType, CalendarIntervalType) =>
DateAddInterval(l, r, ansiEnabled = mode == EvalMode.ANSI)
case (_, CalendarIntervalType | _: DayTimeIntervalType) => Cast(TimeAdd(l, r), l.dataType)
case (CalendarIntervalType, DateType) =>
DateAddInterval(r, l, ansiEnabled = mode == EvalMode.ANSI)
case (CalendarIntervalType | _: DayTimeIntervalType, _) => Cast(TimeAdd(r, l), r.dataType)
case (DateType, dt) if dt != StringType => DateAdd(l, r)
case (dt, DateType) if dt != StringType => DateAdd(r, l)
case _ => a
}
case s @ Subtract(l, r, mode) => (l.dataType, r.dataType) match {
case (DateType, DayTimeIntervalType(DAY, DAY)) =>
DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), mode == EvalMode.ANSI))
case (DateType, _: DayTimeIntervalType) =>
DatetimeSub(l, r, TimeAdd(Cast(l, TimestampType), UnaryMinus(r, mode == EvalMode.ANSI)))
case (DateType, _: YearMonthIntervalType) =>
DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI)))
case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) =>
DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI)))
case (CalendarIntervalType, CalendarIntervalType) |
(_: DayTimeIntervalType, _: DayTimeIntervalType) => s
case (_: NullType, _: AnsiIntervalType) =>
s.copy(left = Cast(s.left, s.right.dataType))
case (_: AnsiIntervalType, _: NullType) =>
s.copy(right = Cast(s.right, s.left.dataType))
case (DateType, CalendarIntervalType) =>
DatetimeSub(l, r, DateAddInterval(l,
UnaryMinus(r, mode == EvalMode.ANSI), ansiEnabled = mode == EvalMode.ANSI))
case (_, CalendarIntervalType | _: DayTimeIntervalType) =>
Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, mode == EvalMode.ANSI))), l.dataType)
case _ if AnyTimestampTypeExpression.unapply(l) ||
AnyTimestampTypeExpression.unapply(r) => SubtractTimestamps(l, r)
case (_, DateType) => SubtractDates(l, r)
case (DateType, dt) if dt != StringType => DateSub(l, r)
case _ => s
}
case m @ Multiply(l, r, mode) => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => MultiplyInterval(l, r, mode == EvalMode.ANSI)
case (_, CalendarIntervalType) => MultiplyInterval(r, l, mode == EvalMode.ANSI)
case (_: YearMonthIntervalType, _) => MultiplyYMInterval(l, r)
case (_, _: YearMonthIntervalType) => MultiplyYMInterval(r, l)
case (_: DayTimeIntervalType, _) => MultiplyDTInterval(l, r)
case (_, _: DayTimeIntervalType) => MultiplyDTInterval(r, l)
case _ => m
}
case d @ Divide(l, r, mode) => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => DivideInterval(l, r, mode == EvalMode.ANSI)
case (_: YearMonthIntervalType, _) => DivideYMInterval(l, r)
case (_: DayTimeIntervalType, _) => DivideDTInterval(l, r)
case _ => d
}
}
}

/**
Expand All @@ -505,34 +510,39 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
*/
object ResolveAliases extends Rule[LogicalPlan] {
private def assignAliases(exprs: Seq[NamedExpression]) = {
def extractOnly(e: Expression): Boolean = e match {
case _: ExtractValue => e.children.forall(extractOnly)
case _: Literal => true
case _: Attribute => true
case _ => false
}
exprs.map(_.transformUpWithPruning(_.containsPattern(UNRESOLVED_ALIAS)) {
case u @ UnresolvedAlias(child, optGenAliasFunc) =>
child match {
case ne: NamedExpression => ne
case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil)
case e if !e.resolved => u
case g: Generator => MultiAlias(g, Nil)
case c @ Cast(ne: NamedExpression, _, _, _) => Alias(c, ne.name)()
case e: ExtractValue if extractOnly(e) => Alias(e, toPrettySQL(e))()
case e if optGenAliasFunc.isDefined =>
Alias(child, optGenAliasFunc.get.apply(e))()
case l: Literal => Alias(l, toPrettySQL(l))()
case e =>
val metaForAutoGeneratedAlias = new MetadataBuilder()
.putString(AUTO_GENERATED_ALIAS, "true")
.build()
Alias(e, toPrettySQL(e))(explicitMetadata = Some(metaForAutoGeneratedAlias))
}
case u: UnresolvedAlias => resolve(u)
}
).asInstanceOf[Seq[NamedExpression]]
}

private[analysis] def resolve(u: UnresolvedAlias): Expression = {
val UnresolvedAlias(child, optGenAliasFunc) = u
child match {
case ne: NamedExpression => ne
case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil)
case e if !e.resolved => u
case g: Generator => MultiAlias(g, Nil)
case c @ Cast(ne: NamedExpression, _, _, _) => Alias(c, ne.name)()
case e: ExtractValue if extractOnly(e) => Alias(e, toPrettySQL(e))()
case e if optGenAliasFunc.isDefined =>
Alias(child, optGenAliasFunc.get.apply(e))()
case l: Literal => Alias(l, toPrettySQL(l))()
case e =>
val metaForAutoGeneratedAlias = new MetadataBuilder()
.putString(AUTO_GENERATED_ALIAS, "true")
.build()
Alias(e, toPrettySQL(e))(explicitMetadata = Some(metaForAutoGeneratedAlias))
}
}

private def extractOnly(e: Expression): Boolean = e match {
case _: ExtractValue => e.children.forall(extractOnly)
case _: Literal => true
case _: Attribute => true
case _ => false
}

private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) =
exprs.exists(_.exists(_.isInstanceOf[UnresolvedAlias]))

Expand Down Expand Up @@ -2299,16 +2309,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}

case u @ UnresolvedFunction(nameParts, arguments, _, _, _, _, _) => withPosition(u) {
resolveBuiltinOrTempFunction(nameParts, arguments, u).getOrElse {
val CatalogAndIdentifier(catalog, ident) = expandIdentifier(nameParts)
if (CatalogV2Util.isSessionCatalog(catalog)) {
resolveV1Function(ident.asFunctionIdentifier, arguments, u)
} else {
resolveV2Function(catalog.asFunctionCatalog, ident, arguments, u)
}
}
}
case u: UnresolvedFunction => resolveFunction(u)

case u: UnresolvedPolymorphicPythonUDTF => withPosition(u) {
// Check if this is a call to a Python user-defined table function whose polymorphic
Expand All @@ -2332,6 +2333,19 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}

private[analysis] def resolveFunction(u: UnresolvedFunction): Expression = {
withPosition(u) {
resolveBuiltinOrTempFunction(u.nameParts, u.arguments, u).getOrElse {
val CatalogAndIdentifier(catalog, ident) = expandIdentifier(u.nameParts)
if (CatalogV2Util.isSessionCatalog(catalog)) {
resolveV1Function(ident.asFunctionIdentifier, u.arguments, u)
} else {
resolveV2Function(catalog.asFunctionCatalog, ident, u.arguments, u)
}
}
}
}

/**
* Check if the arguments of a function are either resolved or a lambda function.
*/
Expand Down

0 comments on commit 5166cb3

Please sign in to comment.