From 9d3909babb46784a47f337b3005531b166e69e88 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Fri, 18 Oct 2024 06:00:03 +0800 Subject: [PATCH] Support `RelationSubquery` PPL (#775) * Support RelationSubquery PPL Signed-off-by: Lantao Jin * fix doc Signed-off-by: Lantao Jin * revert the FROM alias Signed-off-by: Lantao Jin * add the case for subquery in search filter Signed-off-by: Lantao Jin --------- Signed-off-by: Lantao Jin --- docs/ppl-lang/PPL-Example-Commands.md | 30 ++- docs/ppl-lang/ppl-search-command.md | 2 +- docs/ppl-lang/ppl-subquery-command.md | 112 +++++++-- .../FlintSparkPPLExistsSubqueryITSuite.scala | 73 ++++++ .../ppl/FlintSparkPPLInSubqueryITSuite.scala | 74 ++++++ .../spark/ppl/FlintSparkPPLJoinITSuite.scala | 190 ++++++++++++++- .../FlintSparkPPLScalarSubqueryITSuite.scala | 82 ++++++- .../src/main/antlr4/OpenSearchPPLParser.g4 | 28 ++- .../org/opensearch/sql/ast/tree/Relation.java | 13 +- .../opensearch/sql/ppl/parser/AstBuilder.java | 24 +- ...PLLogicalPlanJoinTranslatorTestSuite.scala | 228 +++++++++++++++++- 11 files changed, 798 insertions(+), 58 deletions(-) diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 8e6cbaae9..96eeef726 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -240,8 +240,7 @@ source = table | where ispresent(a) | - `source = table1 | cross join left = l right = r table2` - `source = table1 | left semi join left = l right = r on l.a = r.a table2` - `source = table1 | left anti join left = l right = r on l.a = r.a table2` - -_- **Limitation: sub-searches is unsupported in join right side now**_ +- `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` #### **Lookup** @@ -268,6 +267,8 @@ _- **Limitation: "REPLACE" or "APPEND" clause must contain "AS"**_ - `source = outer | where a not in [ source = inner | fields b ]` - `source = outer | where (a) not in [ source = inner | fields b ]` - `source = outer | where (a,b,c) not in [ source = inner | fields d,e,f ]` +- `source = outer a in [ source = inner | fields b ]` (search filtering with subquery) +- `source = outer a not in [ source = inner | fields b ]` (search filtering with subquery) - `source = outer | where a in [ source = inner1 | where b not in [ source = inner2 | fields c ] | fields b ]` (nested) - `source = table1 | inner join left = l right = r on l.a = r.a AND r.a in [ source = inner | fields d ] | fields l.a, r.a, b, c` (as join filter) @@ -317,6 +318,9 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source = outer | where not exists [ source = inner | where a = c ]` - `source = outer | where exists [ source = inner | where a = c and b = d ]` - `source = outer | where not exists [ source = inner | where a = c and b = d ]` +- `source = outer exists [ source = inner | where a = c ]` (search filtering with subquery) +- `source = outer not exists [ source = inner | where a = c ]` (search filtering with subquery) +- `source = table as t1 exists [ source = table as t2 | where t1.a = t2.a ]` (table alias is useful in exists subquery) - `source = outer | where exists [ source = inner1 | where a = c and exists [ source = inner2 | where c = e ] ]` (nested) - `source = outer | where exists [ source = inner1 | where a = c | where exists [ source = inner2 | where c = e ] ]` (nested) - `source = outer | where exists [ source = inner | where c > 10 ]` (uncorrelated exists) @@ -332,8 +336,13 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source = outer | eval m = [ source = inner | stats max(c) ] | fields m, a` - `source = outer | eval m = [ source = inner | stats max(c) ] + b | fields m, a` -**Uncorrelated scalar subquery in Select and Where** -- `source = outer | where a > [ source = inner | stats min(c) ] | eval m = [ source = inner | stats max(c) ] | fields m, a` +**Uncorrelated scalar subquery in Where** +- `source = outer | where a > [ source = inner | stats min(c) ] | fields a` +- `source = outer | where [ source = inner | stats min(c) ] > 0 | fields a` + +**Uncorrelated scalar subquery in Search filter** +- `source = outer a > [ source = inner | stats min(c) ] | fields a` +- `source = outer [ source = inner | stats min(c) ] > 0 | fields a` **Correlated scalar subquery in Select** - `source = outer | eval m = [ source = inner | where outer.b = inner.d | stats max(c) ] | fields m, a` @@ -345,10 +354,23 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source = outer | where a = [ source = inner | where b = d | stats max(c) ]` - `source = outer | where [ source = inner | where outer.b = inner.d OR inner.d = 1 | stats count() ] > 0 | fields a` +**Correlated scalar subquery in Search filter** +- `source = outer a = [ source = inner | where b = d | stats max(c) ]` +- `source = outer [ source = inner | where outer.b = inner.d OR inner.d = 1 | stats count() ] > 0 | fields a` + **Nested scalar subquery** - `source = outer | where a = [ source = inner | stats max(c) | sort c ] OR b = [ source = inner | where c = 1 | stats min(d) | sort d ]` - `source = outer | where a = [ source = inner | where c = [ source = nested | stats max(e) by f | sort f ] | stats max(d) by c | sort c | head 1 ]` +#### **(Relation) Subquery** +[See additional command details](ppl-subquery-command.md) + +`InSubquery`, `ExistsSubquery` and `ScalarSubquery` are all subquery expressions. But `RelationSubquery` is not a subquery expression, it is a subquery plan which is common used in Join or Search clause. + +- `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` (subquery in join right side) +- `source = [ source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ] | stats count(a) by b ] as outer | head 1` + +_- **Limitation: another command usage of (relation) subquery is in `appendcols` commands which is unsupported**_ --- #### Experimental Commands: diff --git a/docs/ppl-lang/ppl-search-command.md b/docs/ppl-lang/ppl-search-command.md index f81d9d907..bccfd04f0 100644 --- a/docs/ppl-lang/ppl-search-command.md +++ b/docs/ppl-lang/ppl-search-command.md @@ -32,7 +32,7 @@ The example show fetch all the document from accounts index with . PPL query: - os> source=accounts account_number=1 or gender="F"; + os> SEARCH source=accounts account_number=1 or gender="F"; +------------------+-------------+--------------------+-----------+----------+--------+------------+---------+-------+----------------------+------------+ | account_number | firstname | address | balance | gender | city | employer | state | age | email | lastname | |------------------+-------------+--------------------+-----------+----------+--------+------------+---------+-------+----------------------+------------| diff --git a/docs/ppl-lang/ppl-subquery-command.md b/docs/ppl-lang/ppl-subquery-command.md index ac0f98fe8..c4a0c337c 100644 --- a/docs/ppl-lang/ppl-subquery-command.md +++ b/docs/ppl-lang/ppl-subquery-command.md @@ -1,6 +1,6 @@ ## PPL SubQuery Commands: -**Syntax** +### Syntax The subquery command should be implemented using a clean, logical syntax that integrates with existing PPL structure. ```sql @@ -21,13 +21,15 @@ For additional info See [Issue](https://github.com/opensearch-project/opensearch --- -**InSubquery usage** +### InSubquery usage - `source = outer | where a in [ source = inner | fields b ]` - `source = outer | where (a) in [ source = inner | fields b ]` - `source = outer | where (a,b,c) in [ source = inner | fields d,e,f ]` - `source = outer | where a not in [ source = inner | fields b ]` - `source = outer | where (a) not in [ source = inner | fields b ]` - `source = outer | where (a,b,c) not in [ source = inner | fields d,e,f ]` +- `source = outer a in [ source = inner | fields b ]` (search filtering with subquery) +- `source = outer a not in [ source = inner | fields b ]` (search filtering with subquery) - `source = outer | where a in [ source = inner1 | where b not in [ source = inner2 | fields c ] | fields b ]` (nested) - `source = table1 | inner join left = l right = r on l.a = r.a AND r.a in [ source = inner | fields d ] | fields l.a, r.a, b, c` (as join filter) @@ -111,8 +113,9 @@ source = supplier nation | sort s_name ``` +--- -**ExistsSubquery usage** +### ExistsSubquery usage Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table inner2 @@ -120,6 +123,9 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source = outer | where not exists [ source = inner | where a = c ]` - `source = outer | where exists [ source = inner | where a = c and b = d ]` - `source = outer | where not exists [ source = inner | where a = c and b = d ]` +- `source = outer exists [ source = inner | where a = c ]` (search filtering with subquery) +- `source = outer not exists [ source = inner | where a = c ]` (search filtering with subquery) +- `source = table as t1 exists [ source = table as t2 | where t1.a = t2.a ]` (table alias is useful in exists subquery) - `source = outer | where exists [ source = inner1 | where a = c and exists [ source = inner2 | where c = e ] ]` (nested) - `source = outer | where exists [ source = inner1 | where a = c | where exists [ source = inner2 | where c = e ] ]` (nested) - `source = outer | where exists [ source = inner | where c > 10 ]` (uncorrelated exists) @@ -163,8 +169,9 @@ source = orders | sort o_orderpriority | fields o_orderpriority, order_count ``` +--- -**ScalarSubquery usage** +### ScalarSubquery usage Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table nested @@ -172,8 +179,11 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source = outer | eval m = [ source = inner | stats max(c) ] | fields m, a` - `source = outer | eval m = [ source = inner | stats max(c) ] + b | fields m, a` -**Uncorrelated scalar subquery in Select and Where** -- `source = outer | where a > [ source = inner | stats min(c) ] | eval m = [ source = inner | stats max(c) ] | fields m, a` +**Uncorrelated scalar subquery in Where** +- `source = outer | where a > [ source = inner | stats min(c) ] | fields a` + +**Uncorrelated scalar subquery in Search filter** +- `source = outer a > [ source = inner | stats min(c) ] | fields a` **Correlated scalar subquery in Select** - `source = outer | eval m = [ source = inner | where outer.b = inner.d | stats max(c) ] | fields m, a` @@ -185,6 +195,10 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source = outer | where a = [ source = inner | where b = d | stats max(c) ]` - `source = outer | where [ source = inner | where outer.b = inner.d OR inner.d = 1 | stats count() ] > 0 | fields a` +**Correlated scalar subquery in Search filter** +- `source = outer a = [ source = inner | where b = d | stats max(c) ]` +- `source = outer [ source = inner | where outer.b = inner.d OR inner.d = 1 | stats count() ] > 0 | fields a` + **Nested scalar subquery** - `source = outer | where a = [ source = inner | stats max(c) | sort c ] OR b = [ source = inner | where c = 1 | stats min(d) | sort d ]` - `source = outer | where a = [ source = inner | where c = [ source = nested | stats max(e) by f | sort f ] | stats max(d) by c | sort c | head 1 ]` @@ -240,27 +254,77 @@ source = spark_catalog.default.outer source = spark_catalog.default.inner | where c = 1 | stats min(d) | sort d ] ``` +--- + +### (Relation) Subquery +`InSubquery`, `ExistsSubquery` and `ScalarSubquery` are all subquery expressions. But `RelationSubquery` is not a subquery expression, it is a subquery plan which is common used in Join or From clause. + +- `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` (subquery in join right side) +- `source = [ source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ] | stats count(a) by b ] as outer | head 1` -### **Additional Context** +**_SQL Migration examples with Subquery PPL:_** + +tpch q13 +```sql +select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) as c_count + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders +group by + c_count +order by + custdist desc, + c_count desc +``` +Rewritten by PPL (Relation) Subquery: +```sql +SEARCH source = [ + SEARCH source = customer + | LEFT OUTER JOIN left = c right = o ON c_custkey = o_custkey + [ + SEARCH source = orders + | WHERE not like(o_comment, '%special%requests%') + ] + | STATS COUNT(o_orderkey) AS c_count BY c_custkey +] AS c_orders +| STATS COUNT(o_orderkey) AS c_count BY c_custkey +| STATS COUNT(1) AS custdist BY c_count +| SORT - custdist, - c_count +``` +--- -`InSubquery`, `ExistsSubquery` and `ScalarSubquery` are all subquery expression. The common usage of subquery expression is in `where` clause: +### Additional Context -The `where` command syntax is: +`InSubquery`, `ExistsSubquery` and `ScalarSubquery` as subquery expressions, their common usage is in `where` clause and `search filter`. +Where command: +``` +| where | ... ``` -| where +Search filter: ``` -So the subquery is part of boolean expression, such as +search source=* | ... +``` +A subquery expression could be used in boolean expression, for example ```sql -| where orders.order_id in (subquery source=returns | where return_reason="damaged" | return order_id) +| where orders.order_id in [ source=returns | where return_reason="damaged" | field order_id ] ``` -The `orders.order_id in (subquery source=...)` is a ``. - -In general, we name this kind of subquery clause the `InSubquery` expression, it is a ``, one kind of `subquery expressions`. +The `orders.order_id in [ source=... ]` is a ``. -PS: there are many kinds of `subquery expressions`, another commonly used one is `ScalarSubquery` expression: +In general, we name this kind of subquery clause the `InSubquery` expression, it is a ``. **Subquery with Different Join Types** @@ -326,4 +390,18 @@ source = outer | eval l = "nonEmpty" | fields l ``` -This query just print "nonEmpty" if the inner table is not empty. \ No newline at end of file +This query just print "nonEmpty" if the inner table is not empty. + +**Table alias in subquery** + +Table alias is useful in query which contains a subquery, for example + +```sql +select a, ( + select sum(b) + from catalog.schema.table1 as t1 + where t1.a = t2.a + ) sum_b + from catalog.schema.table2 as t2 +``` +`t1` and `t2` are table aliases which are used in correlated subquery, `sum_b` are subquery alias. diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExistsSubqueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExistsSubqueryITSuite.scala index 81bdd99df..8009015b1 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExistsSubqueryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExistsSubqueryITSuite.scala @@ -84,6 +84,44 @@ class FlintSparkPPLExistsSubqueryITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + test("test simple exists subquery in search filter") { + val frame = sql(s""" + | source = $outerTable exists [ source = $innerTable | where id = uid ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1002, "John", 120000), + Row(1003, "David", 120000), + Row(1000, "Jake", 100000), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val existsSubquery = Filter( + Exists(Filter(EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner)), + outer) + val sortedPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + existsSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + test("test not exists subquery") { val frame = sql(s""" | source = $outerTable @@ -122,6 +160,41 @@ class FlintSparkPPLExistsSubqueryITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + test("test not exists subquery in search filter") { + val frame = sql(s""" + | source = $outerTable not exists [ source = $innerTable | where id = uid ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1001, "Hello", 70000), Row(1004, "David", 0)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val existsSubquery = + Filter( + Not( + Exists(Filter(EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner))), + outer) + val sortedPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + existsSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + test("test empty exists subquery") { var frame = sql(s""" | source = $outerTable diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala index 9d8c2c12d..107390dff 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala @@ -87,6 +87,45 @@ class FlintSparkPPLInSubqueryITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + test("test filter id in (select uid from inner)") { + val frame = sql(s""" + source = $outerTable id in [ source = $innerTable | fields uid ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1003, "David", 120000), + Row(1002, "John", 120000), + Row(1000, "Jake", 100000), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("id")), + ListQuery(Project(Seq(UnresolvedAttribute("uid")), inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + test("test where (id) in (select uid from inner)") { // id (0, 1, 2, 3, 4, 5, 6), uid (0, 2, 3, 5, 6) // InSubquery: (0, 2, 3, 5, 6) @@ -214,6 +253,41 @@ class FlintSparkPPLInSubqueryITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + test("test filter id not in (select uid from inner)") { + val frame = sql(s""" + source = $outerTable id not in [ source = $innerTable | fields uid ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1001, "Hello", 70000), Row(1004, "David", 0)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inSubquery = + Filter( + Not( + InSubquery( + Seq(UnresolvedAttribute("id")), + ListQuery(Project(Seq(UnresolvedAttribute("uid")), inner)))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + test("test where (id, name) not in (select uid, name from inner)") { // Not InSubquery: (1, 4, 6) val frame = sql(s""" diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala index b276149a0..00e55d50a 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala @@ -7,9 +7,9 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Divide, EqualTo, Floor, LessThan, Literal, Multiply, Or, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Divide, EqualTo, Floor, GreaterThan, LessThan, Literal, Multiply, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, JoinHint, LogicalPlan, Project, Sort, SubqueryAlias} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, JoinHint, LocalLimit, LogicalPlan, Project, Sort, SubqueryAlias} import org.apache.spark.sql.streaming.StreamTest class FlintSparkPPLJoinITSuite @@ -738,4 +738,190 @@ class FlintSparkPPLJoinITSuite case j @ Join(_, _, Inner, _, JoinHint.NONE) => j }.size == 1) } + + test("test inner join with relation subquery") { + val frame = sql(s""" + | source = $testTable1 + | | where country = 'USA' OR country = 'England' + | | inner join left=a, right=b + | ON a.name = b.name + | [ + | source = $testTable2 + | | where salary > 0 + | | fields name, country, salary + | | sort salary + | | head 3 + | ] + | | stats avg(salary) by span(age, 10) as age_span, b.country + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(70000.0, "USA", 30), Row(100000.0, "England", 70)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val filterExpr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + EqualTo(UnresolvedAttribute("country"), Literal("England"))) + val plan1 = SubqueryAlias("a", Filter(filterExpr, table1)) + val rightSubquery = + GlobalLimit( + Literal(3), + LocalLimit( + Literal(3), + Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Ascending)), + global = true, + Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("country"), + UnresolvedAttribute("salary")), + Filter(GreaterThan(UnresolvedAttribute("salary"), Literal(0)), table2))))) + val plan2 = SubqueryAlias("b", rightSubquery) + + val joinCondition = EqualTo(UnresolvedAttribute("a.name"), UnresolvedAttribute("b.name")) + val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE) + + val salaryField = UnresolvedAttribute("salary") + val countryField = UnresolvedAttribute("b.country") + val countryAlias = Alias(countryField, "b.country")() + val star = Seq(UnresolvedStar(None)) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(salaryField), isDistinct = false), "avg(salary)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), joinPlan) + + val expectedPlan = Project(star, aggregatePlan) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test left outer join with relation subquery") { + val frame = sql(s""" + | source = $testTable1 + | | where country = 'USA' OR country = 'England' + | | left join left=a, right=b + | ON a.name = b.name + | [ + | source = $testTable2 + | | where salary > 0 + | | fields name, country, salary + | | sort salary + | | head 3 + | ] + | | stats avg(salary) by span(age, 10) as age_span, b.country + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(70000.0, "USA", 30), Row(100000.0, "England", 70), Row(null, null, 40)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val filterExpr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + EqualTo(UnresolvedAttribute("country"), Literal("England"))) + val plan1 = SubqueryAlias("a", Filter(filterExpr, table1)) + val rightSubquery = + GlobalLimit( + Literal(3), + LocalLimit( + Literal(3), + Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Ascending)), + global = true, + Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("country"), + UnresolvedAttribute("salary")), + Filter(GreaterThan(UnresolvedAttribute("salary"), Literal(0)), table2))))) + val plan2 = SubqueryAlias("b", rightSubquery) + + val joinCondition = EqualTo(UnresolvedAttribute("a.name"), UnresolvedAttribute("b.name")) + val joinPlan = Join(plan1, plan2, LeftOuter, Some(joinCondition), JoinHint.NONE) + + val salaryField = UnresolvedAttribute("salary") + val countryField = UnresolvedAttribute("b.country") + val countryAlias = Alias(countryField, "b.country")() + val star = Seq(UnresolvedStar(None)) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(salaryField), isDistinct = false), "avg(salary)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), joinPlan) + + val expectedPlan = Project(star, aggregatePlan) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with relation subquery") { + val frame = sql(s""" + | source = $testTable1 + | | where country = 'Canada' OR country = 'England' + | | inner join left=a, right=b + | ON a.name = b.name AND a.year = 2023 AND a.month = 4 AND b.year = 2023 AND b.month = 4 + | [ + | source = $testTable2 + | ] + | | eval a_name = a.name + | | eval a_country = a.country + | | eval b_country = b.country + | | fields a_name, age, state, a_country, occupation, b_country, salary + | | left join left=a, right=b + | ON a.a_name = b.name + | [ + | source = $testTable3 + | ] + | | eval aa_country = a.a_country + | | eval ab_country = a.b_country + | | eval bb_country = b.country + | | fields a_name, age, state, aa_country, occupation, ab_country, salary, bb_country, hobby, language + | | cross join left=a, right=b + | [ + | source = $testTable2 + | ] + | | eval new_country = a.aa_country + | | eval new_salary = b.salary + | | stats avg(new_salary) as avg_salary by span(age, 5) as age_span, state + | | left semi join left=a, right=b + | ON a.state = b.state + | [ + | source = $testTable1 + | ] + | | eval new_avg_salary = floor(avg_salary) + | | fields state, age_span, new_avg_salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row("Quebec", 20, 83333), Row("Ontario", 25, 83333)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + assert(frame.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, Cross, None, JoinHint.NONE) => j + }.size == 1) + assert(frame.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, LeftOuter, _, JoinHint.NONE) => j + }.size == 1) + assert(frame.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, Inner, _, JoinHint.NONE) => j + }.size == 1) + assert(frame.queryExecution.analyzed.collect { case s: SubqueryAlias => + s + }.size == 13) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala index 654add8d8..24b4d77e6 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala @@ -132,12 +132,12 @@ class FlintSparkPPLScalarSubqueryITSuite test("test uncorrelated scalar subquery in select and where") { val frame = sql(s""" | source = $outerTable - | | eval count_dept = [ - | source = $innerTable | stats count(department) - | ] | | where id > [ | source = $innerTable | stats count(department) | ] + 999 + | | eval count_dept = [ + | source = $innerTable | stats count(department) + | ] | | fields name, count_dept | """.stripMargin) val results: Array[Row] = frame.collect() @@ -160,13 +160,50 @@ class FlintSparkPPLScalarSubqueryITSuite val countScalarSubqueryExpr = ScalarSubquery(countAggPlan) val plusScalarSubquery = UnresolvedFunction(Seq("+"), Seq(countScalarSubqueryExpr, Literal(999)), isDistinct = false) + val filter = Filter(GreaterThan(UnresolvedAttribute("id"), plusScalarSubquery), outer) + val evalProjectList = + Seq(UnresolvedStar(None), Alias(countScalarSubqueryExpr, "count_dept")()) + val evalProject = Project(evalProjectList, filter) + val expectedPlan = + Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("count_dept")), evalProject) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test uncorrelated scalar subquery in select and from with filter") { + val frame = sql(s""" + | source = $outerTable id > [ source = $innerTable | stats count(department) ] + 999 + | | eval count_dept = [ + | source = $innerTable | stats count(department) + | ] + | | fields name, count_dept + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row("Jane", 5), Row("Tommy", 5)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val countAgg = Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("department")), + isDistinct = false), + "count(department)")()) + val countAggPlan = Aggregate(Seq(), countAgg, inner) + val countScalarSubqueryExpr = ScalarSubquery(countAggPlan) + val plusScalarSubquery = + UnresolvedFunction(Seq("+"), Seq(countScalarSubqueryExpr, Literal(999)), isDistinct = false) + val filter = Filter(GreaterThan(UnresolvedAttribute("id"), plusScalarSubquery), outer) val evalProjectList = Seq(UnresolvedStar(None), Alias(countScalarSubqueryExpr, "count_dept")()) - val evalProject = Project(evalProjectList, outer) - val filter = Filter(GreaterThan(UnresolvedAttribute("id"), plusScalarSubquery), evalProject) + val evalProject = Project(evalProjectList, filter) val expectedPlan = - Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("count_dept")), filter) + Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("count_dept")), evalProject) comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -302,6 +339,39 @@ class FlintSparkPPLScalarSubqueryITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test("test correlated scalar subquery in from with filter") { + val frame = sql(s""" + | source = $outerTable id = [ source = $innerTable | where id = uid | stats max(uid) ] + | | fields id, name + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1000, "Jake"), + Row(1002, "John"), + Row(1003, "David"), + Row(1005, "Jane"), + Row(1006, "Tommy")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("uid")), isDistinct = false), + "max(uid)")()) + val innerFilter = + Filter(EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, innerFilter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val outerFilter = Filter(EqualTo(UnresolvedAttribute("id"), scalarSubqueryExpr), outer) + val expectedPlan = + Project(Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), outerFilter) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } test("test disjunctive correlated scalar subquery") { val frame = sql(s""" diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 7a6f14839..c205fc236 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -247,17 +247,27 @@ mlArg // clauses fromClause - : SOURCE EQUAL tableSourceClause - | INDEX EQUAL tableSourceClause + : SOURCE EQUAL tableOrSubqueryClause + | INDEX EQUAL tableOrSubqueryClause ; +tableOrSubqueryClause + : LT_SQR_PRTHS subSearch RT_SQR_PRTHS (AS alias = qualifiedName)? + | tableSourceClause + ; + +// One tableSourceClause will generate one Relation node with/without one alias +// even if the relation contains more than one table sources. +// These table sources in one relation will be readed one by one in OpenSearch. +// But it may have different behaivours in different execution backends. +// For example, a Spark UnresovledRelation node only accepts one data source. tableSourceClause - : tableSource (COMMA tableSource)* + : tableSource (COMMA tableSource)* (AS alias = qualifiedName)? ; // join joinCommand - : (joinType) JOIN sideAlias joinHintList? joinCriteria? right = tableSource + : (joinType) JOIN sideAlias joinHintList? joinCriteria? right = tableOrSubqueryClause ; joinType @@ -279,13 +289,13 @@ joinCriteria ; joinHintList - : hintPair (COMMA? hintPair)* - ; + : hintPair (COMMA? hintPair)* + ; hintPair - : leftHintKey = LEFT_HINT DOT ID EQUAL leftHintValue = ident #leftHint - | rightHintKey = RIGHT_HINT DOT ID EQUAL rightHintValue = ident #rightHint - ; + : leftHintKey = LEFT_HINT DOT ID EQUAL leftHintValue = ident #leftHint + | rightHintKey = RIGHT_HINT DOT ID EQUAL rightHintValue = ident #rightHint + ; renameClasue : orignalField = wcFieldExpression AS renamedField = wcFieldExpression diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java index e1732f75f..1b30a7998 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java @@ -8,7 +8,9 @@ import com.google.common.collect.ImmutableList; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; +import lombok.Getter; import lombok.RequiredArgsConstructor; +import lombok.Setter; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.QualifiedName; @@ -38,7 +40,7 @@ public Relation(UnresolvedExpression tableName, String alias) { } /** Optional alias name for the relation. */ - private String alias; + @Setter @Getter private String alias; /** * Return table name. @@ -53,15 +55,6 @@ public List getQualifiedNames() { return tableName.stream().map(t -> (QualifiedName) t).collect(Collectors.toList()); } - /** - * Return alias. - * - * @return alias. - */ - public String getAlias() { - return alias; - } - /** * Get Qualified name preservs parts of the user given identifiers. This can later be utilized to * determine DataSource,Schema and Table Name during Analyzer stage. So Passing QualifiedName diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 8673b1582..1c0fe919f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -156,8 +156,12 @@ public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ct Join.JoinHint joinHint = getJoinHint(ctx.joinHintList()); String leftAlias = ctx.sideAlias().leftAlias.getText(); String rightAlias = ctx.sideAlias().rightAlias.getText(); - // TODO when sub-search is supported, this part need to change. Now relation is the only supported plan for right side - UnresolvedPlan right = new SubqueryAlias(rightAlias, new Relation(this.internalVisitExpression(ctx.tableSource()), rightAlias)); + if (ctx.tableOrSubqueryClause().alias != null) { + // left and right aliases are required in join syntax. Setting by 'AS' causes ambiguous + throw new SyntaxCheckException("'AS' is not allowed in right subquery, use right= instead"); + } + UnresolvedPlan rightRelation = visit(ctx.tableOrSubqueryClause()); + UnresolvedPlan right = new SubqueryAlias(rightAlias, rightRelation); Optional joinCondition = ctx.joinCriteria() == null ? Optional.empty() : Optional.of(expressionBuilder.visitJoinCriteria(ctx.joinCriteria())); @@ -451,16 +455,22 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct return aggregation; } - /** From clause. */ @Override - public UnresolvedPlan visitFromClause(OpenSearchPPLParser.FromClauseContext ctx) { - return visitTableSourceClause(ctx.tableSourceClause()); + public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubqueryClauseContext ctx) { + if (ctx.subSearch() != null) { + return ctx.alias != null + ? new SubqueryAlias(ctx.alias.getText(), visitSubSearch(ctx.subSearch())) + : visitSubSearch(ctx.subSearch()); + } else { + return visitTableSourceClause(ctx.tableSourceClause()); + } } @Override public UnresolvedPlan visitTableSourceClause(OpenSearchPPLParser.TableSourceClauseContext ctx) { - return new Relation( - ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())); + return ctx.alias == null + ? new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())) + : new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList()), ctx.alias.getText()); } @Override diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala index 58c1a8d12..3ceff7735 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala @@ -11,9 +11,9 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Descending, EqualTo, LessThan, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, EqualTo, GreaterThan, LessThan, Literal, Not, SortOrder} import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, JoinHint, Project, Sort, SubqueryAlias} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, JoinHint, LocalLimit, Project, Sort, SubqueryAlias} class PPLLogicalPlanJoinTranslatorTestSuite extends SparkFunSuite @@ -341,4 +341,228 @@ class PPLLogicalPlanJoinTranslatorTestSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), sort) comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + + test("test inner join with relation subquery") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1| JOIN left = l right = r ON l.id = r.id + | [ + | source = $testTable2 + | | where id > 10 and name = 'abc' + | | fields id, name + | | sort id + | | head 10 + | ] + | | stats count(id) as cnt by type + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val leftPlan = SubqueryAlias("l", table1) + val rightSubquery = + GlobalLimit( + Literal(10), + LocalLimit( + Literal(10), + Sort( + Seq(SortOrder(UnresolvedAttribute("id"), Ascending)), + global = true, + Project( + Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), + Filter( + And( + GreaterThan(UnresolvedAttribute("id"), Literal(10)), + EqualTo(UnresolvedAttribute("name"), Literal("abc"))), + table2))))) + val rightPlan = SubqueryAlias("r", rightSubquery) + val joinCondition = EqualTo(UnresolvedAttribute("l.id"), UnresolvedAttribute("r.id")) + val joinPlan = Join(leftPlan, rightPlan, Inner, Some(joinCondition), JoinHint.NONE) + val groupingExpression = Alias(UnresolvedAttribute("type"), "type")() + val aggregateExpression = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedAttribute("id")), isDistinct = false), + "cnt")() + val aggPlan = + Aggregate(Seq(groupingExpression), Seq(aggregateExpression, groupingExpression), joinPlan) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test left outer join with relation subquery") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1| LEFT JOIN left = l right = r ON l.id = r.id + | [ + | source = $testTable2 + | | where id > 10 and name = 'abc' + | | fields id, name + | | sort id + | | head 10 + | ] + | | stats count(id) as cnt by type + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val leftPlan = SubqueryAlias("l", table1) + val rightSubquery = + GlobalLimit( + Literal(10), + LocalLimit( + Literal(10), + Sort( + Seq(SortOrder(UnresolvedAttribute("id"), Ascending)), + global = true, + Project( + Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), + Filter( + And( + GreaterThan(UnresolvedAttribute("id"), Literal(10)), + EqualTo(UnresolvedAttribute("name"), Literal("abc"))), + table2))))) + val rightPlan = SubqueryAlias("r", rightSubquery) + val joinCondition = EqualTo(UnresolvedAttribute("l.id"), UnresolvedAttribute("r.id")) + val joinPlan = Join(leftPlan, rightPlan, LeftOuter, Some(joinCondition), JoinHint.NONE) + val groupingExpression = Alias(UnresolvedAttribute("type"), "type")() + val aggregateExpression = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedAttribute("id")), isDistinct = false), + "cnt")() + val aggPlan = + Aggregate(Seq(groupingExpression), Seq(aggregateExpression, groupingExpression), joinPlan) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with relation subquery") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | head 10 + | | inner JOIN left = l,right = r ON l.id = r.id + | [ + | source = $testTable2 + | | where id > 10 + | ] + | | left JOIN left = l,right = r ON l.name = r.name + | [ + | source = $testTable3 + | | fields id + | ] + | | cross JOIN left = l,right = r + | [ + | source = $testTable4 + | | sort id + | ] + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val table4 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test4")) + var leftPlan = SubqueryAlias("l", GlobalLimit(Literal(10), LocalLimit(Literal(10), table1))) + var rightPlan = + SubqueryAlias("r", Filter(GreaterThan(UnresolvedAttribute("id"), Literal(10)), table2)) + val joinCondition1 = EqualTo(UnresolvedAttribute("l.id"), UnresolvedAttribute("r.id")) + val joinPlan1 = Join(leftPlan, rightPlan, Inner, Some(joinCondition1), JoinHint.NONE) + leftPlan = SubqueryAlias("l", joinPlan1) + rightPlan = SubqueryAlias("r", Project(Seq(UnresolvedAttribute("id")), table3)) + val joinCondition2 = EqualTo(UnresolvedAttribute("l.name"), UnresolvedAttribute("r.name")) + val joinPlan2 = Join(leftPlan, rightPlan, LeftOuter, Some(joinCondition2), JoinHint.NONE) + leftPlan = SubqueryAlias("l", joinPlan2) + rightPlan = SubqueryAlias( + "r", + Sort(Seq(SortOrder(UnresolvedAttribute("id"), Ascending)), global = true, table4)) + val joinPlan3 = Join(leftPlan, rightPlan, Cross, None, JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test complex join: TPC-H Q13 with relation subquery") { + // select + // c_count, + // count(*) as custdist + // from + // ( + // select + // c_custkey, + // count(o_orderkey) as c_count + // from + // customer left outer join orders on + // c_custkey = o_custkey + // and o_comment not like '%special%requests%' + // group by + // c_custkey + // ) as c_orders + // group by + // c_count + // order by + // custdist desc, + // c_count desc + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | SEARCH source = [ + | SEARCH source = customer + | | LEFT OUTER JOIN left = c right = o ON c_custkey = o_custkey + | [ + | SEARCH source = orders + | | WHERE not like(o_comment, '%special%requests%') + | ] + | | STATS COUNT(o_orderkey) AS c_count BY c_custkey + | ] AS c_orders + | | STATS COUNT(o_orderkey) AS c_count BY c_custkey + | | STATS COUNT(1) AS custdist BY c_count + | | SORT - custdist, - c_count + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val tableC = UnresolvedRelation(Seq("customer")) + val tableO = UnresolvedRelation(Seq("orders")) + val left = SubqueryAlias("c", tableC) + val filterNot = Filter( + Not( + UnresolvedFunction( + Seq("like"), + Seq(UnresolvedAttribute("o_comment"), Literal("%special%requests%")), + isDistinct = false)), + tableO) + val right = SubqueryAlias("o", filterNot) + val joinCondition = + EqualTo(UnresolvedAttribute("o_custkey"), UnresolvedAttribute("c_custkey")) + val join = Join(left, right, LeftOuter, Some(joinCondition), JoinHint.NONE) + val groupingExpression1 = Alias(UnresolvedAttribute("c_custkey"), "c_custkey")() + val aggregateExpressions1 = + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("o_orderkey")), + isDistinct = false), + "c_count")() + val agg3 = + Aggregate(Seq(groupingExpression1), Seq(aggregateExpressions1, groupingExpression1), join) + val subqueryAlias = SubqueryAlias("c_orders", agg3) + val agg2 = + Aggregate( + Seq(groupingExpression1), + Seq(aggregateExpressions1, groupingExpression1), + subqueryAlias) + val groupingExpression2 = Alias(UnresolvedAttribute("c_count"), "c_count")() + val aggregateExpressions2 = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(Literal(1)), isDistinct = false), "custdist")() + val agg1 = + Aggregate(Seq(groupingExpression2), Seq(aggregateExpressions2, groupingExpression2), agg2) + val sort = Sort( + Seq( + SortOrder(UnresolvedAttribute("custdist"), Descending), + SortOrder(UnresolvedAttribute("c_count"), Descending)), + global = true, + agg1) + val expectedPlan = Project(Seq(UnresolvedStar(None)), sort) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } }