Skip to content

Commit 69ca3fe

Browse files
committed
add error message and tests
1 parent c71d02c commit 69ca3fe

File tree

8 files changed

+210
-155
lines changed

8 files changed

+210
-155
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -619,18 +619,13 @@ trait HiveTypeCoercion {
619619
*/
620620
object Division extends Rule[LogicalPlan] {
621621
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
622-
// Skip nodes who's children have not been resolved yet.
623-
case e if !e.childrenResolved => e
622+
// Skip nodes who's children have not been resolved yet or input types do not match.
623+
case e if !e.childrenResolved || e.checkInputDataTypes().hasError => e
624624

625625
// Decimal and Double remain the same
626626
case d: Divide if d.resolved && d.dataType == DoubleType => d
627627
case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d
628628

629-
case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] =>
630-
Divide(l, Cast(r, DecimalType.Unlimited))
631-
case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] =>
632-
Divide(Cast(l, DecimalType.Unlimited), r)
633-
634629
case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
635630
}
636631
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,10 @@ abstract class Expression extends TreeNode[Expression] {
8888
}
8989

9090
/**
91-
* todo
91+
* Check the input data types, returns `TypeCheckResult.success` if it's valid,
92+
* or return a `TypeCheckResult` with an error message if invalid.
9293
*/
93-
def checkInputDataTypes: TypeCheckResult = TypeCheckResult.success
94+
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success
9495
}
9596

9697
abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 33 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,6 @@ abstract class UnaryArithmetic extends UnaryExpression {
2828
override def nullable: Boolean = child.nullable
2929
override def dataType: DataType = child.dataType
3030

31-
override def checkInputDataTypes: TypeCheckResult = {
32-
if (TypeUtils.validForNumericExpr(child.dataType)) {
33-
TypeCheckResult.success
34-
} else {
35-
TypeCheckResult.fail("todo")
36-
}
37-
}
38-
3931
override def eval(input: Row): Any = {
4032
val evalE = child.eval(input)
4133
if (evalE == null) {
@@ -52,6 +44,9 @@ abstract class UnaryArithmetic extends UnaryExpression {
5244
case class UnaryMinus(child: Expression) extends UnaryArithmetic {
5345
override def toString: String = s"-$child"
5446

47+
override def checkInputDataTypes(): TypeCheckResult =
48+
TypeUtils.checkForNumericExpr(child.dataType, "operator -")
49+
5550
private lazy val numeric = TypeUtils.getNumeric(dataType)
5651

5752
protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
@@ -62,6 +57,9 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
6257
override def nullable: Boolean = true
6358
override def toString: String = s"SQRT($child)"
6459

60+
override def checkInputDataTypes(): TypeCheckResult =
61+
TypeUtils.checkForNumericExpr(child.dataType, "function sqrt")
62+
6563
private lazy val numeric = TypeUtils.getNumeric(child.dataType)
6664

6765
protected override def evalInternal(evalE: Any) = {
@@ -77,6 +75,9 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
7775
case class Abs(child: Expression) extends UnaryArithmetic {
7876
override def toString: String = s"Abs($child)"
7977

78+
override def checkInputDataTypes(): TypeCheckResult =
79+
TypeUtils.checkForNumericExpr(child.dataType, "function abs")
80+
8081
private lazy val numeric = TypeUtils.getNumeric(dataType)
8182

8283
protected override def evalInternal(evalE: Any) = numeric.abs(evalE)
@@ -87,10 +88,10 @@ abstract class BinaryArithmetic extends BinaryExpression {
8788

8889
override def dataType: DataType = left.dataType
8990

90-
override def checkInputDataTypes: TypeCheckResult = {
91+
override def checkInputDataTypes(): TypeCheckResult = {
9192
if (left.dataType != right.dataType) {
9293
TypeCheckResult.fail(
93-
s"differing types in BinaryArithmetics -- ${left.dataType}, ${right.dataType}")
94+
s"differing types in BinaryArithmetic, ${left.dataType} != ${right.dataType}")
9495
} else {
9596
checkTypesInternal(dataType)
9697
}
@@ -123,13 +124,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
123124
// for `Add` in `HiveTypeCoercion`
124125
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)
125126

126-
protected def checkTypesInternal(t: DataType) = {
127-
if (TypeUtils.validForNumericExpr(t)) {
128-
TypeCheckResult.success
129-
} else {
130-
TypeCheckResult.fail("todo")
131-
}
132-
}
127+
protected def checkTypesInternal(t: DataType) =
128+
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
133129

134130
private lazy val numeric = TypeUtils.getNumeric(dataType)
135131

@@ -143,13 +139,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
143139
// for `Subtract` in `HiveTypeCoercion`
144140
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)
145141

146-
protected def checkTypesInternal(t: DataType) = {
147-
if (TypeUtils.validForNumericExpr(t)) {
148-
TypeCheckResult.success
149-
} else {
150-
TypeCheckResult.fail("todo")
151-
}
152-
}
142+
protected def checkTypesInternal(t: DataType) =
143+
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
153144

154145
private lazy val numeric = TypeUtils.getNumeric(dataType)
155146

@@ -163,13 +154,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
163154
// for `Multiply` in `HiveTypeCoercion`
164155
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)
165156

166-
protected def checkTypesInternal(t: DataType) = {
167-
if (TypeUtils.validForNumericExpr(t)) {
168-
TypeCheckResult.success
169-
} else {
170-
TypeCheckResult.fail("todo")
171-
}
172-
}
157+
protected def checkTypesInternal(t: DataType) =
158+
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
173159

174160
private lazy val numeric = TypeUtils.getNumeric(dataType)
175161

@@ -184,13 +170,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
184170
// for `Divide` in `HiveTypeCoercion`
185171
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)
186172

187-
protected def checkTypesInternal(t: DataType) = {
188-
if (TypeUtils.validForNumericExpr(t)) {
189-
TypeCheckResult.success
190-
} else {
191-
TypeCheckResult.fail("todo")
192-
}
193-
}
173+
protected def checkTypesInternal(t: DataType) =
174+
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
194175

195176
private lazy val div: (Any, Any) => Any = dataType match {
196177
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
@@ -220,13 +201,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
220201
// for `Remainder` in `HiveTypeCoercion`
221202
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)
222203

223-
protected def checkTypesInternal(t: DataType) = {
224-
if (TypeUtils.validForNumericExpr(t)) {
225-
TypeCheckResult.success
226-
} else {
227-
TypeCheckResult.fail("todo")
228-
}
229-
}
204+
protected def checkTypesInternal(t: DataType) =
205+
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
230206

231207
private lazy val integral = dataType match {
232208
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
@@ -254,13 +230,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
254230
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
255231
override def symbol: String = "&"
256232

257-
protected def checkTypesInternal(t: DataType) = {
258-
if (TypeUtils.validForBitwiseExpr(t)) {
259-
TypeCheckResult.success
260-
} else {
261-
TypeCheckResult.fail("todo")
262-
}
263-
}
233+
protected def checkTypesInternal(t: DataType) =
234+
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
264235

265236
private lazy val and: (Any, Any) => Any = dataType match {
266237
case ByteType =>
@@ -282,13 +253,8 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
282253
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
283254
override def symbol: String = "|"
284255

285-
protected def checkTypesInternal(t: DataType) = {
286-
if (TypeUtils.validForBitwiseExpr(t)) {
287-
TypeCheckResult.success
288-
} else {
289-
TypeCheckResult.fail("todo")
290-
}
291-
}
256+
protected def checkTypesInternal(t: DataType) =
257+
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
292258

293259
private lazy val or: (Any, Any) => Any = dataType match {
294260
case ByteType =>
@@ -310,13 +276,8 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
310276
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
311277
override def symbol: String = "^"
312278

313-
protected def checkTypesInternal(t: DataType) = {
314-
if (TypeUtils.validForBitwiseExpr(t)) {
315-
TypeCheckResult.success
316-
} else {
317-
TypeCheckResult.fail("todo")
318-
}
319-
}
279+
protected def checkTypesInternal(t: DataType) =
280+
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
320281

321282
private lazy val xor: (Any, Any) => Any = dataType match {
322283
case ByteType =>
@@ -338,13 +299,8 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
338299
case class BitwiseNot(child: Expression) extends UnaryArithmetic {
339300
override def toString: String = s"~$child"
340301

341-
override def checkInputDataTypes: TypeCheckResult = {
342-
if (TypeUtils.validForBitwiseExpr(dataType)) {
343-
TypeCheckResult.success
344-
} else {
345-
TypeCheckResult.fail("todo")
346-
}
347-
}
302+
override def checkInputDataTypes(): TypeCheckResult =
303+
TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~")
348304

349305
private lazy val not: (Any) => Any = dataType match {
350306
case ByteType =>
@@ -363,13 +319,8 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic {
363319
case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
364320
override def nullable: Boolean = left.nullable && right.nullable
365321

366-
protected def checkTypesInternal(t: DataType) = {
367-
if (TypeUtils.validForOrderingExpr(t)) {
368-
TypeCheckResult.success
369-
} else {
370-
TypeCheckResult.fail("todo")
371-
}
372-
}
322+
protected def checkTypesInternal(t: DataType) =
323+
TypeUtils.checkForOrderingExpr(t, "function maxOf")
373324

374325
private lazy val ordering = TypeUtils.getOrdering(dataType)
375326

@@ -395,13 +346,8 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
395346
case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
396347
override def nullable: Boolean = left.nullable && right.nullable
397348

398-
protected def checkTypesInternal(t: DataType) = {
399-
if (TypeUtils.validForOrderingExpr(t)) {
400-
TypeCheckResult.success
401-
} else {
402-
TypeCheckResult.fail("todo")
403-
}
404-
}
349+
protected def checkTypesInternal(t: DataType) =
350+
TypeUtils.checkForOrderingExpr(t, "function minOf")
405351

406352
private lazy val ordering = TypeUtils.getOrdering(dataType)
407353

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
3131

3232
override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
3333
override def dataType: DataType = DoubleType
34+
override def foldable: Boolean = child.foldable
3435
override def nullable: Boolean = true
3536
override def toString: String = s"$name($child)"
3637

0 commit comments

Comments
 (0)