diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index a50dad7c8cdb8..00abdf4ee19d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveS import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, DataType, StringType} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType} object CollationTypeCasts extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { @@ -85,6 +85,11 @@ object CollationTypeCasts extends TypeCoercionRule { private def extractStringType(dt: DataType): StringType = dt match { case st: StringType => st case ArrayType(et, _) => extractStringType(et) + case MapType(kt, vt, _) => if (hasStringType(kt)) { + extractStringType(kt) + } else { + extractStringType(vt) + } } /** @@ -102,6 +107,14 @@ object CollationTypeCasts extends TypeCoercionRule { case st: StringType if st.collationId != castType.collationId => castType case ArrayType(arrType, nullable) => castStringType(arrType, castType).map(ArrayType(_, nullable)).orNull + case MapType(keyType, valueType, nullable) => + val newKeyType = castStringType(keyType, castType).getOrElse(keyType) + val newValueType = castStringType(valueType, castType).getOrElse(valueType) + if (newKeyType != keyType || newValueType != valueType) { + MapType(newKeyType, newValueType, nullable) + } else { + null + } case _ => null } Option(ret) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 936bb22baa467..7866f47c28b13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractStringType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, StringTypeAnyCollation} import org.apache.spark.sql.types._ import org.apache.spark.sql.types.UpCastRule.numericPrecedence @@ -1048,6 +1048,15 @@ object TypeCoercion extends TypeCoercionBase { } } + case (MapType(fromKeyType, fromValueType, fn), AbstractMapType(toKeyType, toValueType)) => + val newKeyType = implicitCast(fromKeyType, toKeyType).orNull + val newValueType = implicitCast(fromValueType, toValueType).orNull + if (newKeyType != null && newValueType != null) { + MapType(newKeyType, newValueType, fn) + } else { + null + } + case _ => null } Option(ret) @@ -1080,10 +1089,10 @@ object TypeCoercion extends TypeCoercionBase { /** * Whether the data type contains StringType. */ - @tailrec def hasStringType(dt: DataType): Boolean = dt match { case _: StringType => true case ArrayType(et, _) => hasStringType(et) + case MapType(kt, vt, _) => hasStringType(kt) || hasStringType(vt) // Add StructType if we support string promotion for struct fields in the future. case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index e9fa362de14cd..d9d7cd2cd0c1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeAnyCollation} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -85,7 +85,7 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType: override def foldable: Boolean = false override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, MapType(StringType, StringType)) + Seq(StringTypeAnyCollation, AbstractMapType(StringTypeAnyCollation, StringTypeAnyCollation)) override def left: Expression = errorClass override def right: Expression = errorParms diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 828245bb3fdd6..f3d07ba47b715 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -21,8 +21,8 @@ import java.text.SimpleDateFormat import scala.collection.immutable.Seq -import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkRuntimeException} -import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException} +import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -1636,3 +1636,9 @@ class CollationSQLExpressionsSuite } // scalastyle:on nonascii + +class CollationSQLExpressionsANSIOffSuite extends CollationSQLExpressionsSuite { + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.ANSI_ENABLED, false) + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index b22a762a29547..657fd4504cac1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeAnyCollation} import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { @@ -954,10 +955,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", parameters = Map( "functionName" -> "`=`", - "dataType" -> toSQLType(MapType( - StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")), - StringType - )), + "dataType" -> toSQLType(AbstractMapType(StringTypeAnyCollation, StringTypeAnyCollation)), "sqlExpr" -> "\"(m = m)\""), context = ExpectedContext(ctx, query.length - ctx.length, query.length - 1)) } @@ -1010,25 +1008,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { |select map('a' collate utf8_binary_lcase, 1, 'b' collate utf8_binary_lcase, 2) |['A' collate utf8_binary_lcase] |""".stripMargin), Seq(Row(1))) - val ctx = "map('aaa' collate utf8_binary_lcase, 1, 'AAA' collate utf8_binary_lcase, 2)['AaA']" - val query = s"select $ctx" - checkError( - exception = intercept[AnalysisException](sql(query)), - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - parameters = Map( - "sqlExpr" -> "\"map(collate(aaa), 1, collate(AAA), 2)[AaA]\"", - "paramIndex" -> "second", - "inputSql" -> "\"AaA\"", - "inputType" -> toSQLType(StringType), - "requiredType" -> toSQLType(StringType( - CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) - ), - context = ExpectedContext( - fragment = ctx, - start = query.length - ctx.length, - stop = query.length - 1 - ) - ) } test("window aggregates should respect collation") {