Skip to content

Commit 3c373f3

Browse files
author
Davies Liu
committed
support date and datetime by auto_convert
1 parent cb094ff commit 3c373f3

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

python/pyspark/rdd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,6 +2267,8 @@ def _prepare_for_python_RDD(sc, command, obj=None):
22672267
# The broadcast will have same life cycle as created PythonRDD
22682268
broadcast = sc.broadcast(pickled_command)
22692269
pickled_command = ser.dumps(broadcast)
2270+
# There is a bug in py4j.java_gateway.JavaClass with auto_convert
2271+
# TODO: use auto_convert once py4j fix the bug
22702272
broadcast_vars = ListConverter().convert(
22712273
[x._jbroadcast for x in sc._pickled_broadcast_vars],
22722274
sc._gateway._gateway_client)

python/pyspark/sql/_types.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import sys
1919
import decimal
20+
import time
2021
import datetime
2122
import keyword
2223
import warnings
@@ -30,6 +31,9 @@
3031
long = int
3132
unicode = str
3233

34+
from py4j.protocol import register_input_converter
35+
from py4j.java_gateway import JavaClass
36+
3337
__all__ = [
3438
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
3539
"TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
@@ -1237,6 +1241,29 @@ def __repr__(self):
12371241
return "<Row(%s)>" % ", ".join(self)
12381242

12391243

1244+
class DateConverter(object):
1245+
def can_convert(self, obj):
1246+
return isinstance(obj, datetime.date)
1247+
1248+
def convert(self, obj, gateway_client):
1249+
Date = JavaClass("java.sql.Date", gateway_client)
1250+
return Date.valueOf(obj.strftime("%Y-%m-%d"))
1251+
1252+
1253+
class DatetimeConverter(object):
1254+
def can_convert(self, obj):
1255+
return isinstance(obj, datetime.datetime)
1256+
1257+
def convert(self, obj, gateway_client):
1258+
Timestamp = JavaClass("java.sql.Timestamp", gateway_client)
1259+
return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000)
1260+
1261+
1262+
# datetime is a subclass of date, we should register DatetimeConverter first
1263+
register_input_converter(DatetimeConverter())
1264+
register_input_converter(DateConverter())
1265+
1266+
12401267
def _test():
12411268
import doctest
12421269
from pyspark.context import SparkContext

python/pyspark/sql/tests.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import tempfile
2727
import pickle
2828
import functools
29+
import datetime
2930

3031
import py4j
3132

@@ -464,6 +465,16 @@ def test_infer_long_type(self):
464465
self.assertEqual(_infer_type(2**61), LongType())
465466
self.assertEqual(_infer_type(2**71), LongType())
466467

468+
def test_filter_with_datetime(self):
469+
time = datetime.datetime(2015, 4, 17, 23, 01, 02, 3000)
470+
date = time.date()
471+
row = Row(date=date, time=time)
472+
df = self.sqlCtx.createDataFrame([row])
473+
self.assertEqual(1, df.filter(df.date == date).count())
474+
self.assertEqual(1, df.filter(df.time == time).count())
475+
self.assertEqual(0, df.filter(df.date > date).count())
476+
self.assertEqual(0, df.filter(df.time > time).count())
477+
467478
def test_dropna(self):
468479
schema = StructType([
469480
StructField("name", StringType(), True),

0 commit comments

Comments
 (0)