Skip to content

Commit f79410c

Browse files
committed
[SPARK-8721][SQL] Rename ExpectsInputTypes => AutoCastInputTypes.
Author: Reynold Xin <rxin@databricks.com> Closes apache#7109 from rxin/auto-cast and squashes the following commits: a914cc3 [Reynold Xin] [SPARK-8721][SQL] Rename ExpectsInputTypes => AutoCastInputTypes.
1 parent ea775b0 commit f79410c

File tree

6 files changed

+71
-79
lines changed

6 files changed

+71
-79
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ trait HiveTypeCoercion {
116116
IfCoercion ::
117117
Division ::
118118
PropagateTypes ::
119-
ExpectedInputConversion ::
119+
AddCastForAutoCastInputTypes ::
120120
Nil
121121

122122
/**
@@ -709,15 +709,15 @@ trait HiveTypeCoercion {
709709

710710
/**
711711
* Casts types according to the expected input types for Expressions that have the trait
712-
* `ExpectsInputTypes`.
712+
* [[AutoCastInputTypes]].
713713
*/
714-
object ExpectedInputConversion extends Rule[LogicalPlan] {
714+
object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] {
715715

716716
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
717717
// Skip nodes who's children have not been resolved yet.
718718
case e if !e.childrenResolved => e
719719

720-
case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes =>
720+
case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes =>
721721
val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map {
722722
case (child, actual, expected) =>
723723
if (actual == expected) child else Cast(child, expected)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
261261
* Expressions that require a specific `DataType` as input should implement this trait
262262
* so that the proper type conversions can be performed in the analyzer.
263263
*/
264-
trait ExpectsInputTypes {
264+
trait AutoCastInputTypes {
265265
self: Expression =>
266266

267267
def expectedChildTypes: Seq[DataType]

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 55 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ abstract class LeafMathExpression(c: Double, name: String)
5656
* @param name The short name of the function
5757
*/
5858
abstract class UnaryMathExpression(f: Double => Double, name: String)
59-
extends UnaryExpression with Serializable with ExpectsInputTypes {
59+
extends UnaryExpression with Serializable with AutoCastInputTypes {
6060
self: Product =>
6161

6262
override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
@@ -99,7 +99,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
9999
* @param name The short name of the function
100100
*/
101101
abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
102-
extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product =>
102+
extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product =>
103103

104104
override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
105105

@@ -211,19 +211,11 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia
211211
}
212212

213213
case class Bin(child: Expression)
214-
extends UnaryExpression with Serializable with ExpectsInputTypes {
215-
216-
val name: String = "BIN"
217-
218-
override def foldable: Boolean = child.foldable
219-
override def nullable: Boolean = true
220-
override def toString: String = s"$name($child)"
214+
extends UnaryExpression with Serializable with AutoCastInputTypes {
221215

222216
override def expectedChildTypes: Seq[DataType] = Seq(LongType)
223217
override def dataType: DataType = StringType
224218

225-
def funcName: String = name.toLowerCase
226-
227219
override def eval(input: InternalRow): Any = {
228220
val evalE = child.eval(input)
229221
if (evalE == null) {
@@ -239,61 +231,13 @@ case class Bin(child: Expression)
239231
}
240232
}
241233

242-
////////////////////////////////////////////////////////////////////////////////////////////////////
243-
////////////////////////////////////////////////////////////////////////////////////////////////////
244-
// Binary math functions
245-
////////////////////////////////////////////////////////////////////////////////////////////////////
246-
////////////////////////////////////////////////////////////////////////////////////////////////////
247-
248-
249-
case class Atan2(left: Expression, right: Expression)
250-
extends BinaryMathExpression(math.atan2, "ATAN2") {
251-
252-
override def eval(input: InternalRow): Any = {
253-
val evalE1 = left.eval(input)
254-
if (evalE1 == null) {
255-
null
256-
} else {
257-
val evalE2 = right.eval(input)
258-
if (evalE2 == null) {
259-
null
260-
} else {
261-
// With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0
262-
val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0,
263-
evalE2.asInstanceOf[Double] + 0.0)
264-
if (result.isNaN) null else result
265-
}
266-
}
267-
}
268-
269-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
270-
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s"""
271-
if (Double.valueOf(${ev.primitive}).isNaN()) {
272-
${ev.isNull} = true;
273-
}
274-
"""
275-
}
276-
}
277-
278-
case class Pow(left: Expression, right: Expression)
279-
extends BinaryMathExpression(math.pow, "POWER") {
280-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
281-
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s"""
282-
if (Double.valueOf(${ev.primitive}).isNaN()) {
283-
${ev.isNull} = true;
284-
}
285-
"""
286-
}
287-
}
288234

289235
/**
290236
* If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format.
291-
* Otherwise if the number is a STRING,
292-
* it converts each character into its hexadecimal representation and returns the resulting STRING.
293-
* Negative numbers would be treated as two's complement.
237+
* Otherwise if the number is a STRING, it converts each character into its hex representation
238+
* and returns the resulting STRING. Negative numbers would be treated as two's complement.
294239
*/
295-
case class Hex(child: Expression)
296-
extends UnaryExpression with Serializable {
240+
case class Hex(child: Expression) extends UnaryExpression with Serializable {
297241

298242
override def dataType: DataType = StringType
299243

@@ -337,7 +281,7 @@ case class Hex(child: Expression)
337281
private def doHex(bytes: Array[Byte], length: Int): UTF8String = {
338282
val value = new Array[Byte](length * 2)
339283
var i = 0
340-
while(i < length) {
284+
while (i < length) {
341285
value(i * 2) = Character.toUpperCase(Character.forDigit(
342286
(bytes(i) & 0xF0) >>> 4, 16)).toByte
343287
value(i * 2 + 1) = Character.toUpperCase(Character.forDigit(
@@ -362,6 +306,54 @@ case class Hex(child: Expression)
362306
}
363307
}
364308

309+
310+
////////////////////////////////////////////////////////////////////////////////////////////////////
311+
////////////////////////////////////////////////////////////////////////////////////////////////////
312+
// Binary math functions
313+
////////////////////////////////////////////////////////////////////////////////////////////////////
314+
////////////////////////////////////////////////////////////////////////////////////////////////////
315+
316+
317+
case class Atan2(left: Expression, right: Expression)
318+
extends BinaryMathExpression(math.atan2, "ATAN2") {
319+
320+
override def eval(input: InternalRow): Any = {
321+
val evalE1 = left.eval(input)
322+
if (evalE1 == null) {
323+
null
324+
} else {
325+
val evalE2 = right.eval(input)
326+
if (evalE2 == null) {
327+
null
328+
} else {
329+
// With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0
330+
val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0,
331+
evalE2.asInstanceOf[Double] + 0.0)
332+
if (result.isNaN) null else result
333+
}
334+
}
335+
}
336+
337+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
338+
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s"""
339+
if (Double.valueOf(${ev.primitive}).isNaN()) {
340+
${ev.isNull} = true;
341+
}
342+
"""
343+
}
344+
}
345+
346+
case class Pow(left: Expression, right: Expression)
347+
extends BinaryMathExpression(math.pow, "POWER") {
348+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
349+
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s"""
350+
if (Double.valueOf(${ev.primitive}).isNaN()) {
351+
${ev.isNull} = true;
352+
}
353+
"""
354+
}
355+
}
356+
365357
case class Hypot(left: Expression, right: Expression)
366358
extends BinaryMathExpression(math.hypot, "HYPOT")
367359

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.unsafe.types.UTF8String
3131
* For input of type [[BinaryType]]
3232
*/
3333
case class Md5(child: Expression)
34-
extends UnaryExpression with ExpectsInputTypes {
34+
extends UnaryExpression with AutoCastInputTypes {
3535

3636
override def dataType: DataType = StringType
3737

@@ -61,7 +61,7 @@ case class Md5(child: Expression)
6161
* the hash length is not one of the permitted values, the return value is NULL.
6262
*/
6363
case class Sha2(left: Expression, right: Expression)
64-
extends BinaryExpression with Serializable with ExpectsInputTypes {
64+
extends BinaryExpression with Serializable with AutoCastInputTypes {
6565

6666
override def dataType: DataType = StringType
6767

@@ -146,7 +146,7 @@ case class Sha2(left: Expression, right: Expression)
146146
* A function that calculates a sha1 hash value and returns it as a hex string
147147
* For input of type [[BinaryType]] or [[StringType]]
148148
*/
149-
case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes {
149+
case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTypes {
150150

151151
override def dataType: DataType = StringType
152152

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ trait PredicateHelper {
7070
}
7171

7272

73-
case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes {
73+
case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes {
7474
override def foldable: Boolean = child.foldable
7575
override def nullable: Boolean = child.nullable
7676
override def toString: String = s"NOT $child"
@@ -123,7 +123,7 @@ case class InSet(value: Expression, hset: Set[Any])
123123
}
124124

125125
case class And(left: Expression, right: Expression)
126-
extends BinaryExpression with Predicate with ExpectsInputTypes {
126+
extends BinaryExpression with Predicate with AutoCastInputTypes {
127127

128128
override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
129129

@@ -172,7 +172,7 @@ case class And(left: Expression, right: Expression)
172172
}
173173

174174
case class Or(left: Expression, right: Expression)
175-
extends BinaryExpression with Predicate with ExpectsInputTypes {
175+
extends BinaryExpression with Predicate with AutoCastInputTypes {
176176

177177
override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
178178

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
2424
import org.apache.spark.sql.types._
2525
import org.apache.spark.unsafe.types.UTF8String
2626

27-
trait StringRegexExpression extends ExpectsInputTypes {
27+
trait StringRegexExpression extends AutoCastInputTypes {
2828
self: BinaryExpression =>
2929

3030
def escape(v: String): String
@@ -111,7 +111,7 @@ case class RLike(left: Expression, right: Expression)
111111
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
112112
}
113113

114-
trait CaseConversionExpression extends ExpectsInputTypes {
114+
trait CaseConversionExpression extends AutoCastInputTypes {
115115
self: UnaryExpression =>
116116

117117
def convert(v: UTF8String): UTF8String
@@ -158,7 +158,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE
158158
}
159159

160160
/** A base trait for functions that compare two strings, returning a boolean. */
161-
trait StringComparison extends ExpectsInputTypes {
161+
trait StringComparison extends AutoCastInputTypes {
162162
self: BinaryExpression =>
163163

164164
def compare(l: UTF8String, r: UTF8String): Boolean
@@ -221,7 +221,7 @@ case class EndsWith(left: Expression, right: Expression)
221221
* Defined for String and Binary types.
222222
*/
223223
case class Substring(str: Expression, pos: Expression, len: Expression)
224-
extends Expression with ExpectsInputTypes {
224+
extends Expression with AutoCastInputTypes {
225225

226226
def this(str: Expression, pos: Expression) = {
227227
this(str, pos, Literal(Integer.MAX_VALUE))
@@ -295,7 +295,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
295295
/**
296296
* A function that return the length of the given string expression.
297297
*/
298-
case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes {
298+
case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes {
299299
override def dataType: DataType = IntegerType
300300
override def expectedChildTypes: Seq[DataType] = Seq(StringType)
301301

0 commit comments

Comments
 (0)