Skip to content

Commit

Permalink
[SPARK-48186][SQL] Add support for AbstractMapType
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Addition of an abstract MapType (similar to abstract ArrayType in sql internal types) which accepts `StringTypeCollated` as `keyType` & `valueType`. Apart from extending this interface for all Spark functions, this PR also introduces collation awareness for json expression: schema_of_json.

### Why are the changes needed?
This is needed in order to enable collation support for functions that use collated maps.

### Does this PR introduce _any_ user-facing change?
Yes, users should now be able to use collated strings within arguments for json function: schema_of_json.

### How was this patch tested?
E2e sql tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #46458 from uros-db/abstract-map.

Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
uros-db authored and cloud-fan committed May 9, 2024
1 parent 6cc3dc2 commit a4ab82b
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.internal.types

import org.apache.spark.sql.types.{AbstractDataType, DataType, MapType}


/**
* Use AbstractMapType(AbstractDataType, AbstractDataType)
* for defining expected types for expression parameters.
*/
case class AbstractMapType(
keyType: AbstractDataType,
valueType: AbstractDataType
) extends AbstractDataType {

override private[sql] def defaultConcreteType: DataType =
MapType(keyType.defaultConcreteType, valueType.defaultConcreteType, valueContainsNull = true)

override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[MapType] &&
keyType.acceptsType(other.asInstanceOf[MapType].keyType) &&
valueType.acceptsType(other.asInstanceOf[MapType].valueType)
}

override private[spark] def simpleString: String =
s"map<${keyType.simpleString}, ${valueType.simpleString}>"
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeAnyCollation}
import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType, VariantType}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -57,7 +58,7 @@ object ExprUtils extends QueryErrorsBase {

def convertToMapData(exp: Expression): Map[String, String] = exp match {
case m: CreateMap
if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) =>
if AbstractMapType(StringTypeAnyCollation, StringTypeAnyCollation).acceptsType(m.dataType) =>
val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData]
ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) =>
key.toString -> value.toString
Expand All @@ -77,7 +78,7 @@ object ExprUtils extends QueryErrorsBase {
columnNameOfCorruptRecord: String): Unit = {
schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
val f = schema(corruptFieldIndex)
if (f.dataType != StringType || !f.nullable) {
if (!f.dataType.isInstanceOf[StringType] || !f.nullable) {
throw QueryCompilationErrors.invalidFieldTypeForCorruptRecordError()
}
}
Expand Down Expand Up @@ -110,7 +111,7 @@ object ExprUtils extends QueryErrorsBase {
*/
def checkJsonSchema(schema: DataType): TypeCheckResult = {
val isInvalid = schema.existsRecursively {
case MapType(keyType, _, _) if keyType != StringType => true
case MapType(keyType, _, _) if !keyType.isInstanceOf[StringType] => true
case _ => false
}
if (isInvalid) {
Expand All @@ -133,7 +134,7 @@ object ExprUtils extends QueryErrorsBase {
def checkXmlSchema(schema: DataType): TypeCheckResult = {
val isInvalid = schema.existsRecursively {
// XML field names must be StringType
case MapType(keyType, _, _) if keyType != StringType => true
case MapType(keyType, _, _) if !keyType.isInstanceOf[StringType] => true
case _ => false
}
if (isInvalid) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ case class SchemaOfJson(
child = child,
options = ExprUtils.convertToMapData(options))

override def dataType: DataType = StringType
override def dataType: DataType = SQLConf.get.defaultStringType

override def nullable: Boolean = false

Expand Down Expand Up @@ -921,7 +921,8 @@ case class SchemaOfJson(
.map(ArrayType(_, containsNull = at.containsNull))
.getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull))
case other: DataType =>
jsonInferSchema.canonicalizeType(other, jsonOptions).getOrElse(StringType)
jsonInferSchema.canonicalizeType(other, jsonOptions).getOrElse(
SQLConf.get.defaultStringType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,40 @@ class CollationSQLExpressionsSuite
})
}

test("Support SchemaOfJson json expression with collation") {
case class SchemaOfJsonTestCase(
input: String,
collationName: String,
result: Row
)

val testCases = Seq(
SchemaOfJsonTestCase("'[{\"col\":0}]'",
"UTF8_BINARY", Row("ARRAY<STRUCT<col: BIGINT>>")),
SchemaOfJsonTestCase("'[{\"col\":01}]', map('allowNumericLeadingZeros', 'true')",
"UTF8_BINARY_LCASE", Row("ARRAY<STRUCT<col: BIGINT>>")),
SchemaOfJsonTestCase("'[]'",
"UNICODE", Row("ARRAY<STRING>")),
SchemaOfJsonTestCase("''",
"UNICODE_CI", Row("STRING"))
)

// Supported collations
testCases.foreach(t => {
val query =
s"""
|SELECT schema_of_json(${t.input})
|""".stripMargin
// Result & data type
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
val testQuery = sql(query)
checkAnswer(testQuery, t.result)
val dataType = StringType(t.collationName)
assert(testQuery.schema.fields.head.dataType.sameType(dataType))
}
})
}

test("Support StringToMap expression with collation") {
// Supported collations
case class StringToMapTestCase[R](t: String, p: String, k: String, c: String, result: R)
Expand Down

0 comments on commit a4ab82b

Please sign in to comment.