Skip to content

Commit

Permalink
update literal transformations according to catalyst's convention
Browse files Browse the repository at this point in the history
Signed-off-by: YANGDB <yang.db.dev@gmail.com>
  • Loading branch information
YANG-DB committed Sep 7, 2023
1 parent af065f7 commit 5819dc7
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class FlintSparkPPLITSuite
assert(expectedPlan === logicalPlan)
}

test("create ppl simple filter query with two fields result test") {
test("create ppl simple age literal equal filter query with two fields result test") {
val frame = sql(
s"""
| source = $testTable age=25 | fields name, age
Expand All @@ -102,4 +102,22 @@ class FlintSparkPPLITSuite
// Compare the two plans
assert(expectedPlan === logicalPlan)
}

test("create ppl simple name literal equal filter query with two fields result test") {
val frame = sql(
s"""
| source = $testTable name='George' | fields name, age
| """.stripMargin)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table = UnresolvedRelation(Seq("default","flint_ppl_tst"))
val filterExpr = EqualTo(UnresolvedAttribute("name"), Literal("'George'"))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("name"),UnresolvedAttribute("age"))
val expectedPlan = Project(projectList, filterPlan)
// Compare the two plans
assert(expectedPlan === logicalPlan)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ public String analyze(UnresolvedExpression unresolved, CatalystPlanContext conte

@Override
public String visitLiteral(Literal node, CatalystPlanContext context) {
context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.Literal(node.getValue(), translate(node.getType())));
context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.Literal(
translate(node.getValue(),node.getType()), translate(node.getType())));
return node.toString();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

import org.apache.spark.sql.types.ByteType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DateType$;
import org.apache.spark.sql.types.IntegerType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.unsafe.types.UTF8String;

/**
* translate the PPL ast expressions data-types into catalyst data-types
Expand All @@ -23,4 +25,13 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) {
return StringType$.MODULE$;
}
}

static Object translate(Object value, org.opensearch.sql.ast.expression.DataType source) {
switch (source.getCoreType()) {
case STRING:
return UTF8String.fromString(value.toString());
default:
return value;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class PPLLogicalPlanTranslatorTestSuite

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table")))
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[table] | fields + *")

}
Expand All @@ -49,7 +49,7 @@ class PPLLogicalPlanTranslatorTestSuite

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table")))
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[schema.table] | fields + *")

}
Expand All @@ -60,7 +60,7 @@ class PPLLogicalPlanTranslatorTestSuite

val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A"))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table")))
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[schema.table] | fields + A")
}

Expand All @@ -70,7 +70,7 @@ class PPLLogicalPlanTranslatorTestSuite

val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A"))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table")))
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[table] | fields + A")
}

Expand All @@ -83,11 +83,11 @@ class PPLLogicalPlanTranslatorTestSuite
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | where a = 1 | fields + *")
}

test("test simple search with only one table with one field literal equality filtered and one field projected") {
test("test simple search with only one table with one field literal int equality filtered and one field projected") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source=t a = 1 | fields a", false), context)

Expand All @@ -96,10 +96,24 @@ class PPLLogicalPlanTranslatorTestSuite
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("a"))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | where a = 1 | fields + a")
}

test("test simple search with only one table with one field literal string equality filtered and one field projected") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, """source=t a = 'hi' | fields a""", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = EqualTo(UnresolvedAttribute("a"), Literal("'hi'"))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("a"))
val expectedPlan = Project(projectList, filterPlan)

assertEquals(expectedPlan,context.getPlan)
assertEquals(logPlan, "source=[t] | where a = 'hi' | fields + a")
}

test("test simple search with only one table with one field greater than filtered and one field projected") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source=t a > 1 | fields a", false), context)
Expand All @@ -109,7 +123,7 @@ class PPLLogicalPlanTranslatorTestSuite
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("a"))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | where a > 1 | fields + a")
}

Expand All @@ -122,7 +136,7 @@ class PPLLogicalPlanTranslatorTestSuite
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("a"))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | where a >= 1 | fields + a")
}

Expand All @@ -135,7 +149,7 @@ class PPLLogicalPlanTranslatorTestSuite
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("a"))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | where a < 1 | fields + a")
}

Expand All @@ -148,7 +162,7 @@ class PPLLogicalPlanTranslatorTestSuite
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("a"))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | where a <= 1 | fields + a")
}

Expand All @@ -161,7 +175,7 @@ class PPLLogicalPlanTranslatorTestSuite
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("a"))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | where a != 1 | fields + a")
}

Expand All @@ -174,7 +188,7 @@ class PPLLogicalPlanTranslatorTestSuite
val table = UnresolvedRelation(Seq("t"))
val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"))
val expectedPlan = Project(projectList, table)
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | fields + A,B")
}

Expand All @@ -196,7 +210,7 @@ class PPLLogicalPlanTranslatorTestSuite
val expectedPlan = Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true)

assertEquals(logPlan, "source=[table1, table2] | fields + A,B")
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
}


Expand All @@ -217,7 +231,7 @@ class PPLLogicalPlanTranslatorTestSuite
val expectedPlan = Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true)

assertEquals(logPlan, "source=[table1, table2] | fields + *")
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
}

test("Find What are the average prices for different types of properties") {
Expand All @@ -236,7 +250,7 @@ class PPLLogicalPlanTranslatorTestSuite
)
val expectedPlan = Project(projectList, grouped)

assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "???")

}
Expand All @@ -259,7 +273,7 @@ class PPLLogicalPlanTranslatorTestSuite
val expectedPlan = Project(finalProjectList, limited)

// Assert that the generated plan is as expected
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "???")
}

Expand All @@ -286,7 +300,7 @@ class PPLLogicalPlanTranslatorTestSuite
UnresolvedAttribute("avg_price_per_land_unit")
), groupBy)
// Continue with your test...
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "???")
}

Expand All @@ -311,7 +325,7 @@ class PPLLogicalPlanTranslatorTestSuite

val groupByAttributes = Seq(UnresolvedAttribute("property_status"))
val expectedPlan = Aggregate(groupByAttributes, aggregateExpressions, filter)
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "???")
}

Expand All @@ -334,7 +348,7 @@ class PPLLogicalPlanTranslatorTestSuite
val sort = Sort(sortOrder, true, filter)

val expectedPlan = Project(projectList, sort)
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "???")
}

Expand Down Expand Up @@ -365,7 +379,7 @@ class PPLLogicalPlanTranslatorTestSuite
)
)
// Add to your unit test
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "???")
}

Expand Down Expand Up @@ -400,7 +414,7 @@ class PPLLogicalPlanTranslatorTestSuite
)

// Add to your unit test
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "???")
}

Expand Down Expand Up @@ -434,7 +448,7 @@ class PPLLogicalPlanTranslatorTestSuite
)

// Add to your unit test
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "???")

}
Expand Down Expand Up @@ -464,7 +478,7 @@ class PPLLogicalPlanTranslatorTestSuite
)

// Add to your unit test
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "???")
}

Expand All @@ -483,7 +497,7 @@ class PPLLogicalPlanTranslatorTestSuite
UnresolvedRelation(TableIdentifier("access_logs"))
)

assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "???")

}
Expand All @@ -504,7 +518,7 @@ class PPLLogicalPlanTranslatorTestSuite
)

// Add to your unit test
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "???")

}
Expand All @@ -528,7 +542,7 @@ class PPLLogicalPlanTranslatorTestSuite
)

// Add to your unit test
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "???")

}
Expand All @@ -551,7 +565,7 @@ class PPLLogicalPlanTranslatorTestSuite
UnresolvedRelation(TableIdentifier("sso_logs-nginx-*"))
)
)
assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "???")
}

Expand Down Expand Up @@ -580,7 +594,7 @@ class PPLLogicalPlanTranslatorTestSuite
)
)

assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "???")

}
Expand All @@ -603,7 +617,7 @@ class PPLLogicalPlanTranslatorTestSuite
)
)

assertEquals(context.getPlan, expectedPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "???")

}
Expand Down

0 comments on commit 5819dc7

Please sign in to comment.