Skip to content

[SPARK-23936][SQL] Implement map_concat #21073

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f22d3df
Initial commit
bersprockets Apr 14, 2018
84eeec5
Remove unused variable in test
bersprockets Apr 15, 2018
f6fbbc8
Cleanup
bersprockets Apr 15, 2018
4ed8627
Checkpoint non-working codegen
bersprockets Apr 17, 2018
aaee5b8
Checkpoint somewhat working codegen
bersprockets Apr 17, 2018
e08362a
Checkpoint better working codegen
bersprockets Apr 17, 2018
2032801
Require at least two input maps
bersprockets Apr 17, 2018
e149d06
Small cleanup
bersprockets Apr 17, 2018
e4170cf
Remove redundant null check
bersprockets Apr 17, 2018
b3085f0
Any null input means null result (ala Presto)
bersprockets Apr 19, 2018
71f0151
Remove redundant null check
bersprockets Apr 19, 2018
006835d
Review feedback
bersprockets Apr 24, 2018
cf64d83
Check for null value in generated code
bersprockets Apr 27, 2018
fbe00b2
Add since to expression description
bersprockets Apr 27, 2018
83784cc
Allow valueContainsNull to vary; Make checkInputDataTypes more in lin…
bersprockets Apr 27, 2018
79f9304
Add a few more tests
bersprockets Apr 27, 2018
cb0f57f
One more test
bersprockets Apr 27, 2018
57d10cb
Fix import statement; Add two small tests
bersprockets Apr 30, 2018
cda1158
As per SPARK-9415, cannot compare Maps, therefore cannot support them…
bersprockets May 3, 2018
83deda4
Updates for some review comments
bersprockets May 6, 2018
370151e
Initial commit of use of splitExpressionsWithCurrentInputs
bersprockets May 7, 2018
3305827
Add test
bersprockets May 7, 2018
f967483
Fix indentation
bersprockets May 7, 2018
d437199
Fix after rebase
bersprockets May 24, 2018
206db97
Review feedback: use pre-existing empty collections
bersprockets May 31, 2018
549300f
Allow duplicate keys
bersprockets Jun 20, 2018
1b52dd1
Remove extra line added during rebase
bersprockets Jun 21, 2018
969c66e
Review comments
bersprockets Jun 24, 2018
47f0cf5
Review comments
bersprockets Jun 26, 2018
3c0da03
Initial implementation of type coercion for map_concat
bersprockets Jun 30, 2018
03328a4
Simplify type coercion for map_concat parameters
bersprockets Jul 6, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2490,6 +2490,28 @@ def arrays_zip(*cols):
return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column)))


@since(2.4)
def map_concat(*cols):
"""Returns the union of all the given maps.

:param cols: list of column names (string) or list of :class:`Column` expressions

>>> from pyspark.sql.functions import map_concat
>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2")
>>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False)
+--------------------------------+
|map3 |
+--------------------------------+
|[1 -> a, 2 -> b, 3 -> c, 1 -> d]|
+--------------------------------+
"""
sc = SparkContext._active_spark_context
if len(cols) == 1 and isinstance(cols[0], (list, set)):
cols = cols[0]
jc = sc._jvm.functions.map_concat(_to_seq(sc, cols, _to_java_column))
return Column(jc)


# ---------------------------- User Defined Function ----------------------------------

class PandasUDFType(object):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,12 @@ object CatalystTypeConverters {
map,
(key: Any) => convertToCatalyst(key),
(value: Any) => convertToCatalyst(value))
case (keys: Array[_], values: Array[_]) =>
// case for mapdata with duplicate keys
new ArrayBasedMapData(
new GenericArrayData(keys.map(convertToCatalyst)),
new GenericArrayData(values.map(convertToCatalyst))
)
case other => other
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ object FunctionRegistry {
expression[MapValues]("map_values"),
expression[MapEntries]("map_entries"),
expression[MapFromEntries]("map_from_entries"),
expression[MapConcat]("map_concat"),
expression[Size]("size"),
expression[Slice]("slice"),
expression[Size]("cardinality"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,14 @@ object TypeCoercion {
case None => s
}

case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) &&
!haveSameType(children) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => MapConcat(children.map(Cast(_, finalDataType)))
case None => m
}

case m @ CreateMap(children) if m.keys.length == m.values.length &&
(!haveSameType(m.keys) || !haveSameType(m.values)) =>
val newKeys = if (haveSameType(m.keys)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,237 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
override def prettyName: String = "map_entries"
}

/**
* Returns the union of all the given maps.
*/
@ExpressionDescription(
usage = "_FUNC_(map, ...) - Returns the union of all the given maps",
examples = """
Examples:
> SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd'));
[[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]]
""", since = "2.4.0")
case class MapConcat(children: Seq[Expression]) extends Expression {

override def checkInputDataTypes(): TypeCheckResult = {
var funcName = s"function $prettyName"
if (children.exists(!_.dataType.isInstanceOf[MapType])) {
TypeCheckResult.TypeCheckFailure(
s"input to $funcName should all be of type map, but it's " +
children.map(_.dataType.simpleString).mkString("[", ", ", "]"))
} else {
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName)
}
}

override def dataType: MapType = {
val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption
.getOrElse(MapType(StringType, StringType))
val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType])
.exists(_.valueContainsNull)
if (dt.valueContainsNull != valueContainsNull) {
dt.copy(valueContainsNull = valueContainsNull)
} else {
dt
}
}

override def nullable: Boolean = children.exists(_.nullable)

override def eval(input: InternalRow): Any = {
val maps = children.map(_.eval(input))
if (maps.contains(null)) {
return null
}
val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray())
val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray())

val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements())
if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " +
s"elements due to exceeding the map size limit " +
s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
}
val finalKeyArray = new Array[AnyRef](numElements.toInt)
val finalValueArray = new Array[AnyRef](numElements.toInt)
var position = 0
for (i <- keyArrayDatas.indices) {
val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType)
val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType)
Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length)
Array.copy(valueArray, 0, finalValueArray, position, valueArray.length)
position += keyArray.length
}

new ArrayBasedMapData(new GenericArrayData(finalKeyArray),
new GenericArrayData(finalValueArray))
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val mapCodes = children.map(_.genCode(ctx))
val keyType = dataType.keyType
val valueType = dataType.valueType
val argsName = ctx.freshName("args")
val hasNullName = ctx.freshName("hasNull")
val mapDataClass = classOf[MapData].getName
val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName
val arrayDataClass = classOf[ArrayData].getName

val init =
s"""
|$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}];
|boolean ${ev.isNull}, $hasNullName = false;
|$mapDataClass ${ev.value} = null;
""".stripMargin

val assignments = mapCodes.zipWithIndex.map { case (m, i) =>
s"""
|if (!$hasNullName) {
| ${m.code}
| $argsName[$i] = ${m.value};
| if (${m.isNull}) {
| $hasNullName = true;
| }
|}
""".stripMargin
}

val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = assignments,
funcName = "getMapConcatInputs",
extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", hasNullName) :: Nil,
returnType = "boolean",
makeSplitFunction = body =>
s"""
|$body
|return $hasNullName;
""".stripMargin,
foldFunctions = _.map(funcCall => s"$hasNullName = $funcCall;").mkString("\n")
)

val idxName = ctx.freshName("idx")
val numElementsName = ctx.freshName("numElems")
val finKeysName = ctx.freshName("finalKeys")
val finValsName = ctx.freshName("finalValues")

val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) {
genCodeForPrimitiveArrays(ctx, keyType, false)
} else {
genCodeForNonPrimitiveArrays(ctx, keyType)
}

val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) {
genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull)
} else {
genCodeForNonPrimitiveArrays(ctx, valueType)
}

val keyArgsName = ctx.freshName("keyArgs")
val valArgsName = ctx.freshName("valArgs")

val mapMerge =
s"""
|${ev.isNull} = $hasNullName;
|if (!${ev.isNull}) {
| $arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}];
| $arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}];
| long $numElementsName = 0;
| for (int $idxName = 0; $idxName < $argsName.length; $idxName++) {
| $keyArgsName[$idxName] = $argsName[$idxName].keyArray();
| $valArgsName[$idxName] = $argsName[$idxName].valueArray();
| $numElementsName += $argsName[$idxName].numElements();
| }
| if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw new RuntimeException("Unsuccessful attempt to concat maps with " +
| $numElementsName + " elements due to exceeding the map size limit " +
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
| }
| $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName,
| (int) $numElementsName);
| $arrayDataClass $finValsName = $valueConcatenator.concat($valArgsName,
| (int) $numElementsName);
| ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName);
|}
""".stripMargin

ev.copy(
code = code"""
|$init
|$codes
|$mapMerge
""".stripMargin)
}

private def genCodeForPrimitiveArrays(
ctx: CodegenContext,
elementType: DataType,
checkForNull: Boolean): String = {
val counter = ctx.freshName("counter")
val arrayData = ctx.freshName("arrayData")
val argsName = ctx.freshName("args")
val numElemName = ctx.freshName("numElements")
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)

val setterCode1 =
s"""
|$arrayData.set$primitiveValueTypeName(
| $counter,
| ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}
|);""".stripMargin

val setterCode = if (checkForNull) {
s"""
|if ($argsName[y].isNullAt(z)) {
| $arrayData.setNullAt($counter);
|} else {
| $setterCode1
|}""".stripMargin
} else {
setterCode1
}

s"""
|new Object() {
| public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < $argsName[y].numElements(); z++) {
| $setterCode
| $counter++;
| }
| }
| return $arrayData;
| }
|}""".stripMargin.stripPrefix("\n")
}

private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
val genericArrayClass = classOf[GenericArrayData].getName
val arrayData = ctx.freshName("arrayObjects")
val counter = ctx.freshName("counter")
val argsName = ctx.freshName("args")
val numElemName = ctx.freshName("numElements")

s"""
|new Object() {
| public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {;
| Object[] $arrayData = new Object[$numElemName];
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < $argsName[y].numElements(); z++) {
| $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")};
| $counter++;
| }
| }
| return new $genericArrayClass($arrayData);
| }
|}""".stripMargin.stripPrefix("\n")
}

override def prettyName: String = "map_concat"
}

/**
* Returns a map created from the given array of entries.
*/
Expand Down
Loading