Skip to content

Commit cb094ff

Browse files
author
Davies Liu
committed
enable auto convert
1 parent d850b4b commit cb094ff

File tree

8 files changed

+27
-49
lines changed

8 files changed

+27
-49
lines changed

python/pyspark/context.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
from threading import Lock
2424
from tempfile import NamedTemporaryFile
2525

26-
from py4j.java_collections import ListConverter
27-
2826
from pyspark import accumulators
2927
from pyspark.accumulators import Accumulator
3028
from pyspark.broadcast import Broadcast
@@ -643,7 +641,6 @@ def union(self, rdds):
643641
rdds = [x._reserialize() for x in rdds]
644642
first = rdds[0]._jrdd
645643
rest = [x._jrdd for x in rdds[1:]]
646-
rest = ListConverter().convert(rest, self._gateway._gateway_client)
647644
return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer)
648645

649646
def broadcast(self, value):
@@ -846,13 +843,12 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
846843
"""
847844
if partitions is None:
848845
partitions = range(rdd._jrdd.partitions().size())
849-
javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client)
850846

851847
# Implementation note: This is implemented as a mapPartitions followed
852848
# by runJob() in order to avoid having to pass a Python lambda into
853849
# SparkContext#runJob.
854850
mappedRDD = rdd.mapPartitions(partitionFunc)
855-
port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
851+
port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions,
856852
allowLocal)
857853
return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
858854

python/pyspark/java_gateway.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,20 @@
2323
import socket
2424
import platform
2525
from subprocess import Popen, PIPE
26+
2627
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
28+
from py4j.java_collections import ListConverter
2729

2830
from pyspark.serializers import read_int
2931

3032

33+
# patching ListConverter, or it will convert bytearray into Java ArrayList
34+
def can_convert_list(self, obj):
35+
return isinstance(obj, list)
36+
37+
ListConverter.can_convert = can_convert_list
38+
39+
3140
def launch_gateway():
3241
if "PYSPARK_GATEWAY_PORT" in os.environ:
3342
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
@@ -92,7 +101,7 @@ def killChild():
92101
atexit.register(killChild)
93102

94103
# Connect to the gateway
95-
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False)
104+
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)
96105

97106
# Import the classes used by PySpark
98107
java_import(gateway.jvm, "org.apache.spark.SparkConf")

python/pyspark/mllib/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import py4j.protocol
2424
from py4j.protocol import Py4JJavaError
2525
from py4j.java_gateway import JavaObject
26-
from py4j.java_collections import ListConverter, JavaArray, JavaList
26+
from py4j.java_collections import JavaArray, JavaList
2727

2828
from pyspark import RDD, SparkContext
2929
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
@@ -76,7 +76,7 @@ def _py2java(sc, obj):
7676
elif isinstance(obj, SparkContext):
7777
obj = obj._jsc
7878
elif isinstance(obj, list):
79-
obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client)
79+
obj = [_py2java(sc, x) for x in obj]
8080
elif isinstance(obj, JavaObject):
8181
pass
8282
elif isinstance(obj, (int, long, float, bool, bytes, unicode)):

python/pyspark/sql/context.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from itertools import imap as map
2626

2727
from py4j.protocol import Py4JError
28-
from py4j.java_collections import MapConverter
2928

3029
from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix
3130
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
@@ -442,15 +441,13 @@ def load(self, path=None, source=None, schema=None, **options):
442441
if source is None:
443442
source = self.getConf("spark.sql.sources.default",
444443
"org.apache.spark.sql.parquet")
445-
joptions = MapConverter().convert(options,
446-
self._sc._gateway._gateway_client)
447444
if schema is None:
448-
df = self._ssql_ctx.load(source, joptions)
445+
df = self._ssql_ctx.load(source, options)
449446
else:
450447
if not isinstance(schema, StructType):
451448
raise TypeError("schema should be StructType")
452449
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
453-
df = self._ssql_ctx.load(source, scala_datatype, joptions)
450+
df = self._ssql_ctx.load(source, scala_datatype, options)
454451
return DataFrame(df, self)
455452

456453
def createExternalTable(self, tableName, path=None, source=None,
@@ -471,16 +468,14 @@ def createExternalTable(self, tableName, path=None, source=None,
471468
if source is None:
472469
source = self.getConf("spark.sql.sources.default",
473470
"org.apache.spark.sql.parquet")
474-
joptions = MapConverter().convert(options,
475-
self._sc._gateway._gateway_client)
476471
if schema is None:
477-
df = self._ssql_ctx.createExternalTable(tableName, source, joptions)
472+
df = self._ssql_ctx.createExternalTable(tableName, source, options)
478473
else:
479474
if not isinstance(schema, StructType):
480475
raise TypeError("schema should be StructType")
481476
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
482477
df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype,
483-
joptions)
478+
options)
484479
return DataFrame(df, self)
485480

486481
@ignore_unicode_prefix

python/pyspark/sql/dataframe.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
else:
2626
from itertools import imap as map
2727

28-
from py4j.java_collections import ListConverter, MapConverter
29-
3028
from pyspark.context import SparkContext
3129
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
3230
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
@@ -186,9 +184,7 @@ def saveAsTable(self, tableName, source=None, mode="error", **options):
186184
source = self.sql_ctx.getConf("spark.sql.sources.default",
187185
"org.apache.spark.sql.parquet")
188186
jmode = self._java_save_mode(mode)
189-
joptions = MapConverter().convert(options,
190-
self.sql_ctx._sc._gateway._gateway_client)
191-
self._jdf.saveAsTable(tableName, source, jmode, joptions)
187+
self._jdf.saveAsTable(tableName, source, jmode, options)
192188

193189
def save(self, path=None, source=None, mode="error", **options):
194190
"""Saves the contents of the :class:`DataFrame` to a data source.
@@ -211,9 +207,7 @@ def save(self, path=None, source=None, mode="error", **options):
211207
source = self.sql_ctx.getConf("spark.sql.sources.default",
212208
"org.apache.spark.sql.parquet")
213209
jmode = self._java_save_mode(mode)
214-
joptions = MapConverter().convert(options,
215-
self._sc._gateway._gateway_client)
216-
self._jdf.save(source, jmode, joptions)
210+
self._jdf.save(source, jmode, options)
217211

218212
@property
219213
def schema(self):
@@ -819,7 +813,6 @@ def fillna(self, value, subset=None):
819813
value = float(value)
820814

821815
if isinstance(value, dict):
822-
value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client)
823816
return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
824817
elif subset is None:
825818
return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
@@ -932,9 +925,7 @@ def agg(self, *exprs):
932925
"""
933926
assert exprs, "exprs should not be empty"
934927
if len(exprs) == 1 and isinstance(exprs[0], dict):
935-
jmap = MapConverter().convert(exprs[0],
936-
self.sql_ctx._sc._gateway._gateway_client)
937-
jdf = self._jdf.agg(jmap)
928+
jdf = self._jdf.agg(exprs[0])
938929
else:
939930
# Columns
940931
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
@@ -1040,8 +1031,7 @@ def _to_seq(sc, cols, converter=None):
10401031
"""
10411032
if converter:
10421033
cols = [converter(c) for c in cols]
1043-
jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
1044-
return sc._jvm.PythonUtils.toSeq(jcols)
1034+
return sc._jvm.PythonUtils.toSeq(cols)
10451035

10461036

10471037
def _unary_op(name, doc="unary operator"):

python/pyspark/streaming/context.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import os
2121
import sys
2222

23-
from py4j.java_collections import ListConverter
2423
from py4j.java_gateway import java_import, JavaObject
2524

2625
from pyspark import RDD, SparkConf
@@ -305,9 +304,7 @@ def queueStream(self, rdds, oneAtATime=True, default=None):
305304
rdds = [self._sc.parallelize(input) for input in rdds]
306305
self._check_serializers(rdds)
307306

308-
jrdds = ListConverter().convert([r._jrdd for r in rdds],
309-
SparkContext._gateway._gateway_client)
310-
queue = self._jvm.PythonDStream.toRDDQueue(jrdds)
307+
queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds])
311308
if default:
312309
default = default._reserialize(rdds[0]._jrdd_deserializer)
313310
jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
@@ -322,8 +319,7 @@ def transform(self, dstreams, transformFunc):
322319
the transform function parameter will be the same as the order
323320
of corresponding DStreams in the list.
324321
"""
325-
jdstreams = ListConverter().convert([d._jdstream for d in dstreams],
326-
SparkContext._gateway._gateway_client)
322+
jdstreams = [d._jdstream for d in dstreams]
327323
# change the final serializer to sc.serializer
328324
func = TransformFunction(self._sc,
329325
lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
@@ -346,6 +342,5 @@ def union(self, *dstreams):
346342
if len(set(s._slideDuration for s in dstreams)) > 1:
347343
raise ValueError("All DStreams should have same slide duration")
348344
first = dstreams[0]
349-
jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]],
350-
SparkContext._gateway._gateway_client)
345+
jrest = [d._jdstream for d in dstreams[1:]]
351346
return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer)

python/pyspark/streaming/kafka.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
from py4j.java_collections import MapConverter
19-
from py4j.java_gateway import java_import, Py4JError, Py4JJavaError
18+
from py4j.java_gateway import Py4JJavaError
2019

2120
from pyspark.storagelevel import StorageLevel
2221
from pyspark.serializers import PairDeserializer, NoOpSerializer
@@ -57,16 +56,14 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
5756
})
5857
if not isinstance(topics, dict):
5958
raise TypeError("topics should be dict")
60-
jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client)
61-
jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client)
6259
jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
6360

6461
try:
6562
# Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027)
6663
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
6764
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
6865
helper = helperClass.newInstance()
69-
jstream = helper.createStream(ssc._jssc, jparam, jtopics, jlevel)
66+
jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel)
7067
except Py4JJavaError as e:
7168
# TODO: use --jar once it also work on driver
7269
if 'ClassNotFoundException' in str(e.java_exception):

python/pyspark/streaming/tests.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
import struct
2525
from functools import reduce
2626

27-
from py4j.java_collections import MapConverter
28-
2927
from pyspark.context import SparkConf, SparkContext, RDD
3028
from pyspark.streaming.context import StreamingContext
3129
from pyspark.streaming.kafka import KafkaUtils
@@ -581,11 +579,9 @@ def test_kafka_stream(self):
581579
"""Test the Python Kafka stream API."""
582580
topic = "topic1"
583581
sendData = {"a": 3, "b": 5, "c": 10}
584-
jSendData = MapConverter().convert(sendData,
585-
self.ssc.sparkContext._gateway._gateway_client)
586582

587583
self._kafkaTestUtils.createTopic(topic)
588-
self._kafkaTestUtils.sendMessages(topic, jSendData)
584+
self._kafkaTestUtils.sendMessages(topic, sendData)
589585

590586
stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(),
591587
"test-streaming-consumer", {topic: 1},

0 commit comments

Comments
 (0)