Skip to content

Commit 6931022

Browse files
mgaido91gatorsmile
authored andcommitted
[SPARK-23917][SQL] Add array_max function
## What changes were proposed in this pull request? The PR adds the SQL function `array_max`. It takes an array as argument and returns the maximum value in it. ## How was this patch tested? added UTs Author: Marco Gaido <marcogaido91@gmail.com> Closes #21024 from mgaido91/SPARK-23917.
1 parent c096493 commit 6931022

File tree

8 files changed

+133
-6
lines changed

8 files changed

+133
-6
lines changed

python/pyspark/sql/functions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2080,6 +2080,21 @@ def size(col):
20802080
return Column(sc._jvm.functions.size(_to_java_column(col)))
20812081

20822082

2083+
@since(2.4)
2084+
def array_max(col):
2085+
"""
2086+
Collection function: returns the maximum value of the array.
2087+
2088+
:param col: name of column or expression
2089+
2090+
>>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
2091+
>>> df.select(array_max(df.data).alias('max')).collect()
2092+
[Row(max=3), Row(max=10)]
2093+
"""
2094+
sc = SparkContext._active_spark_context
2095+
return Column(sc._jvm.functions.array_max(_to_java_column(col)))
2096+
2097+
20832098
@since(1.5)
20842099
def sort_array(col, asc=True):
20852100
"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ object FunctionRegistry {
409409
expression[MapValues]("map_values"),
410410
expression[Size]("size"),
411411
expression[SortArray]("sort_array"),
412+
expression[ArrayMax]("array_max"),
412413
CreateStruct.registryEntry,
413414

414415
// misc functions

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -674,11 +674,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
674674
val evals = evalChildren.map(eval =>
675675
s"""
676676
|${eval.code}
677-
|if (!${eval.isNull} && (${ev.isNull} ||
678-
| ${ctx.genGreater(dataType, eval.value, ev.value)})) {
679-
| ${ev.isNull} = false;
680-
| ${ev.value} = ${eval.value};
681-
|}
677+
|${ctx.reassignIfGreater(dataType, ev, eval)}
682678
""".stripMargin
683679
)
684680

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,23 @@ class CodegenContext {
699699
case _ => s"(${genComp(dataType, c1, c2)}) > 0"
700700
}
701701

702+
/**
703+
* Generates code for updating `partialResult` if `item` is greater than it.
704+
*
705+
* @param dataType data type of the expressions
706+
* @param partialResult `ExprCode` representing the partial result which has to be updated
707+
* @param item `ExprCode` representing the new expression to evaluate for the result
708+
*/
709+
def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = {
710+
s"""
711+
|if (!${item.isNull} && (${partialResult.isNull} ||
712+
| ${genGreater(dataType, item.value, partialResult.value)})) {
713+
| ${partialResult.isNull} = false;
714+
| ${partialResult.value} = ${item.value};
715+
|}
716+
""".stripMargin
717+
}
718+
702719
/**
703720
* Generates code to do null safe execution, i.e. only execute the code when the input is not
704721
* null by adding null check if necessary.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.util.Comparator
2121
import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2323
import org.apache.spark.sql.catalyst.expressions.codegen._
24-
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
24+
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
2525
import org.apache.spark.sql.types._
2626

2727
/**
@@ -287,3 +287,69 @@ case class ArrayContains(left: Expression, right: Expression)
287287

288288
override def prettyName: String = "array_contains"
289289
}
290+
291+
292+
/**
293+
* Returns the maximum value in the array.
294+
*/
295+
@ExpressionDescription(
296+
usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.",
297+
examples = """
298+
Examples:
299+
> SELECT _FUNC_(array(1, 20, null, 3));
300+
20
301+
""", since = "2.4.0")
302+
case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
303+
304+
override def nullable: Boolean = true
305+
306+
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
307+
308+
private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
309+
310+
override def checkInputDataTypes(): TypeCheckResult = {
311+
val typeCheckResult = super.checkInputDataTypes()
312+
if (typeCheckResult.isSuccess) {
313+
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
314+
} else {
315+
typeCheckResult
316+
}
317+
}
318+
319+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
320+
val childGen = child.genCode(ctx)
321+
val javaType = CodeGenerator.javaType(dataType)
322+
val i = ctx.freshName("i")
323+
val item = ExprCode("",
324+
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
325+
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
326+
ev.copy(code =
327+
s"""
328+
|${childGen.code}
329+
|boolean ${ev.isNull} = true;
330+
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
331+
|if (!${childGen.isNull}) {
332+
| for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) {
333+
| ${ctx.reassignIfGreater(dataType, ev, item)}
334+
| }
335+
|}
336+
""".stripMargin)
337+
}
338+
339+
override protected def nullSafeEval(input: Any): Any = {
340+
var max: Any = null
341+
input.asInstanceOf[ArrayData].foreach(dataType, (_, item) =>
342+
if (item != null && (max == null || ordering.gt(item, max))) {
343+
max = item
344+
}
345+
)
346+
max
347+
}
348+
349+
override def dataType: DataType = child.dataType match {
350+
case ArrayType(dt, _) => dt
351+
case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
352+
}
353+
354+
override def prettyName: String = "array_max"
355+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
105105
checkEvaluation(ArrayContains(a3, Literal("")), null)
106106
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
107107
}
108+
109+
test("Array max") {
110+
checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10)
111+
checkEvaluation(
112+
ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc")
113+
checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null)
114+
checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null)
115+
checkEvaluation(
116+
ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123)
117+
}
108118
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3300,6 +3300,14 @@ object functions {
33003300
*/
33013301
def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) }
33023302

3303+
/**
3304+
* Returns the maximum value in the array.
3305+
*
3306+
* @group collection_funcs
3307+
* @since 2.4.0
3308+
*/
3309+
def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) }
3310+
33033311
/**
33043312
* Returns an unordered array containing the keys of the map.
33053313
* @group collection_funcs

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
413413
)
414414
}
415415

416+
test("array_max function") {
417+
val df = Seq(
418+
Seq[Option[Int]](Some(1), Some(3), Some(2)),
419+
Seq.empty[Option[Int]],
420+
Seq[Option[Int]](None),
421+
Seq[Option[Int]](None, Some(1), Some(-100))
422+
).toDF("a")
423+
424+
val answer = Seq(Row(3), Row(null), Row(null), Row(1))
425+
426+
checkAnswer(df.select(array_max(df("a"))), answer)
427+
checkAnswer(df.selectExpr("array_max(a)"), answer)
428+
}
429+
416430
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
417431
import DataFrameFunctionsSuite.CodegenFallbackExpr
418432
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {

0 commit comments

Comments
 (0)