Skip to content

Commit d03e0af

Browse files
dilipbiswalcloud-fan
authored andcommitted
[SPARK-25522][SQL] Improve type promotion for input arguments of elementAt function
## What changes were proposed in this pull request? In ElementAt, when first argument is MapType, we should coerce the key type and the second argument based on findTightestCommonType. This is not happening currently. We may produce wrong output as we will incorrectly downcast the right hand side double expression to int. ```SQL spark-sql> select element_at(map(1,"one", 2, "two"), 2.2); two ``` Also, when the first argument is ArrayType, the second argument should be an integer type or a smaller integral type that can be safely casted to an integer type. Currently we may do an unsafe cast. In the following case, we should fail with an error as 2.2 is not a integer index. But instead we down cast it to int currently and return a result instead. ```SQL spark-sql> select element_at(array(1,2), 1.24D); 1 ``` This PR also supports implicit cast between two MapTypes. I have followed similar logic that exists today to do implicit casts between two array types. ## How was this patch tested? Added new tests in DataFrameFunctionSuite, TypeCoercionSuite. Closes #22544 from dilipbiswal/SPARK-25522. Authored-by: Dilip Biswal <dbiswal@us.ibm.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent ff87613 commit d03e0af

File tree

5 files changed

+154
-22
lines changed

5 files changed

+154
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -950,6 +950,25 @@ object TypeCoercion {
950950
if !Cast.forceNullable(fromType, toType) =>
951951
implicitCast(fromType, toType).map(ArrayType(_, false)).orNull
952952

953+
// Implicit cast between Map types.
954+
// Follows the same semantics of implicit casting between two array types.
955+
// Refer to documentation above. Make sure that both key and values
956+
// can not be null after the implicit cast operation by calling forceNullable
957+
// method.
958+
case (MapType(fromKeyType, fromValueType, fn), MapType(toKeyType, toValueType, tn))
959+
if !Cast.forceNullable(fromKeyType, toKeyType) && Cast.resolvableNullability(fn, tn) =>
960+
if (Cast.forceNullable(fromValueType, toValueType) && !tn) {
961+
null
962+
} else {
963+
val newKeyType = implicitCast(fromKeyType, toKeyType).orNull
964+
val newValueType = implicitCast(fromValueType, toValueType).orNull
965+
if (newKeyType != null && newValueType != null) {
966+
MapType(newKeyType, newValueType, tn)
967+
} else {
968+
null
969+
}
970+
}
971+
953972
case _ => null
954973
}
955974
Option(ret)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ object Cast {
183183
case _ => false
184184
}
185185

186-
private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
186+
def resolvableNullability(from: Boolean, to: Boolean): Boolean = !from || to
187187
}
188188

189189
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2154,21 +2154,34 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
21542154
}
21552155

21562156
override def inputTypes: Seq[AbstractDataType] = {
2157-
Seq(TypeCollection(ArrayType, MapType),
2158-
left.dataType match {
2159-
case _: ArrayType => IntegerType
2160-
case _: MapType => mapKeyType
2161-
case _ => AnyDataType // no match for a wrong 'left' expression type
2162-
}
2163-
)
2157+
(left.dataType, right.dataType) match {
2158+
case (arr: ArrayType, e2: IntegralType) if (e2 != LongType) =>
2159+
Seq(arr, IntegerType)
2160+
case (MapType(keyType, valueType, hasNull), e2) =>
2161+
TypeCoercion.findTightestCommonType(keyType, e2) match {
2162+
case Some(dt) => Seq(MapType(dt, valueType, hasNull), dt)
2163+
case _ => Seq.empty
2164+
}
2165+
case (l, r) => Seq.empty
2166+
2167+
}
21642168
}
21652169

21662170
override def checkInputDataTypes(): TypeCheckResult = {
2167-
super.checkInputDataTypes() match {
2168-
case f: TypeCheckResult.TypeCheckFailure => f
2169-
case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] =>
2170-
TypeUtils.checkForOrderingExpr(mapKeyType, s"function $prettyName")
2171-
case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess
2171+
(left.dataType, right.dataType) match {
2172+
case (_: ArrayType, e2) if e2 != IntegerType =>
2173+
TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
2174+
s"been ${ArrayType.simpleString} followed by a ${IntegerType.simpleString}, but it's " +
2175+
s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
2176+
case (MapType(e1, _, _), e2) if (!e2.sameType(e1)) =>
2177+
TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
2178+
s"been ${MapType.simpleString} followed by a value of same key type, but it's " +
2179+
s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
2180+
case (e1, _) if (!e1.isInstanceOf[MapType] && !e1.isInstanceOf[ArrayType]) =>
2181+
TypeCheckResult.TypeCheckFailure(s"The first argument to function $prettyName should " +
2182+
s"have been ${ArrayType.simpleString} or ${MapType.simpleString} type, but its " +
2183+
s"${left.dataType.catalogString} type.")
2184+
case _ => TypeCheckResult.TypeCheckSuccess
21722185
}
21732186
}
21742187

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,43 @@ class TypeCoercionSuite extends AnalysisTest {
257257
shouldNotCast(checkedType, IntegralType)
258258
}
259259

260-
test("implicit type cast - MapType(StringType, StringType)") {
261-
val checkedType = MapType(StringType, StringType)
262-
checkTypeCasting(checkedType, castableTypes = Seq(checkedType))
263-
shouldNotCast(checkedType, DecimalType)
264-
shouldNotCast(checkedType, NumericType)
265-
shouldNotCast(checkedType, IntegralType)
260+
test("implicit type cast between two Map types") {
261+
val sourceType = MapType(IntegerType, IntegerType, true)
262+
val castableTypes = numericTypes ++ Seq(StringType).filter(!Cast.forceNullable(IntegerType, _))
263+
val targetTypes = numericTypes.filter(!Cast.forceNullable(IntegerType, _)).map { t =>
264+
MapType(t, sourceType.valueType, valueContainsNull = true)
265+
}
266+
val nonCastableTargetTypes = allTypes.filterNot(castableTypes.contains(_)).map {t =>
267+
MapType(t, sourceType.valueType, valueContainsNull = true)
268+
}
269+
270+
// Tests that its possible to setup implicit casts between two map types when
271+
// source map's key type is integer and the target map's key type are either Byte, Short,
272+
// Long, Double, Float, Decimal(38, 18) or String.
273+
targetTypes.foreach { targetType =>
274+
shouldCast(sourceType, targetType, targetType)
275+
}
276+
277+
// Tests that its not possible to setup implicit casts between two map types when
278+
// source map's key type is integer and the target map's key type are either Binary,
279+
// Boolean, Date, Timestamp, Array, Struct, CaleandarIntervalType or NullType
280+
nonCastableTargetTypes.foreach { targetType =>
281+
shouldNotCast(sourceType, targetType)
282+
}
283+
284+
// Tests that its not possible to cast from nullable map type to not nullable map type.
285+
val targetNotNullableTypes = allTypes.filterNot(_ == IntegerType).map { t =>
286+
MapType(t, sourceType.valueType, valueContainsNull = false)
287+
}
288+
val sourceMapExprWithValueNull =
289+
CreateMap(Seq(Literal.default(sourceType.keyType),
290+
Literal.create(null, sourceType.valueType)))
291+
targetNotNullableTypes.foreach { targetType =>
292+
val castDefault =
293+
TypeCoercion.ImplicitTypeCasts.implicitCast(sourceMapExprWithValueNull, targetType)
294+
assert(castDefault.isEmpty,
295+
s"Should not be able to cast $sourceType to $targetType, but got $castDefault")
296+
}
266297
}
267298

268299
test("implicit type cast - StructType().add(\"a1\", StringType)") {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,11 +1211,80 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
12111211
Seq(Row("3"), Row(""), Row(null))
12121212
)
12131213

1214-
val e = intercept[AnalysisException] {
1214+
val e1 = intercept[AnalysisException] {
12151215
Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)")
12161216
}
1217-
assert(e.message.contains(
1218-
"argument 1 requires (array or map) type, however, '`_1`' is of string type"))
1217+
val errorMsg1 =
1218+
s"""
1219+
|The first argument to function element_at should have been array or map type, but
1220+
|its string type.
1221+
""".stripMargin.replace("\n", " ").trim()
1222+
assert(e1.message.contains(errorMsg1))
1223+
1224+
checkAnswer(
1225+
OneRowRelation().selectExpr("element_at(array(2, 1), 2S)"),
1226+
Seq(Row(1))
1227+
)
1228+
1229+
checkAnswer(
1230+
OneRowRelation().selectExpr("element_at(array('a', 'b'), 1Y)"),
1231+
Seq(Row("a"))
1232+
)
1233+
1234+
checkAnswer(
1235+
OneRowRelation().selectExpr("element_at(array(1, 2, 3), 3)"),
1236+
Seq(Row(3))
1237+
)
1238+
1239+
val e2 = intercept[AnalysisException] {
1240+
OneRowRelation().selectExpr("element_at(array('a', 'b'), 1L)")
1241+
}
1242+
val errorMsg2 =
1243+
s"""
1244+
|Input to function element_at should have been array followed by a int, but it's
1245+
|[array<string>, bigint].
1246+
""".stripMargin.replace("\n", " ").trim()
1247+
assert(e2.message.contains(errorMsg2))
1248+
1249+
checkAnswer(
1250+
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2Y)"),
1251+
Seq(Row("b"))
1252+
)
1253+
1254+
checkAnswer(
1255+
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1S)"),
1256+
Seq(Row("a"))
1257+
)
1258+
1259+
checkAnswer(
1260+
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2)"),
1261+
Seq(Row("b"))
1262+
)
1263+
1264+
checkAnswer(
1265+
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2L)"),
1266+
Seq(Row("b"))
1267+
)
1268+
1269+
checkAnswer(
1270+
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.0D)"),
1271+
Seq(Row("a"))
1272+
)
1273+
1274+
checkAnswer(
1275+
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.23D)"),
1276+
Seq(Row(null))
1277+
)
1278+
1279+
val e3 = intercept[AnalysisException] {
1280+
OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), '1')")
1281+
}
1282+
val errorMsg3 =
1283+
s"""
1284+
|Input to function element_at should have been map followed by a value of same
1285+
|key type, but it's [map<int,string>, string].
1286+
""".stripMargin.replace("\n", " ").trim()
1287+
assert(e3.message.contains(errorMsg3))
12191288
}
12201289

12211290
test("array_union functions") {

0 commit comments

Comments
 (0)