@@ -1637,326 +1637,3 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
1637
1637
a.copy(groupingExpressions = newGrouping)
1638
1638
}
1639
1639
}
1640
-
1641
- /**
1642
- * This rule rewrites predicate sub-queries into left semi/anti joins. The following predicates
1643
- * are supported:
1644
- * a. EXISTS/NOT EXISTS will be rewritten as semi/anti join, unresolved conditions in Filter
1645
- * will be pulled out as the join conditions.
1646
- * b. IN/NOT IN will be rewritten as semi/anti join, unresolved conditions in the Filter will
1647
- * be pulled out as join conditions, value = selected column will also be used as join
1648
- * condition.
1649
- */
1650
- object RewritePredicateSubquery extends Rule [LogicalPlan ] with PredicateHelper {
1651
- def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
1652
- case Filter (condition, child) =>
1653
- val (withSubquery, withoutSubquery) =
1654
- splitConjunctivePredicates(condition).partition(PredicateSubquery .hasPredicateSubquery)
1655
-
1656
- // Construct the pruned filter condition.
1657
- val newFilter : LogicalPlan = withoutSubquery match {
1658
- case Nil => child
1659
- case conditions => Filter (conditions.reduce(And ), child)
1660
- }
1661
-
1662
- // Filter the plan by applying left semi and left anti joins.
1663
- withSubquery.foldLeft(newFilter) {
1664
- case (p, PredicateSubquery (sub, conditions, _, _)) =>
1665
- val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
1666
- Join (outerPlan, sub, LeftSemi , joinCond)
1667
- case (p, Not (PredicateSubquery (sub, conditions, false , _))) =>
1668
- val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
1669
- Join (outerPlan, sub, LeftAnti , joinCond)
1670
- case (p, Not (PredicateSubquery (sub, conditions, true , _))) =>
1671
- // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
1672
- // Construct the condition. A NULL in one of the conditions is regarded as a positive
1673
- // result; such a row will be filtered out by the Anti-Join operator.
1674
-
1675
- // Note that will almost certainly be planned as a Broadcast Nested Loop join.
1676
- // Use EXISTS if performance matters to you.
1677
- val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
1678
- val anyNull = splitConjunctivePredicates(joinCond.get).map(IsNull ).reduceLeft(Or )
1679
- Join (outerPlan, sub, LeftAnti , Option (Or (anyNull, joinCond.get)))
1680
- case (p, predicate) =>
1681
- val (newCond, inputPlan) = rewriteExistentialExpr(Seq (predicate), p)
1682
- Project (p.output, Filter (newCond.get, inputPlan))
1683
- }
1684
- }
1685
-
1686
- /**
1687
- * Given a predicate expression and an input plan, it rewrites
1688
- * any embedded existential sub-query into an existential join.
1689
- * It returns the rewritten expression together with the updated plan.
1690
- * Currently, it does not support null-aware joins. Embedded NOT IN predicates
1691
- * are blocked in the Analyzer.
1692
- */
1693
- private def rewriteExistentialExpr (
1694
- exprs : Seq [Expression ],
1695
- plan : LogicalPlan ): (Option [Expression ], LogicalPlan ) = {
1696
- var newPlan = plan
1697
- val newExprs = exprs.map { e =>
1698
- e transformUp {
1699
- case PredicateSubquery (sub, conditions, nullAware, _) =>
1700
- // TODO: support null-aware join
1701
- val exists = AttributeReference (" exists" , BooleanType , nullable = false )()
1702
- newPlan = Join (newPlan, sub, ExistenceJoin (exists), conditions.reduceLeftOption(And ))
1703
- exists
1704
- }
1705
- }
1706
- (newExprs.reduceOption(And ), newPlan)
1707
- }
1708
- }
1709
-
1710
- /**
1711
- * This rule rewrites correlated [[ScalarSubquery ]] expressions into LEFT OUTER joins.
1712
- */
1713
- object RewriteCorrelatedScalarSubquery extends Rule [LogicalPlan ] {
1714
- /**
1715
- * Extract all correlated scalar subqueries from an expression. The subqueries are collected using
1716
- * the given collector. The expression is rewritten and returned.
1717
- */
1718
- private def extractCorrelatedScalarSubqueries [E <: Expression ](
1719
- expression : E ,
1720
- subqueries : ArrayBuffer [ScalarSubquery ]): E = {
1721
- val newExpression = expression transform {
1722
- case s : ScalarSubquery if s.children.nonEmpty =>
1723
- subqueries += s
1724
- s.plan.output.head
1725
- }
1726
- newExpression.asInstanceOf [E ]
1727
- }
1728
-
1729
- /**
1730
- * Statically evaluate an expression containing zero or more placeholders, given a set
1731
- * of bindings for placeholder values.
1732
- */
1733
- private def evalExpr (expr : Expression , bindings : Map [ExprId , Option [Any ]]) : Option [Any ] = {
1734
- val rewrittenExpr = expr transform {
1735
- case r : AttributeReference =>
1736
- bindings(r.exprId) match {
1737
- case Some (v) => Literal .create(v, r.dataType)
1738
- case None => Literal .default(NullType )
1739
- }
1740
- }
1741
- Option (rewrittenExpr.eval())
1742
- }
1743
-
1744
- /**
1745
- * Statically evaluate an expression containing one or more aggregates on an empty input.
1746
- */
1747
- private def evalAggOnZeroTups (expr : Expression ) : Option [Any ] = {
1748
- // AggregateExpressions are Unevaluable, so we need to replace all aggregates
1749
- // in the expression with the value they would return for zero input tuples.
1750
- // Also replace attribute refs (for example, for grouping columns) with NULL.
1751
- val rewrittenExpr = expr transform {
1752
- case a @ AggregateExpression (aggFunc, _, _, resultId) =>
1753
- aggFunc.defaultResult.getOrElse(Literal .default(NullType ))
1754
-
1755
- case _ : AttributeReference => Literal .default(NullType )
1756
- }
1757
- Option (rewrittenExpr.eval())
1758
- }
1759
-
1760
- /**
1761
- * Statically evaluate a scalar subquery on an empty input.
1762
- *
1763
- * <b>WARNING:</b> This method only covers subqueries that pass the checks under
1764
- * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis ]]. If the checks in
1765
- * CheckAnalysis become less restrictive, this method will need to change.
1766
- */
1767
- private def evalSubqueryOnZeroTups (plan : LogicalPlan ) : Option [Any ] = {
1768
- // Inputs to this method will start with a chain of zero or more SubqueryAlias
1769
- // and Project operators, followed by an optional Filter, followed by an
1770
- // Aggregate. Traverse the operators recursively.
1771
- def evalPlan (lp : LogicalPlan ) : Map [ExprId , Option [Any ]] = lp match {
1772
- case SubqueryAlias (_, child, _) => evalPlan(child)
1773
- case Filter (condition, child) =>
1774
- val bindings = evalPlan(child)
1775
- if (bindings.isEmpty) bindings
1776
- else {
1777
- val exprResult = evalExpr(condition, bindings).getOrElse(false )
1778
- .asInstanceOf [Boolean ]
1779
- if (exprResult) bindings else Map .empty
1780
- }
1781
-
1782
- case Project (projectList, child) =>
1783
- val bindings = evalPlan(child)
1784
- if (bindings.isEmpty) {
1785
- bindings
1786
- } else {
1787
- projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap
1788
- }
1789
-
1790
- case Aggregate (_, aggExprs, _) =>
1791
- // Some of the expressions under the Aggregate node are the join columns
1792
- // for joining with the outer query block. Fill those expressions in with
1793
- // nulls and statically evaluate the remainder.
1794
- aggExprs.map {
1795
- case ref : AttributeReference => (ref.exprId, None )
1796
- case alias @ Alias (_ : AttributeReference , _) => (alias.exprId, None )
1797
- case ne => (ne.exprId, evalAggOnZeroTups(ne))
1798
- }.toMap
1799
-
1800
- case _ => sys.error(s " Unexpected operator in scalar subquery: $lp" )
1801
- }
1802
-
1803
- val resultMap = evalPlan(plan)
1804
-
1805
- // By convention, the scalar subquery result is the leftmost field.
1806
- resultMap(plan.output.head.exprId)
1807
- }
1808
-
1809
- /**
1810
- * Split the plan for a scalar subquery into the parts above the innermost query block
1811
- * (first part of returned value), the HAVING clause of the innermost query block
1812
- * (optional second part) and the parts below the HAVING CLAUSE (third part).
1813
- */
1814
- private def splitSubquery (plan : LogicalPlan ) : (Seq [LogicalPlan ], Option [Filter ], Aggregate ) = {
1815
- val topPart = ArrayBuffer .empty[LogicalPlan ]
1816
- var bottomPart : LogicalPlan = plan
1817
- while (true ) {
1818
- bottomPart match {
1819
- case havingPart @ Filter (_, aggPart : Aggregate ) =>
1820
- return (topPart, Option (havingPart), aggPart)
1821
-
1822
- case aggPart : Aggregate =>
1823
- // No HAVING clause
1824
- return (topPart, None , aggPart)
1825
-
1826
- case p @ Project (_, child) =>
1827
- topPart += p
1828
- bottomPart = child
1829
-
1830
- case s @ SubqueryAlias (_, child, _) =>
1831
- topPart += s
1832
- bottomPart = child
1833
-
1834
- case Filter (_, op) =>
1835
- sys.error(s " Correlated subquery has unexpected operator $op below filter " )
1836
-
1837
- case op @ _ => sys.error(s " Unexpected operator $op in correlated subquery " )
1838
- }
1839
- }
1840
-
1841
- sys.error(" This line should be unreachable" )
1842
- }
1843
-
1844
- // Name of generated column used in rewrite below
1845
- val ALWAYS_TRUE_COLNAME = " alwaysTrue"
1846
-
1847
- /**
1848
- * Construct a new child plan by left joining the given subqueries to a base plan.
1849
- */
1850
- private def constructLeftJoins (
1851
- child : LogicalPlan ,
1852
- subqueries : ArrayBuffer [ScalarSubquery ]): LogicalPlan = {
1853
- subqueries.foldLeft(child) {
1854
- case (currentChild, ScalarSubquery (query, conditions, _)) =>
1855
- val origOutput = query.output.head
1856
-
1857
- val resultWithZeroTups = evalSubqueryOnZeroTups(query)
1858
- if (resultWithZeroTups.isEmpty) {
1859
- // CASE 1: Subquery guaranteed not to have the COUNT bug
1860
- Project (
1861
- currentChild.output :+ origOutput,
1862
- Join (currentChild, query, LeftOuter , conditions.reduceOption(And )))
1863
- } else {
1864
- // Subquery might have the COUNT bug. Add appropriate corrections.
1865
- val (topPart, havingNode, aggNode) = splitSubquery(query)
1866
-
1867
- // The next two cases add a leading column to the outer join input to make it
1868
- // possible to distinguish between the case when no tuples join and the case
1869
- // when the tuple that joins contains null values.
1870
- // The leading column always has the value TRUE.
1871
- val alwaysTrueExprId = NamedExpression .newExprId
1872
- val alwaysTrueExpr = Alias (Literal .TrueLiteral ,
1873
- ALWAYS_TRUE_COLNAME )(exprId = alwaysTrueExprId)
1874
- val alwaysTrueRef = AttributeReference (ALWAYS_TRUE_COLNAME ,
1875
- BooleanType )(exprId = alwaysTrueExprId)
1876
-
1877
- val aggValRef = query.output.head
1878
-
1879
- if (havingNode.isEmpty) {
1880
- // CASE 2: Subquery with no HAVING clause
1881
- Project (
1882
- currentChild.output :+
1883
- Alias (
1884
- If (IsNull (alwaysTrueRef),
1885
- Literal .create(resultWithZeroTups.get, origOutput.dataType),
1886
- aggValRef), origOutput.name)(exprId = origOutput.exprId),
1887
- Join (currentChild,
1888
- Project (query.output :+ alwaysTrueExpr, query),
1889
- LeftOuter , conditions.reduceOption(And )))
1890
-
1891
- } else {
1892
- // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join.
1893
- // Need to modify any operators below the join to pass through all columns
1894
- // referenced in the HAVING clause.
1895
- var subqueryRoot : UnaryNode = aggNode
1896
- val havingInputs : Seq [NamedExpression ] = aggNode.output
1897
-
1898
- topPart.reverse.foreach {
1899
- case Project (projList, _) =>
1900
- subqueryRoot = Project (projList ++ havingInputs, subqueryRoot)
1901
- case s @ SubqueryAlias (alias, _, None ) =>
1902
- subqueryRoot = SubqueryAlias (alias, subqueryRoot, None )
1903
- case op => sys.error(s " Unexpected operator $op in corelated subquery " )
1904
- }
1905
-
1906
- // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups
1907
- // WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
1908
- // ELSE (aggregate value) END AS (original column name)
1909
- val caseExpr = Alias (CaseWhen (Seq (
1910
- (IsNull (alwaysTrueRef), Literal .create(resultWithZeroTups.get, origOutput.dataType)),
1911
- (Not (havingNode.get.condition), Literal .create(null , aggValRef.dataType))),
1912
- aggValRef),
1913
- origOutput.name)(exprId = origOutput.exprId)
1914
-
1915
- Project (
1916
- currentChild.output :+ caseExpr,
1917
- Join (currentChild,
1918
- Project (subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
1919
- LeftOuter , conditions.reduceOption(And )))
1920
-
1921
- }
1922
- }
1923
- }
1924
- }
1925
-
1926
- /**
1927
- * Rewrite [[Filter ]], [[Project ]] and [[Aggregate ]] plans containing correlated scalar
1928
- * subqueries.
1929
- */
1930
- def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
1931
- case a @ Aggregate (grouping, expressions, child) =>
1932
- val subqueries = ArrayBuffer .empty[ScalarSubquery ]
1933
- val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
1934
- if (subqueries.nonEmpty) {
1935
- // We currently only allow correlated subqueries in an aggregate if they are part of the
1936
- // grouping expressions. As a result we need to replace all the scalar subqueries in the
1937
- // grouping expressions by their result.
1938
- val newGrouping = grouping.map { e =>
1939
- subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
1940
- }
1941
- Aggregate (newGrouping, newExpressions, constructLeftJoins(child, subqueries))
1942
- } else {
1943
- a
1944
- }
1945
- case p @ Project (expressions, child) =>
1946
- val subqueries = ArrayBuffer .empty[ScalarSubquery ]
1947
- val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
1948
- if (subqueries.nonEmpty) {
1949
- Project (newExpressions, constructLeftJoins(child, subqueries))
1950
- } else {
1951
- p
1952
- }
1953
- case f @ Filter (condition, child) =>
1954
- val subqueries = ArrayBuffer .empty[ScalarSubquery ]
1955
- val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
1956
- if (subqueries.nonEmpty) {
1957
- Project (f.output, Filter (newCondition, constructLeftJoins(child, subqueries)))
1958
- } else {
1959
- f
1960
- }
1961
- }
1962
- }
0 commit comments