Skip to content

Commit b515768

Browse files
rxinhvanhovell
authored andcommitted
[SPARK-17844] Simplify DataFrame API for defining frame boundaries in window functions
## What changes were proposed in this pull request? When I was creating the example code for SPARK-10496, I realized it was pretty convoluted to define the frame boundaries for window functions when there is no partition column or ordering column. The reason is that we don't provide a way to create a WindowSpec directly with the frame boundaries. We can trivially improve this by adding rowsBetween and rangeBetween to Window object. As an example, to compute cumulative sum using the natural ordering, before this pr: ``` df.select('key, sum("value").over(Window.partitionBy(lit(1)).rowsBetween(Long.MinValue, 0))) ``` After this pr: ``` df.select('key, sum("value").over(Window.rowsBetween(Long.MinValue, 0))) ``` Note that you could argue there is no point specifying a window frame without partitionBy/orderBy -- but it is strange that only rowsBetween and rangeBetween are not the only two APIs not available. This also fixes https://issues.apache.org/jira/browse/SPARK-17656 (removing _root_.scala). ## How was this patch tested? Added test cases to compute cumulative sum in DataFrameWindowSuite for Scala/Java and tests.py for Python. Author: Reynold Xin <rxin@databricks.com> Closes apache#15412 from rxin/SPARK-17844.
1 parent 0c0ad43 commit b515768

File tree

6 files changed

+119
-10
lines changed

6 files changed

+119
-10
lines changed

python/pyspark/sql/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1859,6 +1859,15 @@ def test_window_functions_without_partitionBy(self):
18591859
for r, ex in zip(rs, expected):
18601860
self.assertEqual(tuple(r), ex[:len(r)])
18611861

1862+
def test_window_functions_cumulative_sum(self):
1863+
df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"])
1864+
from pyspark.sql import functions as F
1865+
sel = df.select(df.key, F.sum(df.value).over(Window.rowsBetween(-sys.maxsize, 0)))
1866+
rs = sorted(sel.collect())
1867+
expected = [("one", 1), ("two", 3)]
1868+
for r, ex in zip(rs, expected):
1869+
self.assertEqual(tuple(r), ex[:len(r)])
1870+
18621871
def test_collect_functions(self):
18631872
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
18641873
from pyspark.sql import functions

python/pyspark/sql/window.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,54 @@ def orderBy(*cols):
6666
jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols))
6767
return WindowSpec(jspec)
6868

69+
@staticmethod
70+
@since(2.1)
71+
def rowsBetween(start, end):
72+
"""
73+
Creates a :class:`WindowSpec` with the frame boundaries defined,
74+
from `start` (inclusive) to `end` (inclusive).
75+
76+
Both `start` and `end` are relative positions from the current row.
77+
For example, "0" means "current row", while "-1" means the row before
78+
the current row, and "5" means the fifth row after the current row.
79+
80+
:param start: boundary start, inclusive.
81+
The frame is unbounded if this is ``-sys.maxsize`` (or lower).
82+
:param end: boundary end, inclusive.
83+
The frame is unbounded if this is ``sys.maxsize`` (or higher).
84+
"""
85+
if start <= -sys.maxsize:
86+
start = WindowSpec._JAVA_MIN_LONG
87+
if end >= sys.maxsize:
88+
end = WindowSpec._JAVA_MAX_LONG
89+
sc = SparkContext._active_spark_context
90+
jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rowsBetween(start, end)
91+
return WindowSpec(jspec)
92+
93+
@staticmethod
94+
@since(2.1)
95+
def rangeBetween(start, end):
96+
"""
97+
Creates a :class:`WindowSpec` with the frame boundaries defined,
98+
from `start` (inclusive) to `end` (inclusive).
99+
100+
Both `start` and `end` are relative from the current row. For example,
101+
"0" means "current row", while "-1" means one off before the current row,
102+
and "5" means the five off after the current row.
103+
104+
:param start: boundary start, inclusive.
105+
The frame is unbounded if this is ``-sys.maxsize`` (or lower).
106+
:param end: boundary end, inclusive.
107+
The frame is unbounded if this is ``sys.maxsize`` (or higher).
108+
"""
109+
if start <= -sys.maxsize:
110+
start = WindowSpec._JAVA_MIN_LONG
111+
if end >= sys.maxsize:
112+
end = WindowSpec._JAVA_MAX_LONG
113+
sc = SparkContext._active_spark_context
114+
jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end)
115+
return WindowSpec(jspec)
116+
69117

70118
class WindowSpec(object):
71119
"""

sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ object Window {
4242
* Creates a [[WindowSpec]] with the partitioning defined.
4343
* @since 1.4.0
4444
*/
45-
@_root_.scala.annotation.varargs
45+
@scala.annotation.varargs
4646
def partitionBy(colName: String, colNames: String*): WindowSpec = {
4747
spec.partitionBy(colName, colNames : _*)
4848
}
@@ -51,7 +51,7 @@ object Window {
5151
* Creates a [[WindowSpec]] with the partitioning defined.
5252
* @since 1.4.0
5353
*/
54-
@_root_.scala.annotation.varargs
54+
@scala.annotation.varargs
5555
def partitionBy(cols: Column*): WindowSpec = {
5656
spec.partitionBy(cols : _*)
5757
}
@@ -60,7 +60,7 @@ object Window {
6060
* Creates a [[WindowSpec]] with the ordering defined.
6161
* @since 1.4.0
6262
*/
63-
@_root_.scala.annotation.varargs
63+
@scala.annotation.varargs
6464
def orderBy(colName: String, colNames: String*): WindowSpec = {
6565
spec.orderBy(colName, colNames : _*)
6666
}
@@ -69,11 +69,49 @@ object Window {
6969
* Creates a [[WindowSpec]] with the ordering defined.
7070
* @since 1.4.0
7171
*/
72-
@_root_.scala.annotation.varargs
72+
@scala.annotation.varargs
7373
def orderBy(cols: Column*): WindowSpec = {
7474
spec.orderBy(cols : _*)
7575
}
7676

77+
/**
78+
* Creates a [[WindowSpec]] with the frame boundaries defined,
79+
* from `start` (inclusive) to `end` (inclusive).
80+
*
81+
* Both `start` and `end` are relative positions from the current row. For example, "0" means
82+
* "current row", while "-1" means the row before the current row, and "5" means the fifth row
83+
* after the current row.
84+
*
85+
* @param start boundary start, inclusive.
86+
* The frame is unbounded if this is the minimum long value.
87+
* @param end boundary end, inclusive.
88+
* The frame is unbounded if this is the maximum long value.
89+
* @since 2.1.0
90+
*/
91+
// Note: when updating the doc for this method, also update WindowSpec.rowsBetween.
92+
def rowsBetween(start: Long, end: Long): WindowSpec = {
93+
spec.rowsBetween(start, end)
94+
}
95+
96+
/**
97+
* Creates a [[WindowSpec]] with the frame boundaries defined,
98+
* from `start` (inclusive) to `end` (inclusive).
99+
*
100+
* Both `start` and `end` are relative from the current row. For example, "0" means "current row",
101+
* while "-1" means one off before the current row, and "5" means the five off after the
102+
* current row.
103+
*
104+
* @param start boundary start, inclusive.
105+
* The frame is unbounded if this is the minimum long value.
106+
* @param end boundary end, inclusive.
107+
* The frame is unbounded if this is the maximum long value.
108+
* @since 2.1.0
109+
*/
110+
// Note: when updating the doc for this method, also update WindowSpec.rangeBetween.
111+
def rangeBetween(start: Long, end: Long): WindowSpec = {
112+
spec.rangeBetween(start, end)
113+
}
114+
77115
private[sql] def spec: WindowSpec = {
78116
new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame)
79117
}

sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class WindowSpec private[sql](
3939
* Defines the partitioning columns in a [[WindowSpec]].
4040
* @since 1.4.0
4141
*/
42-
@_root_.scala.annotation.varargs
42+
@scala.annotation.varargs
4343
def partitionBy(colName: String, colNames: String*): WindowSpec = {
4444
partitionBy((colName +: colNames).map(Column(_)): _*)
4545
}
@@ -48,7 +48,7 @@ class WindowSpec private[sql](
4848
* Defines the partitioning columns in a [[WindowSpec]].
4949
* @since 1.4.0
5050
*/
51-
@_root_.scala.annotation.varargs
51+
@scala.annotation.varargs
5252
def partitionBy(cols: Column*): WindowSpec = {
5353
new WindowSpec(cols.map(_.expr), orderSpec, frame)
5454
}
@@ -57,7 +57,7 @@ class WindowSpec private[sql](
5757
* Defines the ordering columns in a [[WindowSpec]].
5858
* @since 1.4.0
5959
*/
60-
@_root_.scala.annotation.varargs
60+
@scala.annotation.varargs
6161
def orderBy(colName: String, colNames: String*): WindowSpec = {
6262
orderBy((colName +: colNames).map(Column(_)): _*)
6363
}
@@ -66,7 +66,7 @@ class WindowSpec private[sql](
6666
* Defines the ordering columns in a [[WindowSpec]].
6767
* @since 1.4.0
6868
*/
69-
@_root_.scala.annotation.varargs
69+
@scala.annotation.varargs
7070
def orderBy(cols: Column*): WindowSpec = {
7171
val sortOrder: Seq[SortOrder] = cols.map { col =>
7272
col.expr match {
@@ -92,6 +92,7 @@ class WindowSpec private[sql](
9292
* The frame is unbounded if this is the maximum long value.
9393
* @since 1.4.0
9494
*/
95+
// Note: when updating the doc for this method, also update Window.rowsBetween.
9596
def rowsBetween(start: Long, end: Long): WindowSpec = {
9697
between(RowFrame, start, end)
9798
}
@@ -109,6 +110,7 @@ class WindowSpec private[sql](
109110
* The frame is unbounded if this is the maximum long value.
110111
* @since 1.4.0
111112
*/
113+
// Note: when updating the doc for this method, also update Window.rangeBetween.
112114
def rangeBetween(start: Long, end: Long): WindowSpec = {
113115
between(RangeFrame, start, end)
114116
}

sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ abstract class UserDefinedAggregateFunction extends Serializable {
106106
/**
107107
* Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments.
108108
*/
109-
@_root_.scala.annotation.varargs
109+
@scala.annotation.varargs
110110
def apply(exprs: Column*): Column = {
111111
val aggregateExpression =
112112
AggregateExpression(
@@ -120,7 +120,7 @@ abstract class UserDefinedAggregateFunction extends Serializable {
120120
* Creates a [[Column]] for this UDAF using the distinct values of the given
121121
* [[Column]]s as input arguments.
122122
*/
123-
@_root_.scala.annotation.varargs
123+
@scala.annotation.varargs
124124
def distinct(exprs: Column*): Column = {
125125
val aggregateExpression =
126126
AggregateExpression(

sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ import org.apache.spark.sql.functions._
2222
import org.apache.spark.sql.test.SharedSQLContext
2323
import org.apache.spark.sql.types.{DataType, LongType, StructType}
2424

25+
/**
26+
* Window function testing for DataFrame API.
27+
*/
2528
class DataFrameWindowSuite extends QueryTest with SharedSQLContext {
2629
import testImplicits._
2730

@@ -47,6 +50,15 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext {
4750
Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil)
4851
}
4952

53+
test("Window.rowsBetween") {
54+
val df = Seq(("one", 1), ("two", 2)).toDF("key", "value")
55+
// Running (cumulative) sum
56+
checkAnswer(
57+
df.select('key, sum("value").over(Window.rowsBetween(Long.MinValue, 0))),
58+
Row("one", 1) :: Row("two", 3) :: Nil
59+
)
60+
}
61+
5062
test("lead") {
5163
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
5264
df.createOrReplaceTempView("window_table")

0 commit comments

Comments
 (0)