Skip to content

Commit

Permalink
[SPARK-40761][SQL] Migrate type check failures of percentile expressi…
Browse files Browse the repository at this point in the history
…ons onto error classes

### What changes were proposed in this pull request?
This pr replace `TypeCheckFailure` by `DataTypeMismatch` in type checks in the percentile expressions, includes `ApproximatePercentile.scala` and `percentiles.scala`

### Why are the changes needed?
Migration onto error classes unifies Spark SQL error messages.

### Does this PR introduce _any_ user-facing change?
Yes. The PR changes user-facing error messages.

### How was this patch tested?
- Pass GitHub Actions

Closes #38234 from LuciferYang/SPARK-40761-2.

Authored-by: yangjie01 <yangjie01@baidu.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
LuciferYang authored and MaxGekk committed Oct 15, 2022
1 parent b0fdecb commit 5c963e3
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 44 deletions.
12 changes: 11 additions & 1 deletion core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@
},
"NON_FOLDABLE_INPUT" : {
"message" : [
"the input should be a foldable string expression and not null; however, got <inputExpr>."
"the input <inputName> should be a foldable <inputType> expression; however, got <inputExpr>."
]
},
"NON_STRING_TYPE" : {
Expand Down Expand Up @@ -228,11 +228,21 @@
"parameter <paramIndex> requires <requiredType> type, however, <inputSql> is of <inputType> type."
]
},
"UNEXPECTED_NULL" : {
"message" : [
"The <exprName> must not be null"
]
},
"UNSPECIFIED_FRAME" : {
"message" : [
"Cannot use an UnspecifiedFrame. This should have been converted during analysis."
]
},
"VALUE_OUT_OF_RANGE" : {
"message" : [
"The <exprName> must be between <valueRange> (current value = <currentValue>)"
]
},
"WRONG_NUM_PARAMS" : {
"message" : [
"wrong number of parameters: <actualNum>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import com.google.common.primitives.{Doubles, Ints, Longs}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
import org.apache.spark.sql.catalyst.trees.TernaryLike
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
Expand Down Expand Up @@ -118,17 +119,46 @@ case class ApproximatePercentile(
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
defaultCheck
} else if (!percentageExpression.foldable || !accuracyExpression.foldable) {
TypeCheckFailure(s"The accuracy or percentage provided must be a constant literal")
} else if (!percentageExpression.foldable) {
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "percentage",
"inputType" -> toSQLType(percentageExpression.dataType),
"inputExpr" -> toSQLExpr(percentageExpression)
)
)
} else if (!accuracyExpression.foldable) {
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "accuracy",
"inputType" -> toSQLType(accuracyExpression.dataType),
"inputExpr" -> toSQLExpr(accuracyExpression)
)
)
} else if (accuracy <= 0 || accuracy > Int.MaxValue) {
TypeCheckFailure(s"The accuracy provided must be a literal between (0, ${Int.MaxValue}]" +
s" (current value = $accuracy)")
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "accuracy",
"valueRange" -> s"(0, ${Int.MaxValue}]",
"currentValue" -> toSQLValue(accuracy, LongType)
)
)
} else if (percentages == null) {
TypeCheckFailure("Percentage value must not be null")
DataTypeMismatch(
errorSubClass = "UNEXPECTED_NULL",
messageParameters = Map("exprName" -> "percentage"))
} else if (percentages.exists(percentage => percentage < 0.0D || percentage > 1.0D)) {
TypeCheckFailure(
s"All percentage values must be between 0.0 and 1.0 " +
s"(current = ${percentages.mkString(", ")})")
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "percentage",
"valueRange" -> "[0.0, 1.0]",
"currentValue" -> percentages.map(toSQLValue(_, DoubleType)).mkString(",")
)
)
} else {
TypeCheckSuccess
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ import java.util

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.QueryExecutionErrors
Expand Down Expand Up @@ -84,14 +85,28 @@ abstract class PercentileBase
defaultCheck
} else if (!percentageExpression.foldable) {
// percentageExpression must be foldable
TypeCheckFailure("The percentage(s) must be a constant literal, " +
s"but got $percentageExpression")
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "percentage",
"inputType" -> toSQLType(percentageExpression.dataType),
"inputExpr" -> toSQLExpr(percentageExpression))
)
} else if (percentages == null) {
TypeCheckFailure("Percentage value must not be null")
DataTypeMismatch(
errorSubClass = "UNEXPECTED_NULL",
messageParameters = Map("exprName" -> "percentage")
)
} else if (percentages.exists(percentage => percentage < 0.0 || percentage > 1.0)) {
// percentages(s) must be in the range [0.0, 1.0]
TypeCheckFailure("Percentage(s) must be between 0.0 and 1.0, " +
s"but got $percentageExpression")
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "percentage",
"valueRange" -> "[0.0, 1.0]",
"currentValue" -> percentages.map(toSQLValue(_, DoubleType)).mkString(",")
)
)
} else {
TypeCheckSuccess
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,17 @@ case class SchemaOfCsv(
override def checkInputDataTypes(): TypeCheckResult = {
if (child.foldable && csv != null) {
super.checkInputDataTypes()
} else {
} else if (!child.foldable) {
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map("inputExpr" -> toSQLExpr(child)))
messageParameters = Map(
"inputName" -> "csv",
"inputType" -> toSQLType(child.dataType),
"inputExpr" -> toSQLExpr(child)))
} else {
DataTypeMismatch(
errorSubClass = "UNEXPECTED_NULL",
messageParameters = Map("exprName" -> "csv"))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -794,10 +794,17 @@ case class SchemaOfJson(
override def checkInputDataTypes(): TypeCheckResult = {
if (child.foldable && json != null) {
super.checkInputDataTypes()
} else {
} else if (!child.foldable) {
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map("inputExpr" -> toSQLExpr(child)))
messageParameters = Map(
"inputName" -> "json",
"inputType" -> toSQLType(child.dataType),
"inputExpr" -> toSQLExpr(child)))
} else {
DataTypeMismatch(
errorSubClass = "UNEXPECTED_NULL",
messageParameters = Map("exprName" -> "json"))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ import java.sql.Date

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, CreateArray, DecimalLiteral, GenericInternalRow, Literal}
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest, PercentileDigestSerializer}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.util.{ArrayData, QuantileSummaries}
import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats
import org.apache.spark.sql.types.{ArrayType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, IntegralType}
import org.apache.spark.sql.types.{ArrayType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, IntegralType, LongType}
import org.apache.spark.util.SizeEstimator

class ApproximatePercentileSuite extends SparkFunSuite {
Expand Down Expand Up @@ -212,14 +213,22 @@ class ApproximatePercentileSuite extends SparkFunSuite {

test("class ApproximatePercentile, fails analysis if percentage or accuracy is not a constant") {
val attribute = AttributeReference("a", DoubleType)()
val accuracyExpression = AttributeReference("b", IntegerType)()
val wrongAccuracy = new ApproximatePercentile(
attribute,
percentageExpression = Literal(0.5D),
accuracyExpression = AttributeReference("b", IntegerType)())
accuracyExpression = accuracyExpression)

assertEqual(
wrongAccuracy.checkInputDataTypes(),
TypeCheckFailure("The accuracy or percentage provided must be a constant literal")
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "accuracy",
"inputType" -> toSQLType(accuracyExpression.dataType),
"inputExpr" -> toSQLExpr(accuracyExpression)
)
)
)

val wrongPercentage = new ApproximatePercentile(
Expand All @@ -229,19 +238,34 @@ class ApproximatePercentileSuite extends SparkFunSuite {

assertEqual(
wrongPercentage.checkInputDataTypes(),
TypeCheckFailure("The accuracy or percentage provided must be a constant literal")
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "percentage",
"inputType" -> toSQLType(attribute.dataType),
"inputExpr" -> toSQLExpr(attribute)
)
)
)
}

test("class ApproximatePercentile, fails analysis if parameters are invalid") {
val wrongAccuracyExpression = Literal(-1)
val wrongAccuracy = new ApproximatePercentile(
AttributeReference("a", DoubleType)(),
percentageExpression = Literal(0.5D),
accuracyExpression = Literal(-1))
accuracyExpression = wrongAccuracyExpression)
assertEqual(
wrongAccuracy.checkInputDataTypes(),
TypeCheckFailure(s"The accuracy provided must be a literal between (0, ${Int.MaxValue}]" +
" (current value = -1)"))
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "accuracy",
"valueRange" -> s"(0, ${Int.MaxValue}]",
"currentValue" ->
toSQLValue(wrongAccuracyExpression.eval().asInstanceOf[Number].longValue, LongType))
)
)

val correctPercentageExpressions = Seq(
Literal(0.1f, FloatType),
Expand Down Expand Up @@ -273,11 +297,32 @@ class ApproximatePercentileSuite extends SparkFunSuite {
percentageExpression = percentageExpression,
accuracyExpression = Literal(100))

assert(
wrongPercentage.checkInputDataTypes() match {
case TypeCheckFailure(msg) if msg.contains("must be between 0.0 and 1.0") => true
case _ => false
})
percentageExpression.eval() match {
case array: ArrayData =>
assertEqual(wrongPercentage.checkInputDataTypes(),
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "percentage",
"valueRange" -> "[0.0, 1.0]",
"currentValue" ->
array.toDoubleArray().map(toSQLValue(_, DoubleType)).mkString(",")
)
)
)
case other =>
assertEqual(wrongPercentage.checkInputDataTypes(),
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "percentage",
"valueRange" -> "[0.0, 1.0]",
"currentValue" ->
Array(other).map(toSQLValue(_, DoubleType)).mkString(",")
)
)
)
}
}
}

Expand Down Expand Up @@ -320,7 +365,7 @@ class ApproximatePercentileSuite extends SparkFunSuite {
assert(new ApproximatePercentile(
AttributeReference("a", DoubleType)(),
percentageExpression = Literal(null, DoubleType)).checkInputDataTypes() ===
TypeCheckFailure("Percentage value must not be null"))
DataTypeMismatch(errorSubClass = "UNEXPECTED_NULL", Map("exprName" -> "percentage")))

val nullPercentageExprs =
Seq(CreateArray(Seq(null).map(Literal(_))), CreateArray(Seq(0.1D, null).map(Literal(_))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -208,9 +209,32 @@ class PercentileSuite extends SparkFunSuite {

invalidPercentages.foreach { percentage =>
val percentile2 = new Percentile(child, percentage)
assertEqual(percentile2.checkInputDataTypes(),
TypeCheckFailure(s"Percentage(s) must be between 0.0 and 1.0, " +
s"but got ${percentage.simpleString(100)}"))
percentage.eval() match {
case array: ArrayData =>
assertEqual(percentile2.checkInputDataTypes(),
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "percentage",
"valueRange" -> "[0.0, 1.0]",
"currentValue" ->
array.toDoubleArray().map(toSQLValue(_, DoubleType)).mkString(",")
)
)
)
case other =>
assertEqual(percentile2.checkInputDataTypes(),
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "percentage",
"valueRange" -> "[0.0, 1.0]",
"currentValue" ->
Array(other).map(toSQLValue(_, DoubleType)).mkString(",")
)
)
)
}
}

val nonFoldablePercentage = Seq(NonFoldableLiteral(0.5),
Expand All @@ -219,8 +243,14 @@ class PercentileSuite extends SparkFunSuite {
nonFoldablePercentage.foreach { percentage =>
val percentile3 = new Percentile(child, percentage)
assertEqual(percentile3.checkInputDataTypes(),
TypeCheckFailure(s"The percentage(s) must be a constant literal, " +
s"but got ${percentage}"))
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "percentage",
"inputType" -> toSQLType(percentage.dataType),
"inputExpr" -> toSQLExpr(percentage))
)
)
}

val invalidDataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType,
Expand Down Expand Up @@ -261,7 +291,7 @@ class PercentileSuite extends SparkFunSuite {
assert(new Percentile(
AttributeReference("a", DoubleType)(),
percentageExpression = Literal(null, DoubleType)).checkInputDataTypes() ===
TypeCheckFailure("Percentage value must not be null"))
DataTypeMismatch(errorSubClass = "UNEXPECTED_NULL", Map("exprName" -> "percentage")))

val nullPercentageExprs =
Seq(CreateArray(Seq(null).map(Literal(_))), CreateArray(Seq(0.1D, null).map(Literal(_))))
Expand Down
Loading

0 comments on commit 5c963e3

Please sign in to comment.