Skip to content

Commit cad3002

Browse files
committed
Merge pull request apache#501 from JoshRosen/cartesian-rdd-fixes
Fix two bugs in PySpark cartesian(): SPARK-978 and SPARK-1034 This pull request fixes two bugs in PySpark's `cartesian()` method: - [SPARK-978](https://spark-project.atlassian.net/browse/SPARK-978): PySpark's cartesian method throws ClassCastException exception - [SPARK-1034](https://spark-project.atlassian.net/browse/SPARK-1034): Py4JException on PySpark Cartesian Result The JIRAs have more details describing the fixes.
2 parents fad6aac + 6156990 commit cad3002

File tree

3 files changed

+56
-22
lines changed

3 files changed

+56
-22
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
4949

5050
override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd)
5151

52-
override val classTag: ClassTag[(K, V)] =
53-
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K, V]]]
52+
override val classTag: ClassTag[(K, V)] = rdd.elementClassTag
5453

5554
import JavaPairRDD._
5655

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,7 @@ private[spark] class PythonRDD[T: ClassTag](
7878
dataOut.writeInt(command.length)
7979
dataOut.write(command)
8080
// Data values
81-
for (elem <- parent.iterator(split, context)) {
82-
PythonRDD.writeToStream(elem, dataOut)
83-
}
81+
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
8482
dataOut.flush()
8583
worker.shutdownOutput()
8684
} catch {
@@ -206,20 +204,43 @@ private[spark] object PythonRDD {
206204
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
207205
}
208206

209-
def writeToStream(elem: Any, dataOut: DataOutputStream) {
210-
elem match {
211-
case bytes: Array[Byte] =>
212-
dataOut.writeInt(bytes.length)
213-
dataOut.write(bytes)
214-
case pair: (Array[Byte], Array[Byte]) =>
215-
dataOut.writeInt(pair._1.length)
216-
dataOut.write(pair._1)
217-
dataOut.writeInt(pair._2.length)
218-
dataOut.write(pair._2)
219-
case str: String =>
220-
dataOut.writeUTF(str)
221-
case other =>
222-
throw new SparkException("Unexpected element type " + other.getClass)
207+
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
208+
// The right way to implement this would be to use TypeTags to get the full
209+
// type of T. Since I don't want to introduce breaking changes throughout the
210+
// entire Spark API, I have to use this hacky approach:
211+
if (iter.hasNext) {
212+
val first = iter.next()
213+
val newIter = Seq(first).iterator ++ iter
214+
first match {
215+
case arr: Array[Byte] =>
216+
newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes =>
217+
dataOut.writeInt(bytes.length)
218+
dataOut.write(bytes)
219+
}
220+
case string: String =>
221+
newIter.asInstanceOf[Iterator[String]].foreach { str =>
222+
dataOut.writeUTF(str)
223+
}
224+
case pair: Tuple2[_, _] =>
225+
pair._1 match {
226+
case bytePair: Array[Byte] =>
227+
newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair =>
228+
dataOut.writeInt(pair._1.length)
229+
dataOut.write(pair._1)
230+
dataOut.writeInt(pair._2.length)
231+
dataOut.write(pair._2)
232+
}
233+
case stringPair: String =>
234+
newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair =>
235+
dataOut.writeUTF(pair._1)
236+
dataOut.writeUTF(pair._2)
237+
}
238+
case other =>
239+
throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass)
240+
}
241+
case other =>
242+
throw new SparkException("Unexpected element type " + first.getClass)
243+
}
223244
}
224245
}
225246

@@ -230,9 +251,7 @@ private[spark] object PythonRDD {
230251

231252
def writeToFile[T](items: Iterator[T], filename: String) {
232253
val file = new DataOutputStream(new FileOutputStream(filename))
233-
for (item <- items) {
234-
writeToStream(item, file)
235-
}
254+
writeIteratorToStream(items, file)
236255
file.close()
237256
}
238257

python/pyspark/tests.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,22 @@ def test_save_as_textfile_with_unicode(self):
152152
raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*")))
153153
self.assertEqual(x, unicode(raw_contents.strip(), "utf-8"))
154154

155+
def test_transforming_cartesian_result(self):
156+
# Regression test for SPARK-1034
157+
rdd1 = self.sc.parallelize([1, 2])
158+
rdd2 = self.sc.parallelize([3, 4])
159+
cart = rdd1.cartesian(rdd2)
160+
result = cart.map(lambda (x, y): x + y).collect()
161+
162+
def test_cartesian_on_textfile(self):
163+
# Regression test for
164+
path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
165+
a = self.sc.textFile(path)
166+
result = a.cartesian(a).collect()
167+
(x, y) = result[0]
168+
self.assertEqual("Hello World!", x.strip())
169+
self.assertEqual("Hello World!", y.strip())
170+
155171

156172
class TestIO(PySparkTestCase):
157173

0 commit comments

Comments
 (0)