Skip to content

Commit b71d325

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-8075] [SQL] apply type check interface to more expressions
a follow up of #6405. Note: It's not a big change, a lot of changing is due to I swap some code in `aggregates.scala` to make aggregate functions right below its corresponding aggregate expressions. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #6723 from cloud-fan/type-check and squashes the following commits: 2124301 [Wenchen Fan] fix tests 5a658bb [Wenchen Fan] add tests 287d3bb [Wenchen Fan] apply type check interface to more expressions
1 parent 7daa702 commit b71d325

File tree

21 files changed

+337
-290
lines changed

21 files changed

+337
-290
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,8 +587,8 @@ class Analyzer(
587587
failAnalysis(
588588
s"""Expect multiple names given for ${g.getClass.getName},
589589
|but only single name '${name}' specified""".stripMargin)
590-
case Alias(g: Generator, name) => Some((g, name :: Nil))
591-
case MultiAlias(g: Generator, names) => Some(g, names)
590+
case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil))
591+
case MultiAlias(g: Generator, names) if g.resolved => Some(g, names)
592592
case _ => None
593593
}
594594
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ trait HiveTypeCoercion {
317317
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
318318

319319
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
320+
case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
320321
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
321322
}
322323
}
@@ -590,11 +591,12 @@ trait HiveTypeCoercion {
590591
// Skip nodes who's children have not been resolved yet.
591592
case e if !e.childrenResolved => e
592593

593-
case a @ CreateArray(children) if !a.resolved =>
594-
val commonType = a.childTypes.reduce(
595-
(a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType))
596-
CreateArray(
597-
children.map(c => if (c.dataType == commonType) c else Cast(c, commonType)))
594+
case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 =>
595+
val types = children.map(_.dataType)
596+
findTightestCommonTypeAndPromoteToString(types) match {
597+
case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
598+
case None => a
599+
}
598600

599601
// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
600602
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
@@ -620,12 +622,11 @@ trait HiveTypeCoercion {
620622
// Coalesce should return the first non-null value, which could be any column
621623
// from the list. So we need to make sure the return type is deterministic and
622624
// compatible with every child column.
623-
case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
625+
case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
624626
val types = es.map(_.dataType)
625627
findTightestCommonTypeAndPromoteToString(types) match {
626628
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
627-
case None =>
628-
sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}")
629+
case None => c
629630
}
630631
}
631632
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
2222
import java.text.{DateFormat, SimpleDateFormat}
2323

2424
import org.apache.spark.Logging
25-
import org.apache.spark.sql.catalyst
25+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2626
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
2727
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2828
import org.apache.spark.sql.types._
@@ -31,7 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String
3131
/** Cast the child expression to the target data type. */
3232
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
3333

34-
override lazy val resolved = childrenResolved && resolve(child.dataType, dataType)
34+
override def checkInputDataTypes(): TypeCheckResult = {
35+
if (resolve(child.dataType, dataType)) {
36+
TypeCheckResult.TypeCheckSuccess
37+
} else {
38+
TypeCheckResult.TypeCheckFailure(
39+
s"cannot cast ${child.dataType} to $dataType")
40+
}
41+
}
3542

3643
override def foldable: Boolean = child.foldable
3744

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,7 @@ abstract class Expression extends TreeNode[Expression] {
162162
/**
163163
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
164164
* or returns a `TypeCheckResult` with an error message if invalid.
165-
* Note: it's not valid to call this method until `childrenResolved == true`
166-
* TODO: we should remove the default implementation and implement it for all
167-
* expressions with proper error message.
165+
* Note: it's not valid to call this method until `childrenResolved == true`.
168166
*/
169167
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
170168
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ object ExtractValue {
9696
}
9797
}
9898

99+
/**
100+
* A common interface of all kinds of extract value expressions.
101+
* Note: concrete extract value expressions are created only by `ExtractValue.apply`,
102+
* we don't need to do type check for them.
103+
*/
99104
trait ExtractValue extends UnaryExpression {
100105
self: Product =>
101106
}
@@ -179,9 +184,6 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
179184

180185
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
181186

182-
override lazy val resolved = childrenResolved &&
183-
child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType]
184-
185187
protected def evalNotNull(value: Any, ordinal: Any) = {
186188
// TODO: consider using Array[_] for ArrayType child to avoid
187189
// boxing of primitives
@@ -203,8 +205,6 @@ case class GetMapValue(child: Expression, ordinal: Expression)
203205

204206
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
205207

206-
override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType]
207-
208208
protected def evalNotNull(value: Any, ordinal: Any) = {
209209
val baseValue = value.asInstanceOf[Map[Any, _]]
210210
baseValue.get(ordinal).orNull

0 commit comments

Comments
 (0)