Skip to content

Commit 6eaadff

Browse files
committed
add equal type constraint to EqualTo
1 parent 3affbd8 commit 6eaadff

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,12 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
178178
s"differing types in ${this.getClass.getSimpleName} " +
179179
s"(${left.dataType} and ${right.dataType}).")
180180
} else {
181-
TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)
181+
checkTypesInternal(dataType)
182182
}
183183
}
184184

185+
protected def checkTypesInternal(t: DataType): TypeCheckResult
186+
185187
override def eval(input: Row): Any = {
186188
val evalE1 = left.eval(input)
187189
if (evalE1 == null) {
@@ -203,8 +205,7 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
203205
case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
204206
override def symbol: String = "="
205207

206-
// EqualTo don't need 2 equal orderable types
207-
override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success
208+
override protected def checkTypesInternal(t: DataType) = TypeCheckResult.success
208209

209210
protected override def evalInternal(l: Any, r: Any) = {
210211
if (left.dataType != BinaryType) l == r
@@ -216,8 +217,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
216217
override def symbol: String = "<=>"
217218
override def nullable: Boolean = false
218219

219-
// EqualNullSafe don't need 2 equal orderable types
220-
override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success
220+
override protected def checkTypesInternal(t: DataType) = TypeCheckResult.success
221221

222222
override def eval(input: Row): Any = {
223223
val l = left.eval(input)
@@ -235,6 +235,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
235235
case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
236236
override def symbol: String = "<"
237237

238+
override protected def checkTypesInternal(t: DataType) =
239+
TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)
240+
238241
private lazy val ordering = TypeUtils.getOrdering(left.dataType)
239242

240243
protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lt(evalE1, evalE2)
@@ -243,6 +246,9 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
243246
case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
244247
override def symbol: String = "<="
245248

249+
override protected def checkTypesInternal(t: DataType) =
250+
TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)
251+
246252
private lazy val ordering = TypeUtils.getOrdering(left.dataType)
247253

248254
protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lteq(evalE1, evalE2)
@@ -251,6 +257,9 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
251257
case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
252258
override def symbol: String = ">"
253259

260+
override protected def checkTypesInternal(t: DataType) =
261+
TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)
262+
254263
private lazy val ordering = TypeUtils.getOrdering(left.dataType)
255264

256265
protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gt(evalE1, evalE2)
@@ -259,6 +268,9 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
259268
case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
260269
override def symbol: String = ">="
261270

271+
override protected def checkTypesInternal(t: DataType) =
272+
TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)
273+
262274
private lazy val ordering = TypeUtils.getOrdering(left.dataType)
263275

264276
protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gteq(evalE1, evalE2)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,16 @@ class ExpressionTypeCheckingSuite extends FunSuite {
9696
}
9797

9898
test("check types for predicates") {
99-
// EqualTo don't have type constraint
100-
assertSuccess(EqualTo('intField, 'booleanField))
101-
assertSuccess(EqualNullSafe('intField, 'booleanField))
102-
10399
// We will cast String to Double for binary comparison
100+
assertSuccess(EqualTo('intField, 'stringField))
101+
assertSuccess(EqualNullSafe('intField, 'stringField))
104102
assertSuccess(LessThan('intField, 'stringField))
105103
assertSuccess(LessThanOrEqual('intField, 'stringField))
106104
assertSuccess(GreaterThan('intField, 'stringField))
107105
assertSuccess(GreaterThanOrEqual('intField, 'stringField))
108106

107+
assertErrorForDifferingTypes(EqualTo('intField, 'booleanField))
108+
assertErrorForDifferingTypes(EqualNullSafe('intField, 'booleanField))
109109
assertErrorForDifferingTypes(LessThan('intField, 'booleanField))
110110
assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField))
111111
assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))

0 commit comments

Comments
 (0)