Skip to content

Commit b374ddf

Browse files
committed
make stringcomparison extends ExpectsInputTypes
1 parent 8aa5aea commit b374ddf

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE
154154
}
155155

156156
/** A base trait for functions that compare two strings, returning a boolean. */
157-
trait StringComparison {
157+
trait StringComparison extends ExpectsInputTypes {
158158
self: BinaryExpression =>
159159

160160
def compare(l: UTF8String, r: UTF8String): Boolean
@@ -163,6 +163,8 @@ trait StringComparison {
163163

164164
override def nullable: Boolean = left.nullable || right.nullable
165165

166+
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
167+
166168
override def eval(input: Row): Any = {
167169
val leftEval = left.eval(input)
168170
if(leftEval == null) {
@@ -183,27 +185,24 @@ trait StringComparison {
183185
* A function that returns true if the string `left` contains the string `right`.
184186
*/
185187
case class Contains(left: Expression, right: Expression)
186-
extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
188+
extends BinaryExpression with Predicate with StringComparison {
187189
override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
188-
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
189190
}
190191

191192
/**
192193
* A function that returns true if the string `left` starts with the string `right`.
193194
*/
194195
case class StartsWith(left: Expression, right: Expression)
195-
extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
196+
extends BinaryExpression with Predicate with StringComparison {
196197
override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
197-
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
198198
}
199199

200200
/**
201201
* A function that returns true if the string `left` ends with the string `right`.
202202
*/
203203
case class EndsWith(left: Expression, right: Expression)
204-
extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
204+
extends BinaryExpression with Predicate with StringComparison {
205205
override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
206-
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
207206
}
208207

209208
/**

0 commit comments

Comments
 (0)