Skip to content

Commit 625f76d

Browse files
committed
[SPARK-40760][SQL] Migrate type check failures of interval expressions onto error classes
### What changes were proposed in this pull request? In the PR, I propose to add new error sub-classes of the error class `DATATYPE_MISMATCH`, and use it in the case of type check failures of some interval expressions. ### Why are the changes needed? Migration onto error classes unifies Spark SQL error messages, and improves search-ability of errors. ### Does this PR introduce _any_ user-facing change? Yes. The PR changes user-facing error messages. ### How was this patch tested? By running the affected test suites: ``` $ build/sbt "test:testOnly *AnalysisSuite" $ build/sbt "test:testOnly *ExpressionTypeCheckingSuite" $ build/sbt "test:testOnly *ApproxCountDistinctForIntervalsSuite" ``` Closes apache#38237 from MaxGekk/type-check-fails-interval-exprs. Authored-by: Max Gekk <max.gekk@gmail.com> Signed-off-by: Max Gekk <max.gekk@gmail.com>
1 parent f81c265 commit 625f76d

File tree

9 files changed

+123
-43
lines changed

9 files changed

+123
-43
lines changed

core/src/main/resources/error/error-classes.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,11 @@
263263
"The <exprName> must be between <valueRange> (current value = <currentValue>)"
264264
]
265265
},
266+
"WRONG_NUM_ENDPOINTS" : {
267+
"message" : [
268+
"The number of endpoints must be >= 2 to construct intervals but the actual number is <actualNumber>."
269+
]
270+
},
266271
"WRONG_NUM_PARAMS" : {
267272
"message" : [
268273
"The <functionName> requires <expectedNum> parameters but the actual number is <actualNum>."

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ import java.util
2121

2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
24-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
24+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
2525
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, GenericInternalRow}
2626
import org.apache.spark.sql.catalyst.trees.BinaryLike
2727
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, HyperLogLogPlusPlusHelper}
28+
import org.apache.spark.sql.errors.QueryErrorsBase
2829
import org.apache.spark.sql.types._
2930
import org.apache.spark.unsafe.Platform
3031

@@ -49,7 +50,10 @@ case class ApproxCountDistinctForIntervals(
4950
relativeSD: Double = 0.05,
5051
mutableAggBufferOffset: Int = 0,
5152
inputAggBufferOffset: Int = 0)
52-
extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes with BinaryLike[Expression] {
53+
extends TypedImperativeAggregate[Array[Long]]
54+
with ExpectsInputTypes
55+
with BinaryLike[Expression]
56+
with QueryErrorsBase {
5357

5458
def this(child: Expression, endpointsExpression: Expression, relativeSD: Expression) = {
5559
this(
@@ -77,19 +81,32 @@ case class ApproxCountDistinctForIntervals(
7781
if (defaultCheck.isFailure) {
7882
defaultCheck
7983
} else if (!endpointsExpression.foldable) {
80-
TypeCheckFailure("The endpoints provided must be constant literals")
84+
DataTypeMismatch(
85+
errorSubClass = "NON_FOLDABLE_INPUT",
86+
messageParameters = Map(
87+
"inputName" -> "endpointsExpression",
88+
"inputType" -> toSQLType(endpointsExpression.dataType)))
8189
} else {
8290
endpointsExpression.dataType match {
8391
case ArrayType(_: NumericType | DateType | TimestampType | TimestampNTZType |
8492
_: AnsiIntervalType, _) =>
8593
if (endpoints.length < 2) {
86-
TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals")
94+
DataTypeMismatch(
95+
errorSubClass = "WRONG_NUM_ENDPOINTS",
96+
messageParameters = Map("actualNumber" -> endpoints.length.toString))
8797
} else {
8898
TypeCheckSuccess
8999
}
90-
case _ =>
91-
TypeCheckFailure("Endpoints require (numeric or timestamp or date or timestamp_ntz or " +
92-
"interval year to month or interval day to second) type")
100+
case inputType =>
101+
val requiredElemTypes = toSQLType(TypeCollection(
102+
NumericType, DateType, TimestampType, TimestampNTZType, AnsiIntervalType))
103+
DataTypeMismatch(
104+
errorSubClass = "UNEXPECTED_INPUT_TYPE",
105+
messageParameters = Map(
106+
"paramIndex" -> "2",
107+
"requiredType" -> s"ARRAY OF $requiredElemTypes",
108+
"inputSql" -> toSQLExpr(endpointsExpression),
109+
"inputType" -> toSQLType(inputType)))
93110
}
94111
}
95112
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ case class Average(
5454
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))
5555

5656
override def checkInputDataTypes(): TypeCheckResult =
57-
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "average")
57+
TypeUtils.checkForAnsiIntervalOrNumericType(child)
5858

5959
override def nullable: Boolean = true
6060

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ case class Sum(
6767
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))
6868

6969
override def checkInputDataTypes(): TypeCheckResult =
70-
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, prettyName)
70+
TypeUtils.checkForAnsiIntervalOrNumericType(child)
7171

7272
final override val nodePatterns: Seq[TreePattern] = Seq(SUM)
7373

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@ package org.apache.spark.sql.catalyst.util
1919

2020
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
22-
import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType
23-
import org.apache.spark.sql.catalyst.expressions.RowOrdering
24-
import org.apache.spark.sql.errors.QueryCompilationErrors
22+
import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering}
23+
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
2524
import org.apache.spark.sql.types._
2625

2726
/**
2827
* Functions to help with checking for valid data types and value comparison of various types.
2928
*/
30-
object TypeUtils {
29+
object TypeUtils extends QueryErrorsBase {
3130

3231
def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = {
3332
if (RowOrdering.isOrderable(dt)) {
@@ -70,13 +69,18 @@ object TypeUtils {
7069
}
7170
}
7271

73-
def checkForAnsiIntervalOrNumericType(
74-
dt: DataType, funcName: String): TypeCheckResult = dt match {
72+
def checkForAnsiIntervalOrNumericType(input: Expression): TypeCheckResult = input.dataType match {
7573
case _: AnsiIntervalType | NullType =>
7674
TypeCheckResult.TypeCheckSuccess
7775
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
78-
case other => TypeCheckResult.TypeCheckFailure(
79-
s"function $funcName requires numeric or interval types, not ${other.catalogString}")
76+
case other =>
77+
DataTypeMismatch(
78+
errorSubClass = "UNEXPECTED_INPUT_TYPE",
79+
messageParameters = Map(
80+
"paramIndex" -> "1",
81+
"requiredType" -> Seq(NumericType, AnsiIntervalType).map(toSQLType).mkString(" or "),
82+
"inputSql" -> toSQLExpr(input),
83+
"inputType" -> toSQLType(other)))
8084
}
8185

8286
def getNumeric(t: DataType, exactNumericRequired: Boolean = false): Numeric[Any] = {

sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,12 @@ private[sql] abstract class DatetimeType extends AtomicType
233233
* The interval type which conforms to the ANSI SQL standard.
234234
*/
235235
private[sql] abstract class AnsiIntervalType extends AtomicType
236+
237+
private[spark] object AnsiIntervalType extends AbstractDataType {
238+
override private[sql] def simpleString: String = "ANSI interval"
239+
240+
override private[sql] def acceptsType(other: DataType): Boolean =
241+
other.isInstanceOf[AnsiIntervalType]
242+
243+
override private[sql] def defaultConcreteType: DataType = DayTimeIntervalType()
244+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,25 +1163,39 @@ class AnalysisSuite extends AnalysisTest with Matchers {
11631163
}
11641164

11651165
test("SPARK-38118: Func(wrong_type) in the HAVING clause should throw data mismatch error") {
1166-
assertAnalysisError(parsePlan(
1167-
s"""
1168-
|WITH t as (SELECT true c)
1169-
|SELECT t.c
1170-
|FROM t
1171-
|GROUP BY t.c
1172-
|HAVING mean(t.c) > 0d""".stripMargin),
1173-
Seq(s"cannot resolve 'mean(t.c)' due to data type mismatch"),
1174-
false)
1166+
assertAnalysisErrorClass(
1167+
inputPlan = parsePlan(
1168+
s"""
1169+
|WITH t as (SELECT true c)
1170+
|SELECT t.c
1171+
|FROM t
1172+
|GROUP BY t.c
1173+
|HAVING mean(t.c) > 0d""".stripMargin),
1174+
expectedErrorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
1175+
expectedMessageParameters = Map(
1176+
"sqlExpr" -> "\"mean(c)\"",
1177+
"paramIndex" -> "1",
1178+
"inputSql" -> "\"c\"",
1179+
"inputType" -> "\"BOOLEAN\"",
1180+
"requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""),
1181+
caseSensitive = false)
11751182

1176-
assertAnalysisError(parsePlan(
1177-
s"""
1178-
|WITH t as (SELECT true c, false d)
1179-
|SELECT (t.c AND t.d) c
1180-
|FROM t
1181-
|GROUP BY t.c, t.d
1182-
|HAVING mean(c) > 0d""".stripMargin),
1183-
Seq(s"cannot resolve 'mean(t.c)' due to data type mismatch"),
1184-
false)
1183+
assertAnalysisErrorClass(
1184+
inputPlan = parsePlan(
1185+
s"""
1186+
|WITH t as (SELECT true c, false d)
1187+
|SELECT (t.c AND t.d) c
1188+
|FROM t
1189+
|GROUP BY t.c, t.d
1190+
|HAVING mean(c) > 0d""".stripMargin),
1191+
expectedErrorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
1192+
expectedMessageParameters = Map(
1193+
"sqlExpr" -> "\"mean(c)\"",
1194+
"paramIndex" -> "1",
1195+
"inputSql" -> "\"c\"",
1196+
"inputType" -> "\"BOOLEAN\"",
1197+
"requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""),
1198+
caseSensitive = false)
11851199

11861200
assertAnalysisErrorClass(
11871201
inputPlan = parsePlan(

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,29 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer
396396
"dataType" -> "\"MAP<STRING, BIGINT>\""
397397
)
398398
)
399-
assertError(Sum($"booleanField"), "function sum requires numeric or interval types")
400-
assertError(Average($"booleanField"),
401-
"function average requires numeric or interval types")
399+
400+
checkError(
401+
exception = intercept[AnalysisException] {
402+
assertSuccess(Sum($"booleanField"))
403+
},
404+
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
405+
parameters = Map(
406+
"sqlExpr" -> "\"sum(booleanField)\"",
407+
"paramIndex" -> "1",
408+
"inputSql" -> "\"booleanField\"",
409+
"inputType" -> "\"BOOLEAN\"",
410+
"requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""))
411+
checkError(
412+
exception = intercept[AnalysisException] {
413+
assertSuccess(Average($"booleanField"))
414+
},
415+
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
416+
parameters = Map(
417+
"sqlExpr" -> "\"avg(booleanField)\"",
418+
"paramIndex" -> "1",
419+
"inputSql" -> "\"booleanField\"",
420+
"inputType" -> "\"BOOLEAN\"",
421+
"requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""))
402422
}
403423

404424
test("check types for others") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.time.LocalDateTime
2222

2323
import org.apache.spark.SparkFunSuite
2424
import org.apache.spark.sql.catalyst.InternalRow
25-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
25+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
2626
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, CreateArray, Literal, SpecificInternalRow}
2727
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils}
2828
import org.apache.spark.sql.types._
@@ -48,20 +48,31 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
4848
AttributeReference("a", DoubleType)(),
4949
endpointsExpression = CreateArray(Seq(AttributeReference("b", DoubleType)())))
5050
assert(wrongEndpoints.checkInputDataTypes() ==
51-
TypeCheckFailure("The endpoints provided must be constant literals"))
51+
DataTypeMismatch(
52+
errorSubClass = "NON_FOLDABLE_INPUT",
53+
messageParameters = Map(
54+
"inputName" -> "endpointsExpression",
55+
"inputType" -> "\"ARRAY<DOUBLE>\"")))
5256

5357
wrongEndpoints = ApproxCountDistinctForIntervals(
5458
AttributeReference("a", DoubleType)(),
5559
endpointsExpression = CreateArray(Array(10L).map(Literal(_))))
5660
assert(wrongEndpoints.checkInputDataTypes() ==
57-
TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals"))
61+
DataTypeMismatch("WRONG_NUM_ENDPOINTS", Map("actualNumber" -> "1")))
5862

5963
wrongEndpoints = ApproxCountDistinctForIntervals(
6064
AttributeReference("a", DoubleType)(),
6165
endpointsExpression = CreateArray(Array("foobar").map(Literal(_))))
66+
// scalastyle:off line.size.limit
6267
assert(wrongEndpoints.checkInputDataTypes() ==
63-
TypeCheckFailure("Endpoints require (numeric or timestamp or date or timestamp_ntz or " +
64-
"interval year to month or interval day to second) type"))
68+
DataTypeMismatch(
69+
errorSubClass = "UNEXPECTED_INPUT_TYPE",
70+
messageParameters = Map(
71+
"paramIndex" -> "2",
72+
"requiredType" -> "ARRAY OF (\"NUMERIC\" or \"DATE\" or \"TIMESTAMP\" or \"TIMESTAMP_NTZ\" or \"ANSI INTERVAL\")",
73+
"inputSql" -> "\"array(foobar)\"",
74+
"inputType" -> "\"ARRAY<STRING>\"")))
75+
// scalastyle:on line.size.limit
6576
}
6677

6778
/** Create an ApproxCountDistinctForIntervals instance and an input and output buffer. */

0 commit comments

Comments
 (0)