Skip to content

Commit 0243b32

Browse files
committed
[SPARK-17272][SQL] Move subquery optimizer rules into its own file
## What changes were proposed in this pull request? As part of breaking Optimizer.scala apart, this patch moves various subquery rules into a single file. ## How was this patch tested? This should be covered by existing tests. Author: Reynold Xin <rxin@databricks.com> Closes apache#14844 from rxin/SPARK-17272.
1 parent dcefac4 commit 0243b32

File tree

2 files changed

+356
-323
lines changed

2 files changed

+356
-323
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 0 additions & 323 deletions
Original file line numberDiff line numberDiff line change
@@ -1637,326 +1637,3 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
16371637
a.copy(groupingExpressions = newGrouping)
16381638
}
16391639
}
1640-
1641-
/**
1642-
* This rule rewrites predicate sub-queries into left semi/anti joins. The following predicates
1643-
* are supported:
1644-
* a. EXISTS/NOT EXISTS will be rewritten as semi/anti join, unresolved conditions in Filter
1645-
* will be pulled out as the join conditions.
1646-
* b. IN/NOT IN will be rewritten as semi/anti join, unresolved conditions in the Filter will
1647-
* be pulled out as join conditions, value = selected column will also be used as join
1648-
* condition.
1649-
*/
1650-
object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
1651-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1652-
case Filter(condition, child) =>
1653-
val (withSubquery, withoutSubquery) =
1654-
splitConjunctivePredicates(condition).partition(PredicateSubquery.hasPredicateSubquery)
1655-
1656-
// Construct the pruned filter condition.
1657-
val newFilter: LogicalPlan = withoutSubquery match {
1658-
case Nil => child
1659-
case conditions => Filter(conditions.reduce(And), child)
1660-
}
1661-
1662-
// Filter the plan by applying left semi and left anti joins.
1663-
withSubquery.foldLeft(newFilter) {
1664-
case (p, PredicateSubquery(sub, conditions, _, _)) =>
1665-
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
1666-
Join(outerPlan, sub, LeftSemi, joinCond)
1667-
case (p, Not(PredicateSubquery(sub, conditions, false, _))) =>
1668-
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
1669-
Join(outerPlan, sub, LeftAnti, joinCond)
1670-
case (p, Not(PredicateSubquery(sub, conditions, true, _))) =>
1671-
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
1672-
// Construct the condition. A NULL in one of the conditions is regarded as a positive
1673-
// result; such a row will be filtered out by the Anti-Join operator.
1674-
1675-
// Note that will almost certainly be planned as a Broadcast Nested Loop join.
1676-
// Use EXISTS if performance matters to you.
1677-
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
1678-
val anyNull = splitConjunctivePredicates(joinCond.get).map(IsNull).reduceLeft(Or)
1679-
Join(outerPlan, sub, LeftAnti, Option(Or(anyNull, joinCond.get)))
1680-
case (p, predicate) =>
1681-
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
1682-
Project(p.output, Filter(newCond.get, inputPlan))
1683-
}
1684-
}
1685-
1686-
/**
1687-
* Given a predicate expression and an input plan, it rewrites
1688-
* any embedded existential sub-query into an existential join.
1689-
* It returns the rewritten expression together with the updated plan.
1690-
* Currently, it does not support null-aware joins. Embedded NOT IN predicates
1691-
* are blocked in the Analyzer.
1692-
*/
1693-
private def rewriteExistentialExpr(
1694-
exprs: Seq[Expression],
1695-
plan: LogicalPlan): (Option[Expression], LogicalPlan) = {
1696-
var newPlan = plan
1697-
val newExprs = exprs.map { e =>
1698-
e transformUp {
1699-
case PredicateSubquery(sub, conditions, nullAware, _) =>
1700-
// TODO: support null-aware join
1701-
val exists = AttributeReference("exists", BooleanType, nullable = false)()
1702-
newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
1703-
exists
1704-
}
1705-
}
1706-
(newExprs.reduceOption(And), newPlan)
1707-
}
1708-
}
1709-
1710-
/**
1711-
* This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT OUTER joins.
1712-
*/
1713-
object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
1714-
/**
1715-
* Extract all correlated scalar subqueries from an expression. The subqueries are collected using
1716-
* the given collector. The expression is rewritten and returned.
1717-
*/
1718-
private def extractCorrelatedScalarSubqueries[E <: Expression](
1719-
expression: E,
1720-
subqueries: ArrayBuffer[ScalarSubquery]): E = {
1721-
val newExpression = expression transform {
1722-
case s: ScalarSubquery if s.children.nonEmpty =>
1723-
subqueries += s
1724-
s.plan.output.head
1725-
}
1726-
newExpression.asInstanceOf[E]
1727-
}
1728-
1729-
/**
1730-
* Statically evaluate an expression containing zero or more placeholders, given a set
1731-
* of bindings for placeholder values.
1732-
*/
1733-
private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = {
1734-
val rewrittenExpr = expr transform {
1735-
case r: AttributeReference =>
1736-
bindings(r.exprId) match {
1737-
case Some(v) => Literal.create(v, r.dataType)
1738-
case None => Literal.default(NullType)
1739-
}
1740-
}
1741-
Option(rewrittenExpr.eval())
1742-
}
1743-
1744-
/**
1745-
* Statically evaluate an expression containing one or more aggregates on an empty input.
1746-
*/
1747-
private def evalAggOnZeroTups(expr: Expression) : Option[Any] = {
1748-
// AggregateExpressions are Unevaluable, so we need to replace all aggregates
1749-
// in the expression with the value they would return for zero input tuples.
1750-
// Also replace attribute refs (for example, for grouping columns) with NULL.
1751-
val rewrittenExpr = expr transform {
1752-
case a @ AggregateExpression(aggFunc, _, _, resultId) =>
1753-
aggFunc.defaultResult.getOrElse(Literal.default(NullType))
1754-
1755-
case _: AttributeReference => Literal.default(NullType)
1756-
}
1757-
Option(rewrittenExpr.eval())
1758-
}
1759-
1760-
/**
1761-
* Statically evaluate a scalar subquery on an empty input.
1762-
*
1763-
* <b>WARNING:</b> This method only covers subqueries that pass the checks under
1764-
* [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in
1765-
* CheckAnalysis become less restrictive, this method will need to change.
1766-
*/
1767-
private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = {
1768-
// Inputs to this method will start with a chain of zero or more SubqueryAlias
1769-
// and Project operators, followed by an optional Filter, followed by an
1770-
// Aggregate. Traverse the operators recursively.
1771-
def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match {
1772-
case SubqueryAlias(_, child, _) => evalPlan(child)
1773-
case Filter(condition, child) =>
1774-
val bindings = evalPlan(child)
1775-
if (bindings.isEmpty) bindings
1776-
else {
1777-
val exprResult = evalExpr(condition, bindings).getOrElse(false)
1778-
.asInstanceOf[Boolean]
1779-
if (exprResult) bindings else Map.empty
1780-
}
1781-
1782-
case Project(projectList, child) =>
1783-
val bindings = evalPlan(child)
1784-
if (bindings.isEmpty) {
1785-
bindings
1786-
} else {
1787-
projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap
1788-
}
1789-
1790-
case Aggregate(_, aggExprs, _) =>
1791-
// Some of the expressions under the Aggregate node are the join columns
1792-
// for joining with the outer query block. Fill those expressions in with
1793-
// nulls and statically evaluate the remainder.
1794-
aggExprs.map {
1795-
case ref: AttributeReference => (ref.exprId, None)
1796-
case alias @ Alias(_: AttributeReference, _) => (alias.exprId, None)
1797-
case ne => (ne.exprId, evalAggOnZeroTups(ne))
1798-
}.toMap
1799-
1800-
case _ => sys.error(s"Unexpected operator in scalar subquery: $lp")
1801-
}
1802-
1803-
val resultMap = evalPlan(plan)
1804-
1805-
// By convention, the scalar subquery result is the leftmost field.
1806-
resultMap(plan.output.head.exprId)
1807-
}
1808-
1809-
/**
1810-
* Split the plan for a scalar subquery into the parts above the innermost query block
1811-
* (first part of returned value), the HAVING clause of the innermost query block
1812-
* (optional second part) and the parts below the HAVING CLAUSE (third part).
1813-
*/
1814-
private def splitSubquery(plan: LogicalPlan) : (Seq[LogicalPlan], Option[Filter], Aggregate) = {
1815-
val topPart = ArrayBuffer.empty[LogicalPlan]
1816-
var bottomPart: LogicalPlan = plan
1817-
while (true) {
1818-
bottomPart match {
1819-
case havingPart @ Filter(_, aggPart: Aggregate) =>
1820-
return (topPart, Option(havingPart), aggPart)
1821-
1822-
case aggPart: Aggregate =>
1823-
// No HAVING clause
1824-
return (topPart, None, aggPart)
1825-
1826-
case p @ Project(_, child) =>
1827-
topPart += p
1828-
bottomPart = child
1829-
1830-
case s @ SubqueryAlias(_, child, _) =>
1831-
topPart += s
1832-
bottomPart = child
1833-
1834-
case Filter(_, op) =>
1835-
sys.error(s"Correlated subquery has unexpected operator $op below filter")
1836-
1837-
case op @ _ => sys.error(s"Unexpected operator $op in correlated subquery")
1838-
}
1839-
}
1840-
1841-
sys.error("This line should be unreachable")
1842-
}
1843-
1844-
// Name of generated column used in rewrite below
1845-
val ALWAYS_TRUE_COLNAME = "alwaysTrue"
1846-
1847-
/**
1848-
* Construct a new child plan by left joining the given subqueries to a base plan.
1849-
*/
1850-
private def constructLeftJoins(
1851-
child: LogicalPlan,
1852-
subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = {
1853-
subqueries.foldLeft(child) {
1854-
case (currentChild, ScalarSubquery(query, conditions, _)) =>
1855-
val origOutput = query.output.head
1856-
1857-
val resultWithZeroTups = evalSubqueryOnZeroTups(query)
1858-
if (resultWithZeroTups.isEmpty) {
1859-
// CASE 1: Subquery guaranteed not to have the COUNT bug
1860-
Project(
1861-
currentChild.output :+ origOutput,
1862-
Join(currentChild, query, LeftOuter, conditions.reduceOption(And)))
1863-
} else {
1864-
// Subquery might have the COUNT bug. Add appropriate corrections.
1865-
val (topPart, havingNode, aggNode) = splitSubquery(query)
1866-
1867-
// The next two cases add a leading column to the outer join input to make it
1868-
// possible to distinguish between the case when no tuples join and the case
1869-
// when the tuple that joins contains null values.
1870-
// The leading column always has the value TRUE.
1871-
val alwaysTrueExprId = NamedExpression.newExprId
1872-
val alwaysTrueExpr = Alias(Literal.TrueLiteral,
1873-
ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId)
1874-
val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME,
1875-
BooleanType)(exprId = alwaysTrueExprId)
1876-
1877-
val aggValRef = query.output.head
1878-
1879-
if (havingNode.isEmpty) {
1880-
// CASE 2: Subquery with no HAVING clause
1881-
Project(
1882-
currentChild.output :+
1883-
Alias(
1884-
If(IsNull(alwaysTrueRef),
1885-
Literal.create(resultWithZeroTups.get, origOutput.dataType),
1886-
aggValRef), origOutput.name)(exprId = origOutput.exprId),
1887-
Join(currentChild,
1888-
Project(query.output :+ alwaysTrueExpr, query),
1889-
LeftOuter, conditions.reduceOption(And)))
1890-
1891-
} else {
1892-
// CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join.
1893-
// Need to modify any operators below the join to pass through all columns
1894-
// referenced in the HAVING clause.
1895-
var subqueryRoot: UnaryNode = aggNode
1896-
val havingInputs: Seq[NamedExpression] = aggNode.output
1897-
1898-
topPart.reverse.foreach {
1899-
case Project(projList, _) =>
1900-
subqueryRoot = Project(projList ++ havingInputs, subqueryRoot)
1901-
case s @ SubqueryAlias(alias, _, None) =>
1902-
subqueryRoot = SubqueryAlias(alias, subqueryRoot, None)
1903-
case op => sys.error(s"Unexpected operator $op in corelated subquery")
1904-
}
1905-
1906-
// CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups
1907-
// WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
1908-
// ELSE (aggregate value) END AS (original column name)
1909-
val caseExpr = Alias(CaseWhen(Seq(
1910-
(IsNull(alwaysTrueRef), Literal.create(resultWithZeroTups.get, origOutput.dataType)),
1911-
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
1912-
aggValRef),
1913-
origOutput.name)(exprId = origOutput.exprId)
1914-
1915-
Project(
1916-
currentChild.output :+ caseExpr,
1917-
Join(currentChild,
1918-
Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
1919-
LeftOuter, conditions.reduceOption(And)))
1920-
1921-
}
1922-
}
1923-
}
1924-
}
1925-
1926-
/**
1927-
* Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar
1928-
* subqueries.
1929-
*/
1930-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1931-
case a @ Aggregate(grouping, expressions, child) =>
1932-
val subqueries = ArrayBuffer.empty[ScalarSubquery]
1933-
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
1934-
if (subqueries.nonEmpty) {
1935-
// We currently only allow correlated subqueries in an aggregate if they are part of the
1936-
// grouping expressions. As a result we need to replace all the scalar subqueries in the
1937-
// grouping expressions by their result.
1938-
val newGrouping = grouping.map { e =>
1939-
subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
1940-
}
1941-
Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
1942-
} else {
1943-
a
1944-
}
1945-
case p @ Project(expressions, child) =>
1946-
val subqueries = ArrayBuffer.empty[ScalarSubquery]
1947-
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
1948-
if (subqueries.nonEmpty) {
1949-
Project(newExpressions, constructLeftJoins(child, subqueries))
1950-
} else {
1951-
p
1952-
}
1953-
case f @ Filter(condition, child) =>
1954-
val subqueries = ArrayBuffer.empty[ScalarSubquery]
1955-
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
1956-
if (subqueries.nonEmpty) {
1957-
Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
1958-
} else {
1959-
f
1960-
}
1961-
}
1962-
}

0 commit comments

Comments
 (0)