Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._
Expand Down Expand Up @@ -68,13 +69,8 @@ case class SortOrder(

override def children: Seq[Expression] = child +: sameOrderExpressions

override def checkInputDataTypes(): TypeCheckResult = {
if (RowOrdering.isOrderable(dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(s"cannot sort data type ${dataType.catalogString}")
}
}
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForOrderingExpr(dataType, prettyName)

override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, UserDefinedType}


/**
* Unwrap UDT data type column into its underlying type.
*/
Expand All @@ -33,8 +34,13 @@ case class UnwrapUDT(child: Expression) extends UnaryExpression with NonSQLExpre
if (child.dataType.isInstanceOf[UserDefinedType[_]]) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
s"Input type should be UserDefinedType but got ${child.dataType.catalogString}")
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> "1",
"requiredType" -> toSQLType("UserDefinedType"),
"inputSql" -> toSQLExpr(child),
"inputType" -> toSQLType(child.dataType)))
}
}
override def dataType: DataType = child.dataType.asInstanceOf[UserDefinedType[_]].sqlType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{COALESCE, NULL_CHECK, TreePattern}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._


/**
* An expression that is evaluated to the first non-null input.
*
Expand Down Expand Up @@ -57,8 +58,14 @@ case class Coalesce(children: Seq[Expression])

override def checkInputDataTypes(): TypeCheckResult = {
if (children.length < 1) {
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least one argument")
DataTypeMismatch(
errorSubClass = "WRONG_NUM_ARGS",
messageParameters = Map(
"functionName" -> toSQLId(prettyName),
"expectedNum" -> "> 0",
"actualNum" -> children.length.toString
)
)
} else {
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), prettyName)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.util.regex.Pattern

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.trees.UnaryLike
Expand Down Expand Up @@ -182,7 +184,12 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge

override def checkInputDataTypes(): TypeCheckResult = {
if (children.size > 3 || children.size < 2) {
TypeCheckResult.TypeCheckFailure(s"$prettyName function requires two or three arguments")
DataTypeMismatch(
errorSubClass = "WRONG_NUM_ARGS",
messageParameters = Map(
"functionName" -> toSQLId(prettyName),
"expectedNum" -> "[2, 3]",
"actualNum" -> children.length.toString))
} else {
super[ExpectsInputTypes].checkInputDataTypes()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
package org.apache.spark.sql.catalyst.expressions.xml

import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
Expand All @@ -42,7 +43,14 @@ abstract class XPathExtract

override def checkInputDataTypes(): TypeCheckResult = {
if (!path.foldable) {
TypeCheckFailure("path should be a string literal")
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "path",
"inputType" -> toSQLType(StringType),
"inputExpr" -> toSQLExpr(path)
)
)
} else {
super.checkInputDataTypes()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,14 @@ class AnalysisErrorSuite extends AnalysisTest {
testRelation2.groupBy($"a")(sum(UnresolvedStar(None))),
"Invalid usage of '*' in expression 'sum'." :: Nil)

errorTest(
errorClassTest(
"sorting by unsupported column types",
mapRelation.orderBy($"map".asc),
"sort" :: "type" :: "map<int,int>" :: Nil)
errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE",
messageParameters = Map(
"sqlExpr" -> "\"map ASC NULLS FIRST\"",
"functionName" -> "`sortorder`",
"dataType" -> "\"MAP<INT, INT>\""))

errorClassTest(
"sorting by attributes are not from grouping expressions",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer
)
)

assertError(Coalesce(Nil), "function coalesce requires at least one argument")
val coalesce = Coalesce(Nil)
checkError(
exception = intercept[AnalysisException] {
assertSuccess(coalesce)
},
errorClass = "DATATYPE_MISMATCH.WRONG_NUM_ARGS",
parameters = Map(
"sqlExpr" -> "\"coalesce()\"",
"functionName" -> toSQLId(coalesce.prettyName),
"expectedNum" -> "> 0",
"actualNum" -> "0"))

val murmur3Hash = new Murmur3Hash(Nil)
checkError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.Timestamp

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._
Expand Down Expand Up @@ -83,4 +84,15 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(SortPrefix(SortOrder(list1, Ascending)), 0L)
checkEvaluation(SortPrefix(SortOrder(nullVal, Ascending)), null)
}

test("Cannot sort map type") {
val m = Literal.create(Map(), MapType(StringType, StringType, valueContainsNull = false))
val sortOrderExpression = SortOrder(m, Ascending)
assert(sortOrderExpression.checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "INVALID_ORDERING_TYPE",
messageParameters = Map(
"functionName" -> "`sortorder`",
"dataType" -> "\"MAP<STRING, STRING>\"")))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1497,12 +1497,43 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

// arguments checking
assert(ParseUrl(Seq(Literal("1"))).checkInputDataTypes().isFailure)
assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal("3"), Literal("4")))
.checkInputDataTypes().isFailure)
assert(ParseUrl(Seq(Literal("1"), Literal(2))).checkInputDataTypes().isFailure)
assert(ParseUrl(Seq(Literal(1), Literal("2"))).checkInputDataTypes().isFailure)
assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal(3))).checkInputDataTypes().isFailure)
assert(ParseUrl(Seq(Literal("1"))).checkInputDataTypes() == DataTypeMismatch(
errorSubClass = "WRONG_NUM_ARGS",
messageParameters = Map(
"functionName" -> "`parse_url`",
"expectedNum" -> "[2, 3]",
"actualNum" -> "1")
))
assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal("3"),
Literal("4"))).checkInputDataTypes() == DataTypeMismatch(
errorSubClass = "WRONG_NUM_ARGS",
messageParameters = Map(
"functionName" -> "`parse_url`",
"expectedNum" -> "[2, 3]",
"actualNum" -> "4")
))
assert(ParseUrl(Seq(Literal("1"), Literal(2))).checkInputDataTypes() == DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> "2",
"requiredType" -> "\"STRING\"",
"inputSql" -> "\"2\"",
"inputType" -> "\"INT\"")))
assert(ParseUrl(Seq(Literal(1), Literal("2"))).checkInputDataTypes() == DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> "1",
"requiredType" -> "\"STRING\"",
"inputSql" -> "\"1\"",
"inputType" -> "\"INT\"")))
assert(ParseUrl(Seq(Literal("1"), Literal("2"),
Literal(3))).checkInputDataTypes() == DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> "3",
"requiredType" -> "\"STRING\"",
"inputSql" -> "\"3\"",
"inputType" -> "\"INT\"")))

// Test escaping of arguments
GenerateUnsafeProjection.generate(ParseUrl(Seq(Literal("\"quote"), Literal("\"quote"))) :: Nil)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType
import org.apache.spark.sql.types.BooleanType

class UnwrapUDTExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {

test("Input type should be UserDefinedType") {
val b1 = Literal.create(false, BooleanType)
val unwrapUDTExpression = UnwrapUDT(b1)
assert(unwrapUDTExpression.checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> "1",
"requiredType" -> toSQLType("UserDefinedType"),
"inputSql" -> "\"false\"",
"inputType" -> "\"BOOLEAN\"")))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.xml

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.StringType

Expand Down Expand Up @@ -195,7 +196,13 @@ class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {

// Validate that non-foldable paths are not supported.
val nonLitPath = exprCtor(Literal("abcd"), NonFoldableLiteral("/"))
assert(nonLitPath.checkInputDataTypes().isFailure)
assert(nonLitPath.checkInputDataTypes() == DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "path",
"inputType" -> "\"STRING\"",
"inputExpr" -> "\"nonfoldableliteral()\"")
))
}

testExpr(XPathBoolean)
Expand Down
Loading