Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@
"The <functionName> accepts only arrays of pair structs, but <childExpr> is of <childType>."
]
},
"MAP_ZIP_WITH_DIFF_TYPES" : {
"message" : [
"Input to the <functionName> should have been two maps with compatible key types, but it's [<leftType>, <rightType>]."
]
},
"NON_FOLDABLE_INPUT" : {
"message" : [
"the input <inputName> should be a foldable <inputType> expression; however, got <inputExpr>."
Expand Down Expand Up @@ -270,6 +275,11 @@
"The <exprName> must not be null"
]
},
"UNEXPECTED_RETURN_TYPE" : {
"message" : [
"The <functionName> requires return <expectedType> type, but the actual is <actualType> type."
]
},
"UNEXPECTED_STATIC_METHOD" : {
"message" : [
"cannot find a static method <methodName> that matches the argument types in <className>"
Expand Down
6 changes: 6 additions & 0 deletions core/src/test/scala/org/apache/spark/SparkFunSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,12 @@ abstract class SparkFunSuite
checkError(exception, errorClass, sqlState, parameters,
false, Array(context))

protected def checkErrorMatchPVals(
exception: SparkThrowable,
errorClass: String,
parameters: Map[String, String]): Unit =
checkError(exception, errorClass, None, parameters, matchPVals = true)

protected def checkErrorMatchPVals(
exception: SparkThrowable,
errorClass: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern._
Expand Down Expand Up @@ -400,11 +402,25 @@ case class ArraySort(
if (function.dataType == IntegerType) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure("Return type of the given function has to be " +
"IntegerType")
DataTypeMismatch(
errorSubClass = "UNEXPECTED_RETURN_TYPE",
messageParameters = Map(
"functionName" -> toSQLId(function.prettyName),
"expectedType" -> toSQLType(IntegerType),
"actualType" -> toSQLType(function.dataType)
)
)
}
case _ =>
TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> "1",
"requiredType" -> toSQLType(ArrayType),
"inputSql" -> toSQLExpr(argument),
"inputType" -> toSQLType(argument.dataType)
)
)
}
case failure => failure
}
Expand Down Expand Up @@ -804,9 +820,13 @@ case class ArrayAggregate(
case TypeCheckResult.TypeCheckSuccess =>
if (!DataType.equalsStructurally(
zero.dataType, merge.dataType, ignoreNullability = true)) {
TypeCheckResult.TypeCheckFailure(
s"argument 3 requires ${zero.dataType.simpleString} type, " +
s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.")
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> "3",
"requiredType" -> toSQLType(zero.dataType),
"inputSql" -> toSQLExpr(merge),
"inputType" -> toSQLType(merge.dataType)))
} else {
TypeCheckResult.TypeCheckSuccess
}
Expand Down Expand Up @@ -1025,9 +1045,14 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
if (leftKeyType.sameType(rightKeyType)) {
TypeUtils.checkForOrderingExpr(leftKeyType, prettyName)
} else {
TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " +
s"been two ${MapType.simpleString}s with compatible key types, but the key types are " +
s"[${leftKeyType.catalogString}, ${rightKeyType.catalogString}].")
DataTypeMismatch(
errorSubClass = "MAP_ZIP_WITH_DIFF_TYPES",
messageParameters = Map(
"functionName" -> toSQLId(prettyName),
"leftType" -> toSQLType(leftKeyType),
"rightType" -> toSQLType(rightKeyType)
)
)
}
case failure => failure
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -859,4 +861,20 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
Seq(1, 1, 2, 3))
}
}

test("Return type of the given function has to be IntegerType") {
val comparator = {
val comp = ArraySort.comparator _
(left: Expression, right: Expression) => Literal.create("hello", StringType)
}

val result = arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator).checkInputDataTypes()
assert(result == DataTypeMismatch(
errorSubClass = "UNEXPECTED_RETURN_TYPE",
messageParameters = Map(
"functionName" -> toSQLId("lambdafunction"),
"expectedType" -> toSQLType(IntegerType),
"actualType" -> toSQLType(StringType)
)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,22 @@ FROM various_maps
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve 'map_zip_with(various_maps.decimal_map1, various_maps.decimal_map2, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7

{
"errorClass" : "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
"messageParameters" : {
"functionName" : "`map_zip_with`",
"leftType" : "\"DECIMAL(36,0)\"",
"rightType" : "\"DECIMAL(36,35)\"",
"sqlExpr" : "\"map_zip_with(decimal_map1, decimal_map2, lambdafunction(struct(k, v1, v2), k, v1, v2))\""
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 81,
"fragment" : "map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2))"
} ]
}

-- !query
SELECT map_zip_with(decimal_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m
Expand All @@ -110,7 +124,22 @@ FROM various_maps
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve 'map_zip_with(various_maps.decimal_map2, various_maps.int_map, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7
{
"errorClass" : "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
"messageParameters" : {
"functionName" : "`map_zip_with`",
"leftType" : "\"DECIMAL(36,35)\"",
"rightType" : "\"INT\"",
"sqlExpr" : "\"map_zip_with(decimal_map2, int_map, lambdafunction(struct(k, v1, v2), k, v1, v2))\""
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 76,
"fragment" : "map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2))"
} ]
}


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
)
}

test("The given function only supports array input") {
val df = Seq(1, 2, 3).toDF("a")
checkErrorMatchPVals(
exception = intercept[AnalysisException] {
df.select(array_sort(col("a"), (x, y) => x - y))
},
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
parameters = Map(
"sqlExpr" -> """"array_sort\(a, lambdafunction\(\(x_\d+ - y_\d+\), x_\d+, y_\d+\)\)"""",
"paramIndex" -> "1",
"requiredType" -> "\"ARRAY\"",
"inputSql" -> "\"a\"",
"inputType" -> "\"INT\""
))
}

test("sort_array/array_sort functions") {
val df = Seq(
(Array[Int](2, 1, 3), Array("b", "c", "a")),
Expand Down Expand Up @@ -3492,15 +3508,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
"requiredType" -> "\"ARRAY\""))
// scalastyle:on line.size.limit

val ex4 = intercept[AnalysisException] {
df.selectExpr("aggregate(s, 0, (acc, x) -> x)")
}
assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type"))
// scalastyle:off line.size.limit
checkError(
exception = intercept[AnalysisException] {
df.selectExpr("aggregate(s, 0, (acc, x) -> x)")
},
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
parameters = Map(
"sqlExpr" -> """"aggregate(s, 0, lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))"""",
"paramIndex" -> "3",
"inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"",
"inputType" -> "\"STRING\"",
"requiredType" -> "\"INT\""
))
// scalastyle:on line.size.limit

val ex4a = intercept[AnalysisException] {
df.select(aggregate(col("s"), lit(0), (acc, x) => x))
}
assert(ex4a.getMessage.contains("data type mismatch: argument 3 requires int type"))
// scalastyle:off line.size.limit
checkError(
exception = intercept[AnalysisException] {
df.select(aggregate(col("s"), lit(0), (acc, x) => x))
},
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
parameters = Map(
"sqlExpr" -> """"aggregate(s, 0, lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))"""",
"paramIndex" -> "3",
"inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"",
"inputType" -> "\"STRING\"",
"requiredType" -> "\"INT\""
))
// scalastyle:on line.size.limit

checkError(
exception =
Expand Down Expand Up @@ -3570,17 +3606,34 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
}
assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match"))

val ex2 = intercept[AnalysisException] {
df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))")
}
assert(ex2.getMessage.contains("The input to function map_zip_with should have " +
"been two maps with compatible key types"))
checkError(
exception = intercept[AnalysisException] {
df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))")
},
errorClass = "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
parameters = Map(
"sqlExpr" -> "\"map_zip_with(mis, mmi, lambdafunction(concat(x, y, z), x, y, z))\"",
"functionName" -> "`map_zip_with`",
"leftType" -> "\"INT\"",
"rightType" -> "\"MAP<INT, INT>\""),
context = ExpectedContext(
fragment = "map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))",
start = 0,
stop = 51))

val ex2a = intercept[AnalysisException] {
df.select(map_zip_with(df("mis"), col("mmi"), (x, y, z) => concat(x, y, z)))
}
assert(ex2a.getMessage.contains("The input to function map_zip_with should have " +
"been two maps with compatible key types"))
// scalastyle:off line.size.limit
checkError(
exception = intercept[AnalysisException] {
df.select(map_zip_with(df("mis"), col("mmi"), (x, y, z) => concat(x, y, z)))
},
errorClass = "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
matchPVals = true,
parameters = Map(
"sqlExpr" -> """"map_zip_with\(mis, mmi, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""",
"functionName" -> "`map_zip_with`",
"leftType" -> "\"INT\"",
"rightType" -> "\"MAP<INT, INT>\""))
// scalastyle:on line.size.limit

checkError(
exception = intercept[AnalysisException] {
Expand Down