Skip to content

[SPARK-7562][SPARK-6444][SQL] Improve error reporting for expression data type mismatch #6405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/src/test/scala/org/apache/spark/SparkFunSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ private[spark] abstract class SparkFunSuite extends FunSuite with Logging {
* Log the suite name and the test name before and after each test.
*
* Subclasses should never override this method. If they wish to run
* custom code before and after each test, they should should mix in
* the {{org.scalatest.BeforeAndAfter}} trait instead.
* custom code before and after each test, they should mix in the
* {{org.scalatest.BeforeAndAfter}} trait instead.
*/
final protected override def withFixture(test: NoArgTest): Outcome = {
val testName = test.text
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,17 @@ trait CheckAnalysis {
val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")

case e: Expression if e.checkInputDataTypes().isFailure =>
e.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(message) =>
e.failAnalysis(
s"cannot resolve '${e.prettyString}' due to data type mismatch: $message")
}

case c: Cast if !c.resolved =>
failAnalysis(
s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")

case b: BinaryExpression if !b.resolved =>
failAnalysis(
s"invalid expression ${b.prettyString} " +
s"between ${b.left.dataType.simpleString} and ${b.right.dataType.simpleString}")

case WindowExpression(UnresolvedWindowFunction(name, _), _) =>
failAnalysis(
s"Could not resolve window function '$name'. " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object HiveTypeCoercion {
* with primitive types, because in that case the precision and scale of the result depends on
* the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]].
*/
val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)
Expand All @@ -57,6 +57,17 @@ object HiveTypeCoercion {

case _ => None
}

/**
* Find the tightest common type of a set of types by continuously applying
* `findTightestCommonTypeOfTwo` on these types.
*/
private def findTightestCommonType(types: Seq[DataType]) = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case None => None
case Some(d) => findTightestCommonTypeOfTwo(d, c)
})
}
}

/**
Expand Down Expand Up @@ -180,7 +191,7 @@ trait HiveTypeCoercion {

case (l, r) if l.dataType != r.dataType =>
logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}")
findTightestCommonType(l.dataType, r.dataType).map { widestType =>
findTightestCommonTypeOfTwo(l.dataType, r.dataType).map { widestType =>
val newLeft =
if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)()
val newRight =
Expand Down Expand Up @@ -217,7 +228,7 @@ trait HiveTypeCoercion {
case e if !e.childrenResolved => e

case b: BinaryExpression if b.left.dataType != b.right.dataType =>
findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType =>
findTightestCommonTypeOfTwo(b.left.dataType, b.right.dataType).map { widestType =>
val newLeft =
if (b.left.dataType == widestType) b.left else Cast(b.left, widestType)
val newRight =
Expand Down Expand Up @@ -441,21 +452,18 @@ trait HiveTypeCoercion {
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
)

case LessThan(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))

case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))

case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))

case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
// When we compare 2 decimal types with different precisions, cast them to the smallest
// common precision.
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
val resultType = DecimalType(max(p1, p2), max(s1, s2))
b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType)))
case b @ BinaryComparison(e1 @ DecimalType.Fixed(_, _), e2)
if e2.dataType == DecimalType.Unlimited =>
b.makeCopy(Array(Cast(e1, DecimalType.Unlimited), e2))
case b @ BinaryComparison(e1, e2 @ DecimalType.Fixed(_, _))
if e1.dataType == DecimalType.Unlimited =>
b.makeCopy(Array(e1, Cast(e2, DecimalType.Unlimited)))

// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
Expand Down Expand Up @@ -570,7 +578,7 @@ trait HiveTypeCoercion {

case a @ CreateArray(children) if !a.resolved =>
val commonType = a.childTypes.reduce(
(a, b) => findTightestCommonType(a, b).getOrElse(StringType))
(a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType))
CreateArray(
children.map(c => if (c.dataType == commonType) c else Cast(c, commonType)))

Expand Down Expand Up @@ -599,14 +607,9 @@ trait HiveTypeCoercion {
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
val dt: Option[DataType] = Some(NullType)
val types = es.map(_.dataType)
val rt = types.foldLeft(dt)((r, c) => r match {
case None => None
case Some(d) => findTightestCommonType(d, c)
})
rt match {
case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt)))
findTightestCommonType(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None =>
sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}")
}
Expand All @@ -619,17 +622,13 @@ trait HiveTypeCoercion {
*/
object Division extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
// Skip nodes who has not been resolved yet,
// as this is an extra rule which should be applied at last.
case e if !e.resolved => e

// Decimal and Double remain the same
case d: Divide if d.resolved && d.dataType == DoubleType => d
case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d

case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] =>
Divide(l, Cast(r, DecimalType.Unlimited))
case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] =>
Divide(Cast(l, DecimalType.Unlimited), r)
case d: Divide if d.dataType == DoubleType => d
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d

case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
}
Expand All @@ -642,42 +641,33 @@ trait HiveTypeCoercion {
import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual =>
logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
val commonType = cw.valueTypes.reduce { (v1, v2) =>
findTightestCommonType(v1, v2).getOrElse(sys.error(
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
}
val transformedBranches = cw.branches.sliding(2, 2).map {
case Seq(when, value) if value.dataType != commonType =>
Seq(when, Cast(value, commonType))
case Seq(elseVal) if elseVal.dataType != commonType =>
Seq(Cast(elseVal, commonType))
case s => s
}.reduce(_ ++ _)
cw match {
case _: CaseWhen =>
CaseWhen(transformedBranches)
case CaseKeyWhen(key, _) =>
CaseKeyWhen(key, transformedBranches)
}

case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved =>
val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) =>
findTightestCommonType(v1, v2).getOrElse(sys.error(
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
}
val transformedBranches = ckw.branches.sliding(2, 2).map {
case Seq(when, then) if when.dataType != commonType =>
Seq(Cast(when, commonType), then)
case s => s
}.reduce(_ ++ _)
val transformedKey = if (ckw.key.dataType != commonType) {
Cast(ckw.key, commonType)
} else {
ckw.key
}
CaseKeyWhen(transformedKey, transformedBranches)
case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual =>
logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}")
val maybeCommonType = findTightestCommonType(c.valueTypes)
maybeCommonType.map { commonType =>
val castedBranches = c.branches.grouped(2).map {
case Seq(when, value) if value.dataType != commonType =>
Seq(when, Cast(value, commonType))
case Seq(elseVal) if elseVal.dataType != commonType =>
Seq(Cast(elseVal, commonType))
case other => other
}.reduce(_ ++ _)
c match {
case _: CaseWhen => CaseWhen(castedBranches)
case CaseKeyWhen(key, _) => CaseKeyWhen(key, castedBranches)
}
}.getOrElse(c)

case c: CaseKeyWhen if c.childrenResolved && !c.resolved =>
val maybeCommonType = findTightestCommonType((c.key +: c.whenList).map(_.dataType))
maybeCommonType.map { commonType =>
val castedBranches = c.branches.grouped(2).map {
case Seq(when, then) if when.dataType != commonType =>
Seq(Cast(when, commonType), then)
case other => other
}.reduce(_ ++ _)
CaseKeyWhen(Cast(c.key, commonType), castedBranches)
}.getOrElse(c)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

/**
* Represents the result of `Expression.checkInputDataTypes`.
* We will throw `AnalysisException` in `CheckAnalysis` if `isFailure` is true.
*/
trait TypeCheckResult {
def isFailure: Boolean = !isSuccess
def isSuccess: Boolean
}

object TypeCheckResult {

/**
* Represents the successful result of `Expression.checkInputDataTypes`.
*/
object TypeCheckSuccess extends TypeCheckResult {
def isSuccess: Boolean = true
}

/**
* Represents the failing result of `Expression.checkInputDataTypes`,
* with a error message to show the reason of failure.
*/
case class TypeCheckFailure(message: String) extends TypeCheckResult {
def isSuccess: Boolean = false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -45,11 +45,12 @@ abstract class Expression extends TreeNode[Expression] {

/**
* Returns `true` if this expression and all its children have been resolved to a specific schema
* and `false` if it still contains any unresolved placeholders. Implementations of expressions
* should override this if the resolution of this type of expression involves more than just
* the resolution of its children.
* and input data types checking passed, and `false` if it still contains any unresolved
* placeholders or has data types mismatch.
* Implementations of expressions should override this if the resolution of this type of
* expression involves more than just the resolution of its children and type checking.
*/
lazy val resolved: Boolean = childrenResolved
lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess

/**
* Returns the [[DataType]] of the result of evaluating this expression. It is
Expand Down Expand Up @@ -86,12 +87,21 @@ abstract class Expression extends TreeNode[Expression] {
case (i1, i2) => i1 == i2
}
}

/**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
* Note: it's not valid to call this method until `childrenResolved == true`
* TODO: we should remove the default implementation and implement it for all
* expressions with proper error message.
*/
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
}

abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
self: Product =>

def symbol: String
def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol")

override def foldable: Boolean = left.foldable && right.foldable

Expand Down Expand Up @@ -125,7 +135,13 @@ case class GroupExpression(children: Seq[Expression]) extends Expression {
* so that the proper type conversions can be performed in the analyzer.
*/
trait ExpectsInputTypes {
self: Expression =>

def expectedChildTypes: Seq[DataType]

override def checkInputDataTypes(): TypeCheckResult = {
// We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`,
// so type mismatch error won't be reported here, but for underling `Cast`s.
TypeCheckResult.TypeCheckSuccess
}
}
Loading