Skip to content

Commit cf6c9ca

Browse files
chenghao-inteldavies
authored andcommitted
[SPARK-8232] [SQL] Add sort_array support
This PR is based on #7581 , just fix the conflict. Author: Cheng Hao <hao.cheng@intel.com> Author: Davies Liu <davies@databricks.com> Closes #7851 from davies/sort_array and squashes the following commits: a80ef66 [Davies Liu] fix conflict 7cfda65 [Davies Liu] Merge branch 'master' of github.com:apache/spark into sort_array 664c960 [Cheng Hao] update the sort_array by using the ArrayData 276d2d5 [Cheng Hao] add empty line 0edab9c [Cheng Hao] Add asending/descending support for sort_array 80fc0f8 [Cheng Hao] Add type checking a42b678 [Cheng Hao] Add sort_array support
1 parent 8765665 commit cf6c9ca

File tree

6 files changed

+187
-7
lines changed

6 files changed

+187
-7
lines changed

python/pyspark/sql/functions.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
'sha1',
5252
'sha2',
5353
'size',
54+
'sort_array',
5455
'sparkPartitionId',
5556
'struct',
5657
'udf',
@@ -570,8 +571,10 @@ def length(col):
570571
def format_number(col, d):
571572
"""Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
572573
and returns the result as a string.
574+
573575
:param col: the column name of the numeric value to be formatted
574576
:param d: the N decimal places
577+
575578
>>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect()
576579
[Row(v=u'5.0000')]
577580
"""
@@ -954,6 +957,23 @@ def size(col):
954957
return Column(sc._jvm.functions.size(_to_java_column(col)))
955958

956959

960+
@since(1.5)
961+
def sort_array(col, asc=True):
962+
"""
963+
Collection function: sorts the input array for the given column in ascending order.
964+
965+
:param col: name of column or expression
966+
967+
>>> df = sqlContext.createDataFrame([([2, 1, 3],),([1],),([],)], ['data'])
968+
>>> df.select(sort_array(df.data).alias('r')).collect()
969+
[Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])]
970+
>>> df.select(sort_array(df.data, asc=False).alias('r')).collect()
971+
[Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])]
972+
"""
973+
sc = SparkContext._active_spark_context
974+
return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))
975+
976+
957977
@since
958978
@ignore_unicode_prefix
959979
def soundex(col):

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ object FunctionRegistry {
233233

234234
// collection functions
235235
expression[Size]("size"),
236+
expression[SortArray]("sort_array"),
236237

237238
// misc functions
238239
expression[Crc32]("crc32"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
*/
1717
package org.apache.spark.sql.catalyst.expressions
1818

19-
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
19+
import java.util.Comparator
20+
21+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
22+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, CodeGenContext, GeneratedExpressionCode}
2023
import org.apache.spark.sql.types._
2124

2225
/**
@@ -35,3 +38,79 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
3538
nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).numElements();")
3639
}
3740
}
41+
42+
/**
43+
* Sorts the input array in ascending / descending order according to the natural ordering of
44+
* the array elements and returns it.
45+
*/
46+
case class SortArray(base: Expression, ascendingOrder: Expression)
47+
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
48+
49+
def this(e: Expression) = this(e, Literal(true))
50+
51+
override def left: Expression = base
52+
override def right: Expression = ascendingOrder
53+
override def dataType: DataType = base.dataType
54+
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)
55+
56+
override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
57+
case _ @ ArrayType(n: AtomicType, _) => TypeCheckResult.TypeCheckSuccess
58+
case _ @ ArrayType(n, _) => TypeCheckResult.TypeCheckFailure(
59+
s"Type $n is not the AtomicType, we can not perform the ordering operations")
60+
case other =>
61+
TypeCheckResult.TypeCheckFailure(s"ArrayType(AtomicType) is expected, but we got $other")
62+
}
63+
64+
@transient
65+
private lazy val lt = {
66+
val ordering = base.dataType match {
67+
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
68+
}
69+
70+
new Comparator[Any]() {
71+
override def compare(o1: Any, o2: Any): Int = {
72+
if (o1 == null && o2 == null) {
73+
0
74+
} else if (o1 == null) {
75+
-1
76+
} else if (o2 == null) {
77+
1
78+
} else {
79+
ordering.compare(o1, o2)
80+
}
81+
}
82+
}
83+
}
84+
85+
@transient
86+
private lazy val gt = {
87+
val ordering = base.dataType match {
88+
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
89+
}
90+
91+
new Comparator[Any]() {
92+
override def compare(o1: Any, o2: Any): Int = {
93+
if (o1 == null && o2 == null) {
94+
0
95+
} else if (o1 == null) {
96+
1
97+
} else if (o2 == null) {
98+
-1
99+
} else {
100+
-ordering.compare(o1, o2)
101+
}
102+
}
103+
}
104+
}
105+
106+
override def nullSafeEval(array: Any, ascending: Any): Any = {
107+
val elementType = base.dataType.asInstanceOf[ArrayType].elementType
108+
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
109+
java.util.Arrays.sort(
110+
data,
111+
if (ascending.asInstanceOf[Boolean]) lt else gt)
112+
new GenericArrayData(data.asInstanceOf[Array[Any]])
113+
}
114+
115+
override def prettyName: String = "sort_array"
116+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,26 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
4343
checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null)
4444
checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
4545
}
46+
47+
test("Sort Array") {
48+
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
49+
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
50+
val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType))
51+
val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType))
52+
53+
checkEvaluation(new SortArray(a0), Seq(1, 2, 3))
54+
checkEvaluation(new SortArray(a1), Seq[Integer]())
55+
checkEvaluation(new SortArray(a2), Seq("a", "b"))
56+
checkEvaluation(new SortArray(a3), Seq(null, "a", "b"))
57+
checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3))
58+
checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]())
59+
checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b"))
60+
checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b"))
61+
checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1))
62+
checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]())
63+
checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a"))
64+
checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null))
65+
66+
checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
67+
}
4668
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2223,19 +2223,30 @@ object functions {
22232223
//////////////////////////////////////////////////////////////////////////////////////////////
22242224

22252225
/**
2226-
* Returns length of array or map
2226+
* Returns length of array or map.
2227+
*
22272228
* @group collection_funcs
22282229
* @since 1.5.0
22292230
*/
2230-
def size(columnName: String): Column = size(Column(columnName))
2231+
def size(e: Column): Column = Size(e.expr)
22312232

22322233
/**
2233-
* Returns length of array or map
2234+
* Sorts the input array for the given column in ascending order,
2235+
* according to the natural ordering of the array elements.
2236+
*
22342237
* @group collection_funcs
22352238
* @since 1.5.0
22362239
*/
2237-
def size(column: Column): Column = Size(column.expr)
2240+
def sort_array(e: Column): Column = sort_array(e, true)
22382241

2242+
/**
2243+
* Sorts the input array for the given column in ascending / descending order,
2244+
* according to the natural ordering of the array elements.
2245+
*
2246+
* @group collection_funcs
2247+
* @since 1.5.0
2248+
*/
2249+
def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr)
22392250

22402251
//////////////////////////////////////////////////////////////////////////////////////////////
22412252
//////////////////////////////////////////////////////////////////////////////////////////////

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,14 +267,61 @@ class DataFrameFunctionsSuite extends QueryTest {
267267
)
268268
}
269269

270+
test("sort_array function") {
271+
val df = Seq(
272+
(Array[Int](2, 1, 3), Array("b", "c", "a")),
273+
(Array[Int](), Array[String]()),
274+
(null, null)
275+
).toDF("a", "b")
276+
checkAnswer(
277+
df.select(sort_array($"a"), sort_array($"b")),
278+
Seq(
279+
Row(Seq(1, 2, 3), Seq("a", "b", "c")),
280+
Row(Seq[Int](), Seq[String]()),
281+
Row(null, null))
282+
)
283+
checkAnswer(
284+
df.select(sort_array($"a", false), sort_array($"b", false)),
285+
Seq(
286+
Row(Seq(3, 2, 1), Seq("c", "b", "a")),
287+
Row(Seq[Int](), Seq[String]()),
288+
Row(null, null))
289+
)
290+
checkAnswer(
291+
df.selectExpr("sort_array(a)", "sort_array(b)"),
292+
Seq(
293+
Row(Seq(1, 2, 3), Seq("a", "b", "c")),
294+
Row(Seq[Int](), Seq[String]()),
295+
Row(null, null))
296+
)
297+
checkAnswer(
298+
df.selectExpr("sort_array(a, true)", "sort_array(b, false)"),
299+
Seq(
300+
Row(Seq(1, 2, 3), Seq("c", "b", "a")),
301+
Row(Seq[Int](), Seq[String]()),
302+
Row(null, null))
303+
)
304+
305+
val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b")
306+
assert(intercept[AnalysisException] {
307+
df2.selectExpr("sort_array(a)").collect()
308+
}.getMessage().contains("Type ArrayType(IntegerType,false) is not the AtomicType, " +
309+
"we can not perform the ordering operations"))
310+
311+
val df3 = Seq(("xxx", "x")).toDF("a", "b")
312+
assert(intercept[AnalysisException] {
313+
df3.selectExpr("sort_array(a)").collect()
314+
}.getMessage().contains("ArrayType(AtomicType) is expected, but we got StringType"))
315+
}
316+
270317
test("array size function") {
271318
val df = Seq(
272319
(Array[Int](1, 2), "x"),
273320
(Array[Int](), "y"),
274321
(Array[Int](1, 2, 3), "z")
275322
).toDF("a", "b")
276323
checkAnswer(
277-
df.select(size("a")),
324+
df.select(size($"a")),
278325
Seq(Row(2), Row(0), Row(3))
279326
)
280327
checkAnswer(
@@ -290,7 +337,7 @@ class DataFrameFunctionsSuite extends QueryTest {
290337
(Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z")
291338
).toDF("a", "b")
292339
checkAnswer(
293-
df.select(size("a")),
340+
df.select(size($"a")),
294341
Seq(Row(2), Row(0), Row(3))
295342
)
296343
checkAnswer(

0 commit comments

Comments
 (0)