Skip to content

Commit 65461b2

Browse files
committed
Merge branch 'sliding' into auc
2 parents ca4bf8c + 5ee6001 commit 65461b2

File tree

3 files changed

+132
-0
lines changed

3 files changed

+132
-0
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,22 @@ abstract class RDD[T: ClassTag](
951951
*/
952952
def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = top(num)(ord.reverse)
953953

954+
/**
955+
* Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding
956+
* window over them. The ordering is first based on the partition index and then the ordering of
957+
* items within each partition. This is similar to sliding in Scala collections, except that it
958+
* becomes an empty RDD if the window size is greater than the total number of items. It needs to
959+
* trigger a Spark job if the parent RDD has more than one partitions and the window size is
960+
* greater than 1.
961+
*/
962+
def sliding(windowSize: Int): RDD[Array[T]] = {
963+
if (windowSize == 1) {
964+
this.map(Array(_))
965+
} else {
966+
new SlidedRDD[T](this, windowSize)
967+
}
968+
}
969+
954970
/**
955971
* Save this RDD as a text file, using string representations of elements.
956972
*/
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.rdd
19+
20+
import scala.collection.mutable
21+
import scala.reflect.ClassTag
22+
23+
import org.apache.spark.{TaskContext, Partition}
24+
25+
private[spark]
26+
class SlidedRDDPartition[T](val idx: Int, val prev: Partition, val tail: Array[T])
27+
extends Partition with Serializable {
28+
override val index: Int = idx
29+
}
30+
31+
/**
32+
* Represents a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding
33+
* window over them. The ordering is first based on the partition index and then the ordering of
34+
* items within each partition. This is similar to sliding in Scala collections, except that it
35+
* becomes an empty RDD if the window size is greater than the total number of items. It needs to
36+
* trigger a Spark job if the parent RDD has more than one partitions.
37+
*
38+
* @param parent the parent RDD
39+
* @param windowSize the window size, must be greater than 1
40+
*
41+
* @see [[org.apache.spark.rdd.RDD#sliding]]
42+
*/
43+
private[spark]
44+
class SlidedRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int)
45+
extends RDD[Array[T]](parent) {
46+
47+
require(windowSize > 1, "Window size must be greater than 1.")
48+
49+
override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = {
50+
val part = split.asInstanceOf[SlidedRDDPartition[T]]
51+
(firstParent[T].iterator(part.prev, context) ++ part.tail)
52+
.sliding(windowSize)
53+
.map(_.toArray)
54+
.filter(_.size == windowSize)
55+
}
56+
57+
override def getPreferredLocations(split: Partition): Seq[String] =
58+
firstParent[T].preferredLocations(split.asInstanceOf[SlidedRDDPartition[T]].prev)
59+
60+
override def getPartitions: Array[Partition] = {
61+
val parentPartitions = parent.partitions
62+
val n = parentPartitions.size
63+
if (n == 0) {
64+
Array.empty
65+
} else if (n == 1) {
66+
Array(new SlidedRDDPartition[T](0, parentPartitions(0), Array.empty))
67+
} else {
68+
val n1 = n - 1
69+
val w1 = windowSize - 1
70+
// Get the first w1 items of each partition, starting from the second partition.
71+
val nextHeads =
72+
parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n, true)
73+
val partitions = mutable.ArrayBuffer[SlidedRDDPartition[T]]()
74+
var i = 0
75+
var partitionIndex = 0
76+
while (i < n1) {
77+
var j = i
78+
val tail = mutable.ArrayBuffer[T]()
79+
// Keep appending to the current tail until appended a head of size w1.
80+
while (j < n1 && nextHeads(j).size < w1) {
81+
tail ++= nextHeads(j)
82+
j += 1
83+
}
84+
if (j < n1) {
85+
tail ++= nextHeads(j)
86+
j += 1
87+
}
88+
partitions += new SlidedRDDPartition[T](partitionIndex, parentPartitions(i), tail.toArray)
89+
partitionIndex += 1
90+
// Skip appended heads.
91+
i = j
92+
}
93+
// If the head of last partition has size w1, we also need to add this partition.
94+
if (nextHeads(n1 - 1).size == w1) {
95+
partitions += new SlidedRDDPartition[T](partitionIndex, parentPartitions(n1), Array.empty)
96+
}
97+
partitions.toArray
98+
}
99+
}
100+
101+
// TODO: Override methods such as aggregate, which only requires one Spark job.
102+
}

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,4 +553,18 @@ class RDDSuite extends FunSuite with SharedSparkContext {
553553
val ids = ranked.map(_._1).distinct().collect()
554554
assert(ids.length === n)
555555
}
556+
557+
test("sliding") {
558+
val data = 0 until 6
559+
for (numPartitions <- 1 to 8) {
560+
val rdd = sc.parallelize(data, numPartitions)
561+
for (windowSize <- 1 to 6) {
562+
val slided = rdd.sliding(windowSize).collect().map(_.toList).toList
563+
val expected = data.sliding(windowSize).map(_.toList).toList
564+
assert(slided === expected)
565+
}
566+
assert(rdd.sliding(7).collect().isEmpty,
567+
"Should return an empty RDD if the window size is greater than the number of items.")
568+
}
569+
}
556570
}

0 commit comments

Comments
 (0)