Skip to content

Commit 0edab9c

Browse files
Add asending/descending support for sort_array
1 parent 80fc0f8 commit 0edab9c

File tree

5 files changed

+86
-22
lines changed

5 files changed

+86
-22
lines changed

python/pyspark/sql/functions.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -904,17 +904,19 @@ def size(col):
904904

905905

906906
@since(1.5)
907-
def sort_array(col):
907+
def sort_array(col, asc=True):
908908
"""
909909
Collection function: sorts the input array for the given column in ascending order.
910910
:param col: name of column or expression
911911
912912
>>> df = sqlContext.createDataFrame([([2, 1, 3],),([1],),([],)], ['data'])
913-
>>> df.select(sort_array(df.data)).collect()
914-
[Row(sort_array(data)=[1, 2, 3]), Row(sort_array(data)=[1]), Row(sort_array(data)=[])]
915-
"""
913+
>>> df.select(sort_array(df.data).alias('r')).collect()
914+
[Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])]
915+
>>> df.select(sort_array(df.data, asc=False).alias('r')).collect()
916+
[Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])]
917+
"""
916918
sc = SparkContext._active_spark_context
917-
return Column(sc._jvm.functions.sort_array(_to_java_column(col)))
919+
return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))
918920

919921

920922
class UserDefinedFunction(object):

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,30 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
4242
}
4343

4444
/**
45-
* Sorts the input array in ascending order according to the natural ordering of
45+
* Sorts the input array in ascending / descending order according to the natural ordering of
4646
* the array elements and returns it.
4747
*/
48-
case class SortArray(child: Expression)
49-
extends UnaryExpression with ExpectsInputTypes with CodegenFallback {
48+
case class SortArray(base: Expression, ascendingOrder: Expression)
49+
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
5050

51-
override def dataType: DataType = child.dataType
52-
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
51+
def this(e: Expression) = this(e, Literal(true))
5352

54-
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
53+
override def left: Expression = base
54+
override def right: Expression = ascendingOrder
55+
override def dataType: DataType = base.dataType
56+
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)
57+
58+
override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
5559
case _ @ ArrayType(n: AtomicType, _) => TypeCheckResult.TypeCheckSuccess
56-
case other => TypeCheckResult.TypeCheckFailure(
57-
s"Type $other is not supported for ordering operations")
60+
case _ @ ArrayType(n, _) => TypeCheckResult.TypeCheckFailure(
61+
s"Type $n is not the AtomicType, we can not perform the ordering operations")
62+
case other =>
63+
TypeCheckResult.TypeCheckFailure(s"ArrayType(AtomicType) is expected, but we got $other")
5864
}
5965

6066
@transient
6167
private lazy val lt: (Any, Any) => Boolean = {
62-
val ordering = child.dataType match {
68+
val ordering = base.dataType match {
6369
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
6470
}
6571

@@ -76,8 +82,27 @@ case class SortArray(child: Expression)
7682
}
7783
}
7884

79-
override def nullSafeEval(value: Any): Seq[Any] = {
80-
value.asInstanceOf[Seq[Any]].sortWith(lt)
85+
@transient
86+
private lazy val gt: (Any, Any) => Boolean = {
87+
val ordering = base.dataType match {
88+
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
89+
}
90+
91+
(left, right) => {
92+
if (left == null && right == null) {
93+
true
94+
} else if (left == null) {
95+
false
96+
} else if (right == null) {
97+
true
98+
} else {
99+
ordering.compare(left, right) > 0
100+
}
101+
}
102+
}
103+
104+
override def nullSafeEval(array: Any, ascending: Any): Seq[Any] = {
105+
array.asInstanceOf[Seq[Any]].sortWith(if (ascending.asInstanceOf[Boolean]) lt else gt)
81106
}
82107

83108
override def prettyName: String = "sort_array"

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,20 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
4848
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
4949
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
5050
val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType))
51+
val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType))
5152

52-
checkEvaluation(SortArray(a0), Seq(1, 2, 3))
53-
checkEvaluation(SortArray(a1), Seq[Integer]())
54-
checkEvaluation(SortArray(a2), Seq("a", "b"))
53+
checkEvaluation(new SortArray(a0), Seq(1, 2, 3))
54+
checkEvaluation(new SortArray(a1), Seq[Integer]())
55+
checkEvaluation(new SortArray(a2), Seq("a", "b"))
56+
checkEvaluation(new SortArray(a3), Seq(null, "a", "b"))
57+
checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3))
58+
checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]())
59+
checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b"))
60+
checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b"))
61+
checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1))
62+
checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]())
63+
checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a"))
64+
checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null))
5565

5666
checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
5767
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2200,8 +2200,16 @@ object functions {
22002200
* @group collection_funcs
22012201
* @since 1.5.0
22022202
*/
2203-
def sort_array(e: Column): Column = SortArray(e.expr)
2203+
def sort_array(e: Column): Column = sort_array(e, true)
22042204

2205+
/**
2206+
* Sorts the input array for the given column in ascending / descending order,
2207+
* according to the natural ordering of the array elements.
2208+
*
2209+
* @group collection_funcs
2210+
* @since 1.5.0
2211+
*/
2212+
def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr)
22052213

22062214
//////////////////////////////////////////////////////////////////////////////////////////////
22072215
//////////////////////////////////////////////////////////////////////////////////////////////

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,19 +280,38 @@ class DataFrameFunctionsSuite extends QueryTest {
280280
Row(Seq[Int](), Seq[String]()),
281281
Row(null, null))
282282
)
283+
checkAnswer(
284+
df.select(sort_array($"a", false), sort_array($"b", false)),
285+
Seq(
286+
Row(Seq(3, 2, 1), Seq("c", "b", "a")),
287+
Row(Seq[Int](), Seq[String]()),
288+
Row(null, null))
289+
)
283290
checkAnswer(
284291
df.selectExpr("sort_array(a)", "sort_array(b)"),
285292
Seq(
286293
Row(Seq(1, 2, 3), Seq("a", "b", "c")),
287294
Row(Seq[Int](), Seq[String]()),
288295
Row(null, null))
289296
)
297+
checkAnswer(
298+
df.selectExpr("sort_array(a, true)", "sort_array(b, false)"),
299+
Seq(
300+
Row(Seq(1, 2, 3), Seq("c", "b", "a")),
301+
Row(Seq[Int](), Seq[String]()),
302+
Row(null, null))
303+
)
290304

291305
val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b")
292306
assert(intercept[AnalysisException] {
293307
df2.selectExpr("sort_array(a)").collect()
294-
}.getMessage().contains("Type ArrayType(ArrayType(IntegerType,false),true) " +
295-
"is not supported for ordering operations"))
308+
}.getMessage().contains("Type ArrayType(IntegerType,false) is not the AtomicType, " +
309+
"we can not perform the ordering operations"))
310+
311+
val df3 = Seq(("xxx", "x")).toDF("a", "b")
312+
assert(intercept[AnalysisException] {
313+
df3.selectExpr("sort_array(a)").collect()
314+
}.getMessage().contains("ArrayType(AtomicType) is expected, but we got StringType"))
296315
}
297316

298317
test("array size function") {

0 commit comments

Comments
 (0)