Skip to content

Spark 615 map partitions with index callable from java #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.api.java

import java.util.{Comparator, List => JList}
import java.util.{Comparator, List => JList, Iterator => JIterator}

import scala.collection.JavaConversions._
import scala.reflect.ClassTag
Expand Down Expand Up @@ -72,11 +72,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
def mapPartitionsWithIndex[R: ClassTag](
f: JFunction2[java.lang.Integer, java.util.Iterator[T], java.util.Iterator[R]],
preservesPartitioning: Boolean = false): JavaRDD[R] =
new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))),
preservesPartitioning))
def mapPartitionsWithIndex[R](f: JFunction2[Integer, JIterator[T], JIterator[R]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering - why isn't this version callable from Java? It seems better to me not to define a new Function class... I'm trying to understand why the existing one can't be called as-is (modulo the classtag stuff).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it was just the classtag stuff. The new function class was for convience and in keeping with the other functions being defined (FlatMapFunction etc.). I can get of the new function class if you want.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @holdenk now that I look more I'm all confused because it looks like FlatMapFunction is e.g. used for mapPartitions despite it's name. I'll have to defer to @mateiz on this one who has more recently looked at the java API.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree, did you try using the old one? It lacked a test but maybe it can work. Throughout the Java API we only pass the ClassTag for java.lang.Object, we don't try to pass a real class tag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried calling the old one, but that was pre the big Java API refactoring, and it didn't work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try updating this to just take a Function2<Integer, java.util.Iterator<A>, java.util.Iterator<B>>. I'm pretty sure it would work and we won't need a new type of class.

preservesPartitioning: Boolean = false): JavaRDD[R] = {
import scala.collection.JavaConverters._
def fn = (a: Int, b: Iterator[T]) => f.call(a, asJavaIterator(b)).asScala
val newRdd = rdd.mapPartitionsWithIndex(fn, preservesPartitioning)(fakeClassTag[R])
new JavaRDD(newRdd)(fakeClassTag)
}

/**
* Return a new RDD by applying a function to all elements of this RDD.
Expand Down
24 changes: 24 additions & 0 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,30 @@ public void javaDoubleRDDHistoGram() {
Assert.assertArrayEquals(expected_counts, histogram);
}

@Test
public void mapPartitionsWithIndex() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
JavaRDD<Integer> rddByIndex =
rdd.mapPartitionsWithIndex(new Function2<Integer,
java.util.Iterator<Integer>,
java.util.Iterator<Integer>>() {
@Override
public Iterator<Integer> call(Integer start, java.util.Iterator<Integer> iter) {
List<Integer> list = new ArrayList<Integer>();
int pos = start;
while (iter.hasNext()) {
list.add(iter.next() * pos);
pos += 1;
}
return list.iterator();
}
}, false);
Assert.assertEquals(0, rddByIndex.first().intValue());
Integer[] values = {0, 2, 6, 12, 20};
Assert.assertEquals(Arrays.asList(values), rddByIndex.collect());
}


@Test
public void map() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,24 @@ public void mapPartitions() {
Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
}

@Test
public void mapPartitionsWithIndex() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
JavaRDD<Integer> rddByIndex = rdd.mapPartitionsWithIndex((start, iter) -> {
List<Integer> list = new ArrayList<Integer>();
int sum = 0;
int pos = start;
while (iter.hasNext()) {
sum += (pos * iter.next());
pos += 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indents seem messed up here, should both be 2 spaces

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also space before *

}
return list.iterator();
});
Assert.assertEquals(0, rddByIndex.first().intValue());
Integer[] values = {0, 2, 6, 12, 20};
Assert.assertEquals(Arrays.asList(values), rddByIndex.collect());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this test is just broken as is, there's no rddByIndex variable. You have to run it on Java 8, otherwise SBT will not build this project.

}

@Test
public void sequenceFile() {
File tempDir = Files.createTempDir();
Expand Down