Skip to content

Commit

Permalink
update scala fmt style
Browse files Browse the repository at this point in the history
Signed-off-by: YANGDB <yang.db.dev@gmail.com>
  • Loading branch information
YANG-DB committed Oct 16, 2024
1 parent 63c6118 commit 049be03
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq

import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Expression, Literal, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest

Expand Down Expand Up @@ -36,81 +36,201 @@ class FlintSparkPPLFieldSummaryITSuite
}
}

ignore("test fieldsummary with single field includefields(status_code) & nulls=true ") {
test("test fieldsummary with single field includefields(status_code) & nulls=true ") {
val frame = sql(s"""
| source = $testTable | fieldsummary includefields= status_code nulls=true
| """.stripMargin)
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(Row("status_code", 4, 3, 200, 403, 276.0, "int"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical

// Aggregate with functions applied to status_code
val aggregateExpressions: Seq[NamedExpression] = Seq(
Alias(Literal("status_code"), "Field")(),
Alias(
UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = false),
"COUNT")(),
Alias(
UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true),
"COUNT_DISTINCT")(),
Alias(
UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false),
"MIN")(),
Alias(
UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false),
"MAX")(),
Alias(
UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false),
"AVG")(),
Alias(
UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false),
"TYPEOF")())

val table =
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

// Define the aggregate plan with alias for TYPEOF in the aggregation
val aggregatePlan = Aggregate(
groupingExpressions = Seq(Alias(
UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false),
"TYPEOF")()),
aggregateExpressions,
table)
val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan)
// Compare the two plans
comparePlans(expectedPlan, logicalPlan, false)
}

/*
val frame = sql(s"""
| SELECT
| 'status_code' AS Field,
| COUNT(status_code) AS Count,
| COUNT(DISTINCT status_code) AS Distinct,
| MIN(status_code) AS Min,
| MAX(status_code) AS Max,
| AVG(CAST(status_code AS DOUBLE)) AS Avg,
| typeof(status_code) AS Type,
| COUNT(*) - COUNT(status_code) AS Nulls
| FROM $testTable
| GROUP BY typeof(status_code)
| """.stripMargin)
*/

// val frame = sql(s"""
// | SELECT
// | 'status_code' AS Field,
// | COUNT(status_code) AS Count,
// | COUNT(DISTINCT status_code) AS Distinct,
// | MIN(status_code) AS Min,
// | MAX(status_code) AS Max,
// | AVG(CAST(status_code AS DOUBLE)) AS Avg,
// | typeof(status_code) AS Type,
// | (SELECT COLLECT_LIST(STRUCT(status_code, count_status))
// | FROM (
// | SELECT status_code, COUNT(*) AS count_status
// | FROM $testTable
// | GROUP BY status_code
// | ORDER BY count_status DESC
// | LIMIT 5
// | )) AS top_values,
// | COUNT(*) - COUNT(status_code) AS Nulls
// | FROM $testTable
// | GROUP BY typeof(status_code)
// |
// | UNION ALL
// |
// | SELECT
// | 'id' AS Field,
// | COUNT(id) AS Count,
// | COUNT(DISTINCT id) AS Distinct,
// | MIN(id) AS Min,
// | MAX(id) AS Max,
// | AVG(CAST(id AS DOUBLE)) AS Avg,
// | typeof(id) AS Type,
// | (SELECT COLLECT_LIST(STRUCT(id, count_id))
// | FROM (
// | SELECT id, COUNT(*) AS count_id
// | FROM $testTable
// | GROUP BY id
// | ORDER BY count_id DESC
// | LIMIT 5
// | )) AS top_values,
// | COUNT(*) - COUNT(id) AS Nulls
// | FROM $testTable
// | GROUP BY typeof(id)
// |""".stripMargin)
/**
* // val frame = sql(s""" // | SELECT // | 'status_code' AS Field, // | COUNT(status_code) AS
* Count, // | COUNT(DISTINCT status_code) AS Distinct, // | MIN(status_code) AS Min, // |
* MAX(status_code) AS Max, // | AVG(CAST(status_code AS DOUBLE)) AS Avg, // |
* typeof(status_code) AS Type, // | (SELECT COLLECT_LIST(STRUCT(status_code, count_status)) //
* \| FROM ( // | SELECT status_code, COUNT(*) AS count_status // | FROM $testTable // | GROUP
* BY status_code // | ORDER BY count_status DESC // | LIMIT 5 // | )) AS top_values, // |
* COUNT(*) - COUNT(status_code) AS Nulls // | FROM $testTable // | GROUP BY typeof(status_code)
* // | // | UNION ALL // | // | SELECT // | 'id' AS Field, // | COUNT(id) AS Count, // |
* COUNT(DISTINCT id) AS Distinct, // | MIN(id) AS Min, // | MAX(id) AS Max, // | AVG(CAST(id AS
* DOUBLE)) AS Avg, // | typeof(id) AS Type, // | (SELECT COLLECT_LIST(STRUCT(id, count_id)) //
* \| FROM ( // | SELECT id, COUNT(*) AS count_id // | FROM $testTable // | GROUP BY id // |
* ORDER BY count_id DESC // | LIMIT 5 // | )) AS top_values, // | COUNT(*) - COUNT(id) AS Nulls
* // | FROM $testTable // | GROUP BY typeof(id) // |""".stripMargin) // Aggregate with
* functions applied to status_code
*/
test(
"test fieldsummary with single field includefields(id, status_code, request_path) & nulls=true") {
val frame = sql(s"""
| source = $testTable | fieldsummary includefields= id, status_code, request_path nulls=true
| """.stripMargin)

val results: Array[Row] = frame.collect()
// Print each row in a readable format
val expectedResults: Array[Row] =
Array(
Row("id", 6L, 6L, "1", "6", 3.5, "int"),
Row("status_code", 4L, 3L, "200", "403", 276.0, "int"),
Row("request_path", 4L, 3L, "/about", "/home", null, "string"))

implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
// scalastyle:off println
results.foreach(row => println(row.mkString(", ")))
println(logicalPlan)
// scalastyle:on println

// val expectedPlan = ?
// comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
val table =
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
// Define the aggregate plan with alias for TYPEOF in the aggregation
val aggregateIdPlan = Aggregate(
Seq(
Alias(
UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false),
"TYPEOF")()),
Seq(
Alias(Literal("id"), "Field")(),
Alias(
UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = false),
"COUNT")(),
Alias(
UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("id")), isDistinct = true),
"COUNT_DISTINCT")(),
Alias(
UnresolvedFunction("MIN", Seq(UnresolvedAttribute("id")), isDistinct = false),
"MIN")(),
Alias(
UnresolvedFunction("MAX", Seq(UnresolvedAttribute("id")), isDistinct = false),
"MAX")(),
Alias(
UnresolvedFunction("AVG", Seq(UnresolvedAttribute("id")), isDistinct = false),
"AVG")(),
Alias(
UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("id")), isDistinct = false),
"TYPEOF")()),
table)
val idProj = Project(seq(UnresolvedStar(None)), aggregateIdPlan)

// Aggregate with functions applied to status_code
// Define the aggregate plan with alias for TYPEOF in the aggregation
val aggregateStatusCodePlan = Aggregate(
Seq(Alias(
UnresolvedFunction("TYPEOF", Seq(UnresolvedAttribute("status_code")), isDistinct = false),
"TYPEOF")()),
Seq(
Alias(Literal("status_code"), "Field")(),
Alias(
UnresolvedFunction(
"COUNT",
Seq(UnresolvedAttribute("status_code")),
isDistinct = false),
"COUNT")(),
Alias(
UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("status_code")), isDistinct = true),
"COUNT_DISTINCT")(),
Alias(
UnresolvedFunction("MIN", Seq(UnresolvedAttribute("status_code")), isDistinct = false),
"MIN")(),
Alias(
UnresolvedFunction("MAX", Seq(UnresolvedAttribute("status_code")), isDistinct = false),
"MAX")(),
Alias(
UnresolvedFunction("AVG", Seq(UnresolvedAttribute("status_code")), isDistinct = false),
"AVG")(),
Alias(
UnresolvedFunction(
"TYPEOF",
Seq(UnresolvedAttribute("status_code")),
isDistinct = false),
"TYPEOF")()),
table)
val statusCodeProj =
Project(seq(UnresolvedStar(None)), aggregateStatusCodePlan)

// Define the aggregate plan with alias for TYPEOF in the aggregation
val aggregatePlan = Aggregate(
Seq(
Alias(
UnresolvedFunction(
"TYPEOF",
Seq(UnresolvedAttribute("request_path")),
isDistinct = false),
"TYPEOF")()),
Seq(
Alias(Literal("request_path"), "Field")(),
Alias(
UnresolvedFunction(
"COUNT",
Seq(UnresolvedAttribute("request_path")),
isDistinct = false),
"COUNT")(),
Alias(
UnresolvedFunction(
"COUNT",
Seq(UnresolvedAttribute("request_path")),
isDistinct = true),
"COUNT_DISTINCT")(),
Alias(
UnresolvedFunction("MIN", Seq(UnresolvedAttribute("request_path")), isDistinct = false),
"MIN")(),
Alias(
UnresolvedFunction("MAX", Seq(UnresolvedAttribute("request_path")), isDistinct = false),
"MAX")(),
Alias(
UnresolvedFunction("AVG", Seq(UnresolvedAttribute("request_path")), isDistinct = false),
"AVG")(),
Alias(
UnresolvedFunction(
"TYPEOF",
Seq(UnresolvedAttribute("request_path")),
isDistinct = false),
"TYPEOF")()),
table)
val requestPathProj = Project(seq(UnresolvedStar(None)), aggregatePlan)

val expectedPlan =
Union(seq(idProj, statusCodeProj, requestPathProj), true, true)
// Compare the two plans
comparePlans(expectedPlan, logicalPlan, false)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
* @return
*/
public interface AggregatorTranslator {

static String aggregationAlias(BuiltinFunctionName functionName, QualifiedName name) {
return functionName.name()+"("+name.toString()+")";
}

static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction aggregateFunction, Expression arg) {
if (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).isEmpty())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MAX;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.MIN;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.TYPEOF;
import static org.opensearch.sql.ppl.utils.AggregatorTranslator.aggregationAlias;
import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq;
import static scala.Option.empty;

Expand Down Expand Up @@ -106,7 +105,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont
//Alias for the count(field) as Count
UnresolvedFunction count = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), false, empty(), false);
Alias countAlias = Alias$.MODULE$.apply(count,
aggregationAlias(COUNT, field.getField()),
COUNT.name(),
NamedExpression.newExprId(),
seq(),
empty(),
Expand All @@ -115,7 +114,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont
//Alias for the count(DISTINCT field) as CountDistinct
UnresolvedFunction countDistinct = new UnresolvedFunction(seq(COUNT.name()), seq(fieldLiteral), true, empty(), false);
Alias distinctCountAlias = Alias$.MODULE$.apply(countDistinct,
aggregationAlias(COUNT_DISTINCT, field.getField()),
COUNT_DISTINCT.name(),
NamedExpression.newExprId(),
seq(),
empty(),
Expand All @@ -124,7 +123,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont
//Alias for the MAX(field) as MAX
UnresolvedFunction max = new UnresolvedFunction(seq(MAX.name()), seq(fieldLiteral), false, empty(), false);
Alias maxAlias = Alias$.MODULE$.apply(max,
aggregationAlias(MAX, field.getField()),
MAX.name(),
NamedExpression.newExprId(),
seq(),
empty(),
Expand All @@ -133,7 +132,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont
//Alias for the MIN(field) as Min
UnresolvedFunction min = new UnresolvedFunction(seq(MIN.name()), seq(fieldLiteral), false, empty(), false);
Alias minAlias = Alias$.MODULE$.apply(min,
aggregationAlias(MIN, field.getField()),
MIN.name(),
NamedExpression.newExprId(),
seq(),
empty(),
Expand All @@ -142,7 +141,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont
//Alias for the AVG(field) as Avg
UnresolvedFunction avg = new UnresolvedFunction(seq(AVG.name()), seq(fieldLiteral), false, empty(), false);
Alias avgAlias = Alias$.MODULE$.apply(avg,
aggregationAlias(AVG, field.getField()),
AVG.name(),
NamedExpression.newExprId(),
seq(),
empty(),
Expand Down Expand Up @@ -196,7 +195,7 @@ static LogicalPlan translate(FieldSummary fieldSummary, CatalystPlanContext cont
//Alias for the typeOf(field) as Type
UnresolvedFunction typeOf = new UnresolvedFunction(seq(TYPEOF.name()), seq(fieldLiteral), false, empty(), false);
Alias typeOfAlias = Alias$.MODULE$.apply(typeOf,
aggregationAlias(TYPEOF, field.getField()),
TYPEOF.name(),
NamedExpression.newExprId(),
seq(),
empty(),
Expand Down
Loading

0 comments on commit 049be03

Please sign in to comment.