From 5819dc7dc4035238af8fc0888d6b8db4797d2853 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Thu, 7 Sep 2023 13:54:40 -0700 Subject: [PATCH] update literal transformations according to catalyst's convention Signed-off-by: YANGDB --- .../flint/spark/FlintSparkPPLITSuite.scala | 20 ++++- .../sql/ppl/CatalystQueryPlanVisitor.java | 3 +- .../sql/ppl/utils/DataTypeTransformer.java | 11 +++ .../PPLLogicalPlanTranslatorTestSuite.scala | 74 +++++++++++-------- 4 files changed, 76 insertions(+), 32 deletions(-) diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala index f61751305..0efd24f67 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala @@ -85,7 +85,7 @@ class FlintSparkPPLITSuite assert(expectedPlan === logicalPlan) } - test("create ppl simple filter query with two fields result test") { + test("create ppl simple age literal equal filter query with two fields result test") { val frame = sql( s""" | source = $testTable age=25 | fields name, age @@ -102,4 +102,22 @@ class FlintSparkPPLITSuite // Compare the two plans assert(expectedPlan === logicalPlan) } + + test("create ppl simple name literal equal filter query with two fields result test") { + val frame = sql( + s""" + | source = $testTable name='George' | fields name, age + | """.stripMargin) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default","flint_ppl_tst")) + val filterExpr = EqualTo(UnresolvedAttribute("name"), Literal("'George'")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"),UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(expectedPlan === logicalPlan) + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 36d94424e..2e1ffe474 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -259,7 +259,8 @@ public String analyze(UnresolvedExpression unresolved, CatalystPlanContext conte @Override public String visitLiteral(Literal node, CatalystPlanContext context) { - context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.Literal(node.getValue(), translate(node.getType()))); + context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.Literal( + translate(node.getValue(),node.getType()), translate(node.getType()))); return node.toString(); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index bedbfb8c1..e1e48fc93 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -3,9 +3,11 @@ import org.apache.spark.sql.types.ByteType$; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DateType$; import org.apache.spark.sql.types.IntegerType$; import org.apache.spark.sql.types.StringType$; +import org.apache.spark.unsafe.types.UTF8String; /** * translate the PPL ast expressions data-types into catalyst data-types @@ -23,4 +25,13 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) { return StringType$.MODULE$; } } + + static Object translate(Object value, org.opensearch.sql.ast.expression.DataType source) { + switch (source.getCoreType()) { + case STRING: + return UTF8String.fromString(value.toString()); + default: + return value; + } + } } \ No newline at end of file diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTranslatorTestSuite.scala index 3784448b0..a82c7a24b 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTranslatorTestSuite.scala @@ -37,7 +37,7 @@ class PPLLogicalPlanTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "source=[table] | fields + *") } @@ -49,7 +49,7 @@ class PPLLogicalPlanTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table"))) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "source=[schema.table] | fields + *") } @@ -60,7 +60,7 @@ class PPLLogicalPlanTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A")) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table"))) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "source=[schema.table] | fields + A") } @@ -70,7 +70,7 @@ class PPLLogicalPlanTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A")) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "source=[table] | fields + A") } @@ -83,11 +83,11 @@ class PPLLogicalPlanTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, filterPlan) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "source=[t] | where a = 1 | fields + *") } - test("test simple search with only one table with one field literal equality filtered and one field projected") { + test("test simple search with only one table with one field literal int equality filtered and one field projected") { val context = new CatalystPlanContext val logPlan = planTrnasformer.visit(plan(pplParser, "source=t a = 1 | fields a", false), context) @@ -96,10 +96,24 @@ class PPLLogicalPlanTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "source=[t] | where a = 1 | fields + a") } + test("test simple search with only one table with one field literal string equality filtered and one field projected") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit(plan(pplParser, """source=t a = 'hi' | fields a""", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo(UnresolvedAttribute("a"), Literal("'hi'")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("a")) + val expectedPlan = Project(projectList, filterPlan) + + assertEquals(expectedPlan,context.getPlan) + assertEquals(logPlan, "source=[t] | where a = 'hi' | fields + a") + } + test("test simple search with only one table with one field greater than filtered and one field projected") { val context = new CatalystPlanContext val logPlan = planTrnasformer.visit(plan(pplParser, "source=t a > 1 | fields a", false), context) @@ -109,7 +123,7 @@ class PPLLogicalPlanTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "source=[t] | where a > 1 | fields + a") } @@ -122,7 +136,7 @@ class PPLLogicalPlanTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "source=[t] | where a >= 1 | fields + a") } @@ -135,7 +149,7 @@ class PPLLogicalPlanTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "source=[t] | where a < 1 | fields + a") } @@ -148,7 +162,7 @@ class PPLLogicalPlanTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "source=[t] | where a <= 1 | fields + a") } @@ -161,7 +175,7 @@ class PPLLogicalPlanTranslatorTestSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("a")) val expectedPlan = Project(projectList, filterPlan) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "source=[t] | where a != 1 | fields + a") } @@ -174,7 +188,7 @@ class PPLLogicalPlanTranslatorTestSuite val table = UnresolvedRelation(Seq("t")) val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) val expectedPlan = Project(projectList, table) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "source=[t] | fields + A,B") } @@ -196,7 +210,7 @@ class PPLLogicalPlanTranslatorTestSuite val expectedPlan = Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) assertEquals(logPlan, "source=[table1, table2] | fields + A,B") - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) } @@ -217,7 +231,7 @@ class PPLLogicalPlanTranslatorTestSuite val expectedPlan = Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) assertEquals(logPlan, "source=[table1, table2] | fields + *") - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) } test("Find What are the average prices for different types of properties") { @@ -236,7 +250,7 @@ class PPLLogicalPlanTranslatorTestSuite ) val expectedPlan = Project(projectList, grouped) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "???") } @@ -259,7 +273,7 @@ class PPLLogicalPlanTranslatorTestSuite val expectedPlan = Project(finalProjectList, limited) // Assert that the generated plan is as expected - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "???") } @@ -286,7 +300,7 @@ class PPLLogicalPlanTranslatorTestSuite UnresolvedAttribute("avg_price_per_land_unit") ), groupBy) // Continue with your test... - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "???") } @@ -311,7 +325,7 @@ class PPLLogicalPlanTranslatorTestSuite val groupByAttributes = Seq(UnresolvedAttribute("property_status")) val expectedPlan = Aggregate(groupByAttributes, aggregateExpressions, filter) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "???") } @@ -334,7 +348,7 @@ class PPLLogicalPlanTranslatorTestSuite val sort = Sort(sortOrder, true, filter) val expectedPlan = Project(projectList, sort) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "???") } @@ -365,7 +379,7 @@ class PPLLogicalPlanTranslatorTestSuite ) ) // Add to your unit test - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "???") } @@ -400,7 +414,7 @@ class PPLLogicalPlanTranslatorTestSuite ) // Add to your unit test - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "???") } @@ -434,7 +448,7 @@ class PPLLogicalPlanTranslatorTestSuite ) // Add to your unit test - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "???") } @@ -464,7 +478,7 @@ class PPLLogicalPlanTranslatorTestSuite ) // Add to your unit test - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "???") } @@ -483,7 +497,7 @@ class PPLLogicalPlanTranslatorTestSuite UnresolvedRelation(TableIdentifier("access_logs")) ) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "???") } @@ -504,7 +518,7 @@ class PPLLogicalPlanTranslatorTestSuite ) // Add to your unit test - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "???") } @@ -528,7 +542,7 @@ class PPLLogicalPlanTranslatorTestSuite ) // Add to your unit test - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "???") } @@ -551,7 +565,7 @@ class PPLLogicalPlanTranslatorTestSuite UnresolvedRelation(TableIdentifier("sso_logs-nginx-*")) ) ) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "???") } @@ -580,7 +594,7 @@ class PPLLogicalPlanTranslatorTestSuite ) ) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "???") } @@ -603,7 +617,7 @@ class PPLLogicalPlanTranslatorTestSuite ) ) - assertEquals(context.getPlan, expectedPlan) + assertEquals(expectedPlan, context.getPlan) assertEquals(logPlan, "???") }