Skip to content

Commit

Permalink
separate unit-tests into a dedicated file per each test category
Browse files Browse the repository at this point in the history
Signed-off-by: YANGDB <yang.db.dev@gmail.com>
  • Loading branch information
YANG-DB committed Sep 7, 2023
1 parent 5819dc7 commit 7db7213
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 280 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.sql.ppl;

import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.Union;

Expand All @@ -26,7 +25,7 @@ public class CatalystPlanContext {
/**
* NamedExpression contextual parameters
**/
private final Stack<Expression> namedParseExpressions = new Stack<>();
private final Stack<org.apache.spark.sql.catalyst.expressions.Expression> namedParseExpressions = new Stack<>();

public LogicalPlan getPlan() {
if (this.planBranches.size() == 1) {
Expand All @@ -36,7 +35,7 @@ public LogicalPlan getPlan() {
return new Union(asScalaBuffer(this.planBranches).toSeq(), true, true);
}

public Stack<Expression> getNamedParseExpressions() {
public Stack<org.apache.spark.sql.catalyst.expressions.Expression> getNamedParseExpressions() {
return namedParseExpressions;
}

Expand All @@ -48,7 +47,7 @@ public Stack<Expression> getNamedParseExpressions() {
public void with(LogicalPlan plan) {
this.planBranches.push(plan);
}

public void plan(Function<LogicalPlan, LogicalPlan> transformFunction) {
this.planBranches.replaceAll(transformFunction::apply);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
import org.apache.spark.sql.catalyst.analysis.UnresolvedTable;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.AggregateFunction;
Expand Down Expand Up @@ -50,17 +50,15 @@
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ppl.utils.ComparatorTransformer;
import scala.Option;
import scala.collection.JavaConverters;
import scala.collection.Seq;

import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

import static java.lang.String.format;
import static java.util.Collections.singletonList;
import static java.util.List.of;
import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate;
import static scala.Option.empty;
import static scala.collection.JavaConverters.asScalaBuffer;

/**
Expand Down Expand Up @@ -170,7 +168,7 @@ public String visitProject(Project node, CatalystPlanContext context) {
String fields = visitExpressionList(node.getProjectList(), context);

// Create an UnresolvedStar for all-fields projection
Seq<?> projectList = JavaConverters.asScalaBuffer(context.getNamedParseExpressions()).toSeq();
Seq<?> projectList = asScalaBuffer(context.getNamedParseExpressions()).toSeq();
// Create a Project node with the UnresolvedStar
context.plan(p -> new org.apache.spark.sql.catalyst.plans.logical.Project((Seq<NamedExpression>) projectList, p));

Expand Down Expand Up @@ -317,13 +315,14 @@ public String visitFunction(Function node, CatalystPlanContext context) {
public String visitCompare(Compare node, CatalystPlanContext context) {
String left = analyze(node.getLeft(), context);
String right = analyze(node.getRight(), context);
context.getNamedParseExpressions().add(ComparatorTransformer.comparator(node, context));
Predicate comparator = ComparatorTransformer.comparator(node, context);
context.getNamedParseExpressions().add((org.apache.spark.sql.catalyst.expressions.Expression)comparator);
return format("%s %s %s", left, node.getOperator(), right);
}

@Override
public String visitField(Field node, CatalystPlanContext context) {
context.getNamedParseExpressions().add(UnresolvedAttribute$.MODULE$.apply(JavaConverters.asScalaBuffer(Collections.singletonList(node.getField().toString()))));
context.getNamedParseExpressions().add(UnresolvedAttribute$.MODULE$.apply(asScalaBuffer(singletonList(node.getField().toString()))));
return node.getField().toString();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,18 @@
import org.apache.spark.sql.catalyst.expressions.BinaryComparison;
import org.apache.spark.sql.catalyst.expressions.EqualTo;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.GreaterThan;
import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.LessThan;
import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.Not;
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.ppl.CatalystPlanContext;

import static com.amazonaws.services.mturk.model.Comparator.NotEqualTo;

/**
* Transform the PPL Logical comparator into catalyst comparator
*/
Expand All @@ -15,7 +23,7 @@ public interface ComparatorTransformer {
* comparator expression builder building a catalyst binary comparator from PPL's compare logical step
* @return
*/
static BinaryComparison comparator(Compare expression, CatalystPlanContext context) {
static Predicate comparator(Compare expression, CatalystPlanContext context) {
if (BuiltinFunctionName.of(expression.getOperator()).isEmpty())
throw new IllegalStateException("Unexpected value: " + BuiltinFunctionName.of(expression.getOperator()));

Expand Down Expand Up @@ -274,15 +282,15 @@ static BinaryComparison comparator(Compare expression, CatalystPlanContext conte
case EQUAL:
return new EqualTo(left,right);
case NOTEQUAL:
break;
return new Not(new EqualTo(left,right));
case LESS:
break;
return new LessThan(left,right);
case LTE:
break;
return new LessThanOrEqual(left,right);
case GREATER:
break;
return new GreaterThan(left,right);
case GTE:
break;
return new GreaterThanOrEqual(left,right);
case LIKE:
break;
case NOT_LIKE:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.catalyst.plans.logical._
import org.junit.Assert.assertEquals
import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.scalatest.matchers.should.Matchers

class PPLLogicalPlanComplexQueriesTranslatorTestSuite
extends SparkFunSuite
with Matchers {

private val planTrnasformer = new CatalystQueryPlanVisitor()
private val pplParser = new PPLSyntaxParser()

test("test simple search with only one table and no explicit fields (defaults to all fields)") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source=table", false), context)

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table")))
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[table] | fields + *")

}

test("test simple search with schema.table and no explicit fields (defaults to all fields)") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source=schema.table", false), context)

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table")))
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[schema.table] | fields + *")

}

test("test simple search with schema.table and one field projected") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source=schema.table | fields A", false), context)

val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A"))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table")))
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[schema.table] | fields + A")
}

test("test simple search with only one table with one field projected") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source=table | fields A", false), context)

val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A"))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table")))
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[table] | fields + A")
}

test("test simple search with only one table with two fields projected") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source=t | fields A, B", false), context)


val table = UnresolvedRelation(Seq("t"))
val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"))
val expectedPlan = Project(projectList, table)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | fields + A,B")
}

test("Search multiple tables - translated into union call - fields expected to exist in both tables ") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "search source = table1, table2 | fields A, B", false), context)


val table1 = UnresolvedRelation(Seq("table1"))
val table2 = UnresolvedRelation(Seq("table2"))

val allFields1 = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"))
val allFields2 = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"))

val projectedTable1 = Project(allFields1, table1)
val projectedTable2 = Project(allFields2, table2)

val expectedPlan = Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true)

assertEquals(logPlan, "source=[table1, table2] | fields + A,B")
assertEquals(expectedPlan, context.getPlan)
}

test("Search multiple tables - translated into union call with fields") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source = table1, table2 ", false), context)


val table1 = UnresolvedRelation(Seq("table1"))
val table2 = UnresolvedRelation(Seq("table2"))

val allFields1 = UnresolvedStar(None)
val allFields2 = UnresolvedStar(None)

val projectedTable1 = Project(Seq(allFields1), table1)
val projectedTable2 = Project(Seq(allFields2), table2)

val expectedPlan = Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true)

assertEquals(logPlan, "source=[table1, table2] | fields + *")
assertEquals(expectedPlan, context.getPlan)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.apache.hadoop.conf.Configuration
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, TableFunctionRegistry, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Descending, Divide, EqualTo, Floor, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Like, Literal, NamedExpression, Not, SortOrder, UnixTimestamp}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.Assert.assertEquals
import org.mockito.Mockito.when
import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar.mock

class PPLLogicalPlanFiltersTranslatorTestSuite
extends SparkFunSuite
with Matchers {

private val planTrnasformer = new CatalystQueryPlanVisitor()
private val pplParser = new PPLSyntaxParser()

test("test simple search with only one table with one field literal filtered ") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source=t a = 1 ", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = EqualTo(UnresolvedAttribute("a"), Literal(1))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | where a = 1 | fields + *")
}

test("test simple search with only one table with one field literal int equality filtered and one field projected") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source=t a = 1 | fields a", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = EqualTo(UnresolvedAttribute("a"), Literal(1))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("a"))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | where a = 1 | fields + a")
}

test("test simple search with only one table with one field literal string equality filtered and one field projected") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, """source=t a = 'hi' | fields a""", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = EqualTo(UnresolvedAttribute("a"), Literal("'hi'"))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("a"))
val expectedPlan = Project(projectList, filterPlan)

assertEquals(expectedPlan,context.getPlan)
assertEquals(logPlan, "source=[t] | where a = 'hi' | fields + a")
}

test("test simple search with only one table with one field greater than filtered and one field projected") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source=t a > 1 | fields a", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = GreaterThan(UnresolvedAttribute("a"), Literal(1))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("a"))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | where a > 1 | fields + a")
}

test("test simple search with only one table with one field greater than equal filtered and one field projected") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source=t a >= 1 | fields a", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = GreaterThanOrEqual(UnresolvedAttribute("a"), Literal(1))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("a"))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | where a >= 1 | fields + a")
}

test("test simple search with only one table with one field lower than filtered and one field projected") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source=t a < 1 | fields a", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = LessThan(UnresolvedAttribute("a"), Literal(1))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("a"))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | where a < 1 | fields + a")
}

test("test simple search with only one table with one field lower than equal filtered and one field projected") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source=t a <= 1 | fields a", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = LessThanOrEqual(UnresolvedAttribute("a"), Literal(1))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("a"))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | where a <= 1 | fields + a")
}

test("test simple search with only one table with one field not equal filtered and one field projected") {
val context = new CatalystPlanContext
val logPlan = planTrnasformer.visit(plan(pplParser, "source=t a != 1 | fields a", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = Not(EqualTo(UnresolvedAttribute("a"), Literal(1)))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedAttribute("a"))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(expectedPlan, context.getPlan)
assertEquals(logPlan, "source=[t] | where a != 1 | fields + a")
}
}
Loading

0 comments on commit 7db7213

Please sign in to comment.