Skip to content

Commit 12c61f4

Browse files
committed
Accept only BinaryType for Md5
1 parent 1df0b5b commit 12c61f4

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,22 @@ import org.apache.spark.unsafe.types.UTF8String
2626

2727
/**
2828
* A function that calculates an MD5 128-bit checksum and returns it as a hex string
29-
* For input of type [[StringType]] or [[BinaryType]]
29+
* For input of type [[BinaryType]]
3030
*/
31-
case class Md5(child: Expression) extends UnaryExpression {
31+
case class Md5(child: Expression)
32+
extends UnaryExpression with ExpectsInputTypes {
3233

3334
override def dataType: DataType = StringType
3435

36+
override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
37+
3538
override def checkInputDataTypes(): TypeCheckResult =
36-
if (child.dataType == StringType || child.dataType == BinaryType) {
39+
if (child.dataType == BinaryType) {
3740
TypeCheckResult.TypeCheckSuccess
3841
} else {
3942
TypeCheckResult.TypeCheckFailure(
4043
s"types error in ${this.getClass.getSimpleName} " +
41-
s"get (${child.dataType}, expect StringType or BinaryType).")
44+
s"get (${child.dataType}, expect BinaryType).")
4245
}
4346

4447
override def children: Seq[Expression] = child :: Nil
@@ -47,10 +50,8 @@ case class Md5(child: Expression) extends UnaryExpression {
4750
val value = child.eval(input)
4851
if (value == null) {
4952
null
50-
} else if (child.dataType == BinaryType) {
53+
} else{
5154
UTF8String.fromString(DigestUtils.md5Hex(value.asInstanceOf[Array[Byte]]))
52-
} else {
53-
UTF8String.fromString(DigestUtils.md5Hex(value.asInstanceOf[UTF8String].getBytes))
5455
}
5556
}
5657

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,10 @@ import org.apache.spark.sql.types.{StringType, BinaryType}
2323
class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
2424

2525
test("md5") {
26-
checkEvaluation(Md5(Literal("ABC")), "902fbdd2b1df0c4f70b4a5d23525e932")
26+
checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932")
2727
checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
2828
"6ac1e56bc78f031059be7be854522c4c")
2929
checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
30-
checkEvaluation(Md5(Literal.create(null, StringType)), null)
3130
}
3231

3332
}

0 commit comments

Comments
 (0)