Skip to content

Commit 8e7ae47

Browse files
staplemarmbrus
authored andcommitted
[SPARK-2314][SQL] Override collect and take in python library, and count in java library, with optimized versions.
SchemaRDD overrides RDD functions, including collect, count, and take, with optimized versions making use of the query optimizer. The java and python interface classes wrapping SchemaRDD need to ensure the optimized versions are called as well. This patch overrides relevant calls in the python and java interfaces with optimized versions. Adds a new Row serialization pathway between python and java, based on JList[Array[Byte]] versus the existing RDD[Array[Byte]]. I wasn’t overjoyed about doing this, but I noticed that some QueryPlans implement optimizations in executeCollect(), which outputs an Array[Row] rather than the typical RDD[Row] that can be shipped to python using the existing serialization code. To me it made sense to ship the Array[Row] over to python directly instead of converting it back to an RDD[Row] just for the purpose of sending the Rows to python using the existing serialization code. Author: Aaron Staple <aaron.staple@gmail.com> Closes #1592 from staple/SPARK-2314 and squashes the following commits: 89ff550 [Aaron Staple] Merge with master. 6bb7b6c [Aaron Staple] Fix typo. b56d0ac [Aaron Staple] [SPARK-2314][SQL] Override count in JavaSchemaRDD, forwarding to SchemaRDD's count. 0fc9d40 [Aaron Staple] Fix comment typos. f03cdfa [Aaron Staple] [SPARK-2314][SQL] Override collect and take in sql.py, forwarding to SchemaRDD's collect.
1 parent 30f288a commit 8e7ae47

File tree

4 files changed

+71
-17
lines changed

4 files changed

+71
-17
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ private[spark] object PythonRDD extends Logging {
776776
}
777777

778778
/**
779-
* Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by
779+
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
780780
* PySpark.
781781
*/
782782
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {

python/pyspark/sql.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pyspark.rdd import RDD, PipelinedRDD
3131
from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
3232
from pyspark.storagelevel import StorageLevel
33+
from pyspark.traceback_utils import SCCallSiteSync
3334

3435
from itertools import chain, ifilter, imap
3536

@@ -1550,6 +1551,18 @@ def id(self):
15501551
self._id = self._jrdd.id()
15511552
return self._id
15521553

1554+
def limit(self, num):
1555+
"""Limit the result count to the number specified.
1556+
1557+
>>> srdd = sqlCtx.inferSchema(rdd)
1558+
>>> srdd.limit(2).collect()
1559+
[Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
1560+
>>> srdd.limit(0).collect()
1561+
[]
1562+
"""
1563+
rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD()
1564+
return SchemaRDD(rdd, self.sql_ctx)
1565+
15531566
def saveAsParquetFile(self, path):
15541567
"""Save the contents as a Parquet file, preserving the schema.
15551568
@@ -1626,15 +1639,39 @@ def count(self):
16261639
return self._jschema_rdd.count()
16271640

16281641
def collect(self):
1629-
"""
1630-
Return a list that contains all of the rows in this RDD.
1642+
"""Return a list that contains all of the rows in this RDD.
16311643
1632-
Each object in the list is on Row, the fields can be accessed as
1644+
Each object in the list is a Row, the fields can be accessed as
16331645
attributes.
1646+
1647+
Unlike the base RDD implementation of collect, this implementation
1648+
leverages the query optimizer to perform a collect on the SchemaRDD,
1649+
which supports features such as filter pushdown.
1650+
1651+
>>> srdd = sqlCtx.inferSchema(rdd)
1652+
>>> srdd.collect()
1653+
[Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
16341654
"""
1635-
rows = RDD.collect(self)
1655+
with SCCallSiteSync(self.context) as css:
1656+
bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator()
16361657
cls = _create_cls(self.schema())
1637-
return map(cls, rows)
1658+
return map(cls, self._collect_iterator_through_file(bytesInJava))
1659+
1660+
def take(self, num):
1661+
"""Take the first num rows of the RDD.
1662+
1663+
Each object in the list is a Row, the fields can be accessed as
1664+
attributes.
1665+
1666+
Unlike the base RDD implementation of take, this implementation
1667+
leverages the query optimizer to perform a collect on a SchemaRDD,
1668+
which supports features such as filter pushdown.
1669+
1670+
>>> srdd = sqlCtx.inferSchema(rdd)
1671+
>>> srdd.take(2)
1672+
[Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
1673+
"""
1674+
return self.limit(num).collect()
16381675

16391676
# Convert each object in the RDD to a Row with the right class
16401677
# for this SchemaRDD, so that fields can be accessed as attributes.

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -377,15 +377,15 @@ class SchemaRDD(
377377
def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan)
378378

379379
/**
380-
* Converts a JavaRDD to a PythonRDD. It is used by pyspark.
380+
* Helper for converting a Row to a simple Array suitable for pyspark serialization.
381381
*/
382-
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
382+
private def rowToJArray(row: Row, structType: StructType): Array[Any] = {
383383
import scala.collection.Map
384384

385385
def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
386386
case (null, _) => null
387387

388-
case (obj: Row, struct: StructType) => rowToArray(obj, struct)
388+
case (obj: Row, struct: StructType) => rowToJArray(obj, struct)
389389

390390
case (seq: Seq[Any], array: ArrayType) =>
391391
seq.map(x => toJava(x, array.elementType)).asJava
@@ -402,22 +402,37 @@ class SchemaRDD(
402402
case (other, _) => other
403403
}
404404

405-
def rowToArray(row: Row, structType: StructType): Array[Any] = {
406-
val fields = structType.fields.map(field => field.dataType)
407-
row.zip(fields).map {
408-
case (obj, dataType) => toJava(obj, dataType)
409-
}.toArray
410-
}
405+
val fields = structType.fields.map(field => field.dataType)
406+
row.zip(fields).map {
407+
case (obj, dataType) => toJava(obj, dataType)
408+
}.toArray
409+
}
411410

411+
/**
412+
* Converts a JavaRDD to a PythonRDD. It is used by pyspark.
413+
*/
414+
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
412415
val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
413416
this.mapPartitions { iter =>
414417
val pickle = new Pickler
415418
iter.map { row =>
416-
rowToArray(row, rowSchema)
419+
rowToJArray(row, rowSchema)
417420
}.grouped(100).map(batched => pickle.dumps(batched.toArray))
418421
}
419422
}
420423

424+
/**
425+
* Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same
426+
* format as javaToPython. It is used by pyspark.
427+
*/
428+
private[sql] def collectToPython: JList[Array[Byte]] = {
429+
val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
430+
val pickle = new Pickler
431+
new java.util.ArrayList(collect().map { row =>
432+
rowToJArray(row, rowSchema)
433+
}.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
434+
}
435+
421436
/**
422437
* Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value
423438
* of base RDD functions that do not change schema.
@@ -433,7 +448,7 @@ class SchemaRDD(
433448
}
434449

435450
// =======================================================================
436-
// Overriden RDD actions
451+
// Overridden RDD actions
437452
// =======================================================================
438453

439454
override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()

sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ class JavaSchemaRDD(
112112
new java.util.ArrayList(arr)
113113
}
114114

115+
override def count(): Long = baseSchemaRDD.count
116+
115117
override def take(num: Int): JList[Row] = {
116118
import scala.collection.JavaConversions._
117119
val arr: java.util.Collection[Row] = baseSchemaRDD.take(num).toSeq.map(new Row(_))

0 commit comments

Comments
 (0)