Skip to content

Commit ab9128f

Browse files
Davies Liurxin
Davies Liu
authored andcommitted
[SPARK-6949] [SQL] [PySpark] Support Date/Timestamp in Column expression
This PR enable auto_convert in JavaGateway, then we could register a converter for a given types, for example, date and datetime. There are two bugs related to auto_convert, see [1] and [2], we workaround it in this PR. [1] py4j/py4j#160 [2] py4j/py4j#161 cc rxin JoshRosen Author: Davies Liu <davies@databricks.com> Closes apache#5570 from davies/py4j_date and squashes the following commits: eb4fa53 [Davies Liu] fix tests in python 3 d17d634 [Davies Liu] rollback changes in mllib 2e7566d [Davies Liu] convert tuple into ArrayList ceb3779 [Davies Liu] Update rdd.py 3c373f3 [Davies Liu] support date and datetime by auto_convert cb094ff [Davies Liu] enable auto convert
1 parent 8136810 commit ab9128f

File tree

10 files changed

+70
-47
lines changed

10 files changed

+70
-47
lines changed

python/pyspark/context.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
from threading import Lock
2424
from tempfile import NamedTemporaryFile
2525

26-
from py4j.java_collections import ListConverter
27-
2826
from pyspark import accumulators
2927
from pyspark.accumulators import Accumulator
3028
from pyspark.broadcast import Broadcast
@@ -643,7 +641,6 @@ def union(self, rdds):
643641
rdds = [x._reserialize() for x in rdds]
644642
first = rdds[0]._jrdd
645643
rest = [x._jrdd for x in rdds[1:]]
646-
rest = ListConverter().convert(rest, self._gateway._gateway_client)
647644
return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer)
648645

649646
def broadcast(self, value):
@@ -846,13 +843,12 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
846843
"""
847844
if partitions is None:
848845
partitions = range(rdd._jrdd.partitions().size())
849-
javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client)
850846

851847
# Implementation note: This is implemented as a mapPartitions followed
852848
# by runJob() in order to avoid having to pass a Python lambda into
853849
# SparkContext#runJob.
854850
mappedRDD = rdd.mapPartitions(partitionFunc)
855-
port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
851+
port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions,
856852
allowLocal)
857853
return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
858854

python/pyspark/java_gateway.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,30 @@
1717

1818
import atexit
1919
import os
20+
import sys
2021
import select
2122
import signal
2223
import shlex
2324
import socket
2425
import platform
2526
from subprocess import Popen, PIPE
27+
28+
if sys.version >= '3':
29+
xrange = range
30+
2631
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
32+
from py4j.java_collections import ListConverter
2733

2834
from pyspark.serializers import read_int
2935

3036

37+
# patching ListConverter, or it will convert bytearray into Java ArrayList
38+
def can_convert_list(self, obj):
39+
return isinstance(obj, (list, tuple, xrange))
40+
41+
ListConverter.can_convert = can_convert_list
42+
43+
3144
def launch_gateway():
3245
if "PYSPARK_GATEWAY_PORT" in os.environ:
3346
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
@@ -92,7 +105,7 @@ def killChild():
92105
atexit.register(killChild)
93106

94107
# Connect to the gateway
95-
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False)
108+
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)
96109

97110
# Import the classes used by PySpark
98111
java_import(gateway.jvm, "org.apache.spark.SparkConf")

python/pyspark/rdd.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,6 +2267,9 @@ def _prepare_for_python_RDD(sc, command, obj=None):
22672267
# The broadcast will have same life cycle as created PythonRDD
22682268
broadcast = sc.broadcast(pickled_command)
22692269
pickled_command = ser.dumps(broadcast)
2270+
# There is a bug in py4j.java_gateway.JavaClass with auto_convert
2271+
# https://github.com/bartdag/py4j/issues/161
2272+
# TODO: use auto_convert once py4j fix the bug
22702273
broadcast_vars = ListConverter().convert(
22712274
[x._jbroadcast for x in sc._pickled_broadcast_vars],
22722275
sc._gateway._gateway_client)

python/pyspark/sql/_types.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import sys
1919
import decimal
20+
import time
2021
import datetime
2122
import keyword
2223
import warnings
@@ -30,6 +31,9 @@
3031
long = int
3132
unicode = str
3233

34+
from py4j.protocol import register_input_converter
35+
from py4j.java_gateway import JavaClass
36+
3337
__all__ = [
3438
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
3539
"TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
@@ -1237,6 +1241,29 @@ def __repr__(self):
12371241
return "<Row(%s)>" % ", ".join(self)
12381242

12391243

1244+
class DateConverter(object):
1245+
def can_convert(self, obj):
1246+
return isinstance(obj, datetime.date)
1247+
1248+
def convert(self, obj, gateway_client):
1249+
Date = JavaClass("java.sql.Date", gateway_client)
1250+
return Date.valueOf(obj.strftime("%Y-%m-%d"))
1251+
1252+
1253+
class DatetimeConverter(object):
1254+
def can_convert(self, obj):
1255+
return isinstance(obj, datetime.datetime)
1256+
1257+
def convert(self, obj, gateway_client):
1258+
Timestamp = JavaClass("java.sql.Timestamp", gateway_client)
1259+
return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000)
1260+
1261+
1262+
# datetime is a subclass of date, we should register DatetimeConverter first
1263+
register_input_converter(DatetimeConverter())
1264+
register_input_converter(DateConverter())
1265+
1266+
12401267
def _test():
12411268
import doctest
12421269
from pyspark.context import SparkContext

python/pyspark/sql/context.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from itertools import imap as map
2626

2727
from py4j.protocol import Py4JError
28-
from py4j.java_collections import MapConverter
2928

3029
from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix
3130
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
@@ -442,15 +441,13 @@ def load(self, path=None, source=None, schema=None, **options):
442441
if source is None:
443442
source = self.getConf("spark.sql.sources.default",
444443
"org.apache.spark.sql.parquet")
445-
joptions = MapConverter().convert(options,
446-
self._sc._gateway._gateway_client)
447444
if schema is None:
448-
df = self._ssql_ctx.load(source, joptions)
445+
df = self._ssql_ctx.load(source, options)
449446
else:
450447
if not isinstance(schema, StructType):
451448
raise TypeError("schema should be StructType")
452449
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
453-
df = self._ssql_ctx.load(source, scala_datatype, joptions)
450+
df = self._ssql_ctx.load(source, scala_datatype, options)
454451
return DataFrame(df, self)
455452

456453
def createExternalTable(self, tableName, path=None, source=None,
@@ -471,16 +468,14 @@ def createExternalTable(self, tableName, path=None, source=None,
471468
if source is None:
472469
source = self.getConf("spark.sql.sources.default",
473470
"org.apache.spark.sql.parquet")
474-
joptions = MapConverter().convert(options,
475-
self._sc._gateway._gateway_client)
476471
if schema is None:
477-
df = self._ssql_ctx.createExternalTable(tableName, source, joptions)
472+
df = self._ssql_ctx.createExternalTable(tableName, source, options)
478473
else:
479474
if not isinstance(schema, StructType):
480475
raise TypeError("schema should be StructType")
481476
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
482477
df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype,
483-
joptions)
478+
options)
484479
return DataFrame(df, self)
485480

486481
@ignore_unicode_prefix

python/pyspark/sql/dataframe.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
else:
2626
from itertools import imap as map
2727

28-
from py4j.java_collections import ListConverter, MapConverter
29-
3028
from pyspark.context import SparkContext
3129
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
3230
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
@@ -186,9 +184,7 @@ def saveAsTable(self, tableName, source=None, mode="error", **options):
186184
source = self.sql_ctx.getConf("spark.sql.sources.default",
187185
"org.apache.spark.sql.parquet")
188186
jmode = self._java_save_mode(mode)
189-
joptions = MapConverter().convert(options,
190-
self.sql_ctx._sc._gateway._gateway_client)
191-
self._jdf.saveAsTable(tableName, source, jmode, joptions)
187+
self._jdf.saveAsTable(tableName, source, jmode, options)
192188

193189
def save(self, path=None, source=None, mode="error", **options):
194190
"""Saves the contents of the :class:`DataFrame` to a data source.
@@ -211,9 +207,7 @@ def save(self, path=None, source=None, mode="error", **options):
211207
source = self.sql_ctx.getConf("spark.sql.sources.default",
212208
"org.apache.spark.sql.parquet")
213209
jmode = self._java_save_mode(mode)
214-
joptions = MapConverter().convert(options,
215-
self._sc._gateway._gateway_client)
216-
self._jdf.save(source, jmode, joptions)
210+
self._jdf.save(source, jmode, options)
217211

218212
@property
219213
def schema(self):
@@ -819,7 +813,6 @@ def fillna(self, value, subset=None):
819813
value = float(value)
820814

821815
if isinstance(value, dict):
822-
value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client)
823816
return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
824817
elif subset is None:
825818
return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
@@ -932,9 +925,7 @@ def agg(self, *exprs):
932925
"""
933926
assert exprs, "exprs should not be empty"
934927
if len(exprs) == 1 and isinstance(exprs[0], dict):
935-
jmap = MapConverter().convert(exprs[0],
936-
self.sql_ctx._sc._gateway._gateway_client)
937-
jdf = self._jdf.agg(jmap)
928+
jdf = self._jdf.agg(exprs[0])
938929
else:
939930
# Columns
940931
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
@@ -1040,8 +1031,7 @@ def _to_seq(sc, cols, converter=None):
10401031
"""
10411032
if converter:
10421033
cols = [converter(c) for c in cols]
1043-
jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
1044-
return sc._jvm.PythonUtils.toSeq(jcols)
1034+
return sc._jvm.PythonUtils.toSeq(cols)
10451035

10461036

10471037
def _unary_op(name, doc="unary operator"):

python/pyspark/sql/tests.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import tempfile
2727
import pickle
2828
import functools
29+
import datetime
2930

3031
import py4j
3132

@@ -464,6 +465,16 @@ def test_infer_long_type(self):
464465
self.assertEqual(_infer_type(2**61), LongType())
465466
self.assertEqual(_infer_type(2**71), LongType())
466467

468+
def test_filter_with_datetime(self):
469+
time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
470+
date = time.date()
471+
row = Row(date=date, time=time)
472+
df = self.sqlCtx.createDataFrame([row])
473+
self.assertEqual(1, df.filter(df.date == date).count())
474+
self.assertEqual(1, df.filter(df.time == time).count())
475+
self.assertEqual(0, df.filter(df.date > date).count())
476+
self.assertEqual(0, df.filter(df.time > time).count())
477+
467478
def test_dropna(self):
468479
schema = StructType([
469480
StructField("name", StringType(), True),

python/pyspark/streaming/context.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import os
2121
import sys
2222

23-
from py4j.java_collections import ListConverter
2423
from py4j.java_gateway import java_import, JavaObject
2524

2625
from pyspark import RDD, SparkConf
@@ -305,9 +304,7 @@ def queueStream(self, rdds, oneAtATime=True, default=None):
305304
rdds = [self._sc.parallelize(input) for input in rdds]
306305
self._check_serializers(rdds)
307306

308-
jrdds = ListConverter().convert([r._jrdd for r in rdds],
309-
SparkContext._gateway._gateway_client)
310-
queue = self._jvm.PythonDStream.toRDDQueue(jrdds)
307+
queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds])
311308
if default:
312309
default = default._reserialize(rdds[0]._jrdd_deserializer)
313310
jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
@@ -322,8 +319,7 @@ def transform(self, dstreams, transformFunc):
322319
the transform function parameter will be the same as the order
323320
of corresponding DStreams in the list.
324321
"""
325-
jdstreams = ListConverter().convert([d._jdstream for d in dstreams],
326-
SparkContext._gateway._gateway_client)
322+
jdstreams = [d._jdstream for d in dstreams]
327323
# change the final serializer to sc.serializer
328324
func = TransformFunction(self._sc,
329325
lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
@@ -346,6 +342,5 @@ def union(self, *dstreams):
346342
if len(set(s._slideDuration for s in dstreams)) > 1:
347343
raise ValueError("All DStreams should have same slide duration")
348344
first = dstreams[0]
349-
jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]],
350-
SparkContext._gateway._gateway_client)
345+
jrest = [d._jdstream for d in dstreams[1:]]
351346
return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer)

python/pyspark/streaming/kafka.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
from py4j.java_collections import MapConverter
19-
from py4j.java_gateway import java_import, Py4JError, Py4JJavaError
18+
from py4j.java_gateway import Py4JJavaError
2019

2120
from pyspark.storagelevel import StorageLevel
2221
from pyspark.serializers import PairDeserializer, NoOpSerializer
@@ -57,16 +56,14 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
5756
})
5857
if not isinstance(topics, dict):
5958
raise TypeError("topics should be dict")
60-
jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client)
61-
jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client)
6259
jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
6360

6461
try:
6562
# Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027)
6663
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
6764
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
6865
helper = helperClass.newInstance()
69-
jstream = helper.createStream(ssc._jssc, jparam, jtopics, jlevel)
66+
jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel)
7067
except Py4JJavaError as e:
7168
# TODO: use --jar once it also work on driver
7269
if 'ClassNotFoundException' in str(e.java_exception):

python/pyspark/streaming/tests.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
import struct
2525
from functools import reduce
2626

27-
from py4j.java_collections import MapConverter
28-
2927
from pyspark.context import SparkConf, SparkContext, RDD
3028
from pyspark.streaming.context import StreamingContext
3129
from pyspark.streaming.kafka import KafkaUtils
@@ -581,11 +579,9 @@ def test_kafka_stream(self):
581579
"""Test the Python Kafka stream API."""
582580
topic = "topic1"
583581
sendData = {"a": 3, "b": 5, "c": 10}
584-
jSendData = MapConverter().convert(sendData,
585-
self.ssc.sparkContext._gateway._gateway_client)
586582

587583
self._kafkaTestUtils.createTopic(topic)
588-
self._kafkaTestUtils.sendMessages(topic, jSendData)
584+
self._kafkaTestUtils.sendMessages(topic, sendData)
589585

590586
stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(),
591587
"test-streaming-consumer", {topic: 1},

0 commit comments

Comments
 (0)