Skip to content

Commit cffb67c

Browse files
committed
to have resolved call the data type check function
1 parent 6eaadff commit cffb67c

File tree

4 files changed

+11
-29
lines changed

4 files changed

+11
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ trait HiveTypeCoercion {
407407
Union(newLeft, newRight)
408408

409409
// fix decimal precision for expressions
410-
case q => q.transformExpressionsUp {
410+
case q => q.transformExpressions {
411411
// Skip nodes whose children have not been resolved yet
412412
case e if !e.childrenResolved => e
413413

@@ -619,12 +619,13 @@ trait HiveTypeCoercion {
619619
*/
620620
object Division extends Rule[LogicalPlan] {
621621
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
622-
// Skip nodes who's children have not been resolved yet or input types do not match.
623-
case e if !e.childrenResolved || e.checkInputDataTypes().hasError => e
622+
// Skip Divisions who has not been resolved yet,
623+
// as this is an extra rule which should be applied at last.
624+
case e if !e.resolved => e
624625

625626
// Decimal and Double remain the same
626-
case d: Divide if d.resolved && d.dataType == DoubleType => d
627-
case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d
627+
case d: Divide if d.dataType == DoubleType => d
628+
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
628629

629630
case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
630631
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ abstract class Expression extends TreeNode[Expression] {
4949
* should override this if the resolution of this type of expression involves more than just
5050
* the resolution of its children.
5151
*/
52-
lazy val resolved: Boolean = childrenResolved
52+
lazy val resolved: Boolean = childrenResolved && !checkInputDataTypes().hasError
5353

5454
/**
5555
* Returns the [[DataType]] of the result of evaluating this expression. It is

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,6 @@ abstract class BinaryArithmetic extends BinaryExpression {
121121
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
122122
override def symbol: String = "+"
123123

124-
// We will always cast fixed decimal to unlimited decimal
125-
// for `Add` in `HiveTypeCoercion`
126-
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)
127-
128124
protected def checkTypesInternal(t: DataType) =
129125
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
130126

@@ -136,10 +132,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
136132
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
137133
override def symbol: String = "-"
138134

139-
// We will always cast fixed decimal to unlimited decimal
140-
// for `Subtract` in `HiveTypeCoercion`
141-
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)
142-
143135
protected def checkTypesInternal(t: DataType) =
144136
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
145137

@@ -151,10 +143,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
151143
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
152144
override def symbol: String = "*"
153145

154-
// We will always cast fixed decimal to unlimited decimal
155-
// for `Multiply` in `HiveTypeCoercion`
156-
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)
157-
158146
protected def checkTypesInternal(t: DataType) =
159147
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
160148

@@ -167,10 +155,6 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
167155
override def symbol: String = "/"
168156
override def nullable: Boolean = true
169157

170-
// We will always cast fixed decimal to unlimited decimal
171-
// for `Divide` in `HiveTypeCoercion`
172-
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)
173-
174158
protected def checkTypesInternal(t: DataType) =
175159
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
176160

@@ -198,10 +182,6 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
198182
override def symbol: String = "%"
199183
override def nullable: Boolean = true
200184

201-
// We will always cast fixed decimal to unlimited decimal
202-
// for `Remainder` in `HiveTypeCoercion`
203-
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)
204-
205185
protected def checkTypesInternal(t: DataType) =
206186
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
207187

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
22-
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
23-
import org.apache.spark.sql.catalyst.dsl.plans._
2422
import org.apache.spark.sql.catalyst.dsl.expressions._
25-
import org.apache.spark.sql.types.{BooleanType, StringType}
23+
import org.apache.spark.sql.catalyst.dsl.plans._
24+
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
25+
import org.apache.spark.sql.types.StringType
26+
2627
import org.scalatest.FunSuite
2728

2829

0 commit comments

Comments
 (0)