Skip to content

Commit 17af727

Browse files
BryanCutlergatorsmile
authored andcommitted
[SPARK-21375][PYSPARK][SQL] Add Date and Timestamp support to ArrowConverters for toPandas() Conversion
## What changes were proposed in this pull request? Adding date and timestamp support with Arrow for `toPandas()` and `pandas_udf`s. Timestamps are stored in Arrow as UTC and manifested to the user as timezone-naive localized to the Python system timezone. ## How was this patch tested? Added Scala tests for date and timestamp types under ArrowConverters, ArrowUtils, and ArrowWriter suites. Added Python tests for `toPandas()` and `pandas_udf`s with date and timestamp types. Author: Bryan Cutler <cutlerb@gmail.com> Author: Takuya UESHIN <ueshin@databricks.com> Closes #18664 from BryanCutler/arrow-date-timestamp-SPARK-21375.
1 parent 5c3a1f3 commit 17af727

File tree

17 files changed

+417
-73
lines changed

17 files changed

+417
-73
lines changed

python/pyspark/serializers.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def __repr__(self):
214214

215215

216216
def _create_batch(series):
217+
from pyspark.sql.types import _check_series_convert_timestamps_internal
217218
import pyarrow as pa
218219
# Make input conform to [(series1, type1), (series2, type2), ...]
219220
if not isinstance(series, (list, tuple)) or \
@@ -224,12 +225,25 @@ def _create_batch(series):
224225
# If a nullable integer series has been promoted to floating point with NaNs, need to cast
225226
# NOTE: this is not necessary with Arrow >= 0.7
226227
def cast_series(s, t):
227-
if t is None or s.dtype == t.to_pandas_dtype():
228+
if type(t) == pa.TimestampType:
229+
# NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680
230+
return _check_series_convert_timestamps_internal(s.fillna(0))\
231+
.values.astype('datetime64[us]', copy=False)
232+
elif t == pa.date32():
233+
# TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8
234+
return s.dt.date
235+
elif t is None or s.dtype == t.to_pandas_dtype():
228236
return s
229237
else:
230238
return s.fillna(0).astype(t.to_pandas_dtype(), copy=False)
231239

232-
arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series]
240+
# Some object types don't support masks in Arrow, see ARROW-1721
241+
def create_array(s, t):
242+
casted = cast_series(s, t)
243+
mask = None if casted.dtype == 'object' else s.isnull()
244+
return pa.Array.from_pandas(casted, mask=mask, type=t)
245+
246+
arrs = [create_array(s, t) for s, t in series]
233247
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
234248

235249

@@ -260,11 +274,13 @@ def load_stream(self, stream):
260274
"""
261275
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
262276
"""
277+
from pyspark.sql.types import _check_dataframe_localize_timestamps
263278
import pyarrow as pa
264279
reader = pa.open_stream(stream)
265280
for batch in reader:
266-
table = pa.Table.from_batches([batch])
267-
yield [c.to_pandas() for c in table.itercolumns()]
281+
# NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1
282+
pdf = _check_dataframe_localize_timestamps(batch.to_pandas())
283+
yield [c for _, c in pdf.iteritems()]
268284

269285
def __repr__(self):
270286
return "ArrowStreamPandasSerializer"

python/pyspark/sql/dataframe.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1883,11 +1883,13 @@ def toPandas(self):
18831883
import pandas as pd
18841884
if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true":
18851885
try:
1886+
from pyspark.sql.types import _check_dataframe_localize_timestamps
18861887
import pyarrow
18871888
tables = self._collectAsArrow()
18881889
if tables:
18891890
table = pyarrow.concat_tables(tables)
1890-
return table.to_pandas()
1891+
pdf = table.to_pandas()
1892+
return _check_dataframe_localize_timestamps(pdf)
18911893
else:
18921894
return pd.DataFrame.from_records([], columns=self.columns)
18931895
except ImportError as e:
@@ -1955,6 +1957,7 @@ def _to_corrected_pandas_type(dt):
19551957
"""
19561958
When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong.
19571959
This method gets the corrected data type for Pandas if that type may be inferred uncorrectly.
1960+
NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns]
19581961
"""
19591962
import numpy as np
19601963
if type(dt) == ByteType:
@@ -1965,6 +1968,8 @@ def _to_corrected_pandas_type(dt):
19651968
return np.int32
19661969
elif type(dt) == FloatType:
19671970
return np.float32
1971+
elif type(dt) == DateType:
1972+
return 'datetime64[ns]'
19681973
else:
19691974
return None
19701975

python/pyspark/sql/tests.py

Lines changed: 95 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3086,18 +3086,38 @@ class ArrowTests(ReusedPySparkTestCase):
30863086

30873087
@classmethod
30883088
def setUpClass(cls):
3089+
from datetime import datetime
30893090
ReusedPySparkTestCase.setUpClass()
3091+
3092+
# Synchronize default timezone between Python and Java
3093+
cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
3094+
tz = "America/Los_Angeles"
3095+
os.environ["TZ"] = tz
3096+
time.tzset()
3097+
30903098
cls.spark = SparkSession(cls.sc)
3099+
cls.spark.conf.set("spark.sql.session.timeZone", tz)
30913100
cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
30923101
cls.schema = StructType([
30933102
StructField("1_str_t", StringType(), True),
30943103
StructField("2_int_t", IntegerType(), True),
30953104
StructField("3_long_t", LongType(), True),
30963105
StructField("4_float_t", FloatType(), True),
3097-
StructField("5_double_t", DoubleType(), True)])
3098-
cls.data = [("a", 1, 10, 0.2, 2.0),
3099-
("b", 2, 20, 0.4, 4.0),
3100-
("c", 3, 30, 0.8, 6.0)]
3106+
StructField("5_double_t", DoubleType(), True),
3107+
StructField("6_date_t", DateType(), True),
3108+
StructField("7_timestamp_t", TimestampType(), True)])
3109+
cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
3110+
("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
3111+
("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
3112+
3113+
@classmethod
3114+
def tearDownClass(cls):
3115+
del os.environ["TZ"]
3116+
if cls.tz_prev is not None:
3117+
os.environ["TZ"] = cls.tz_prev
3118+
time.tzset()
3119+
ReusedPySparkTestCase.tearDownClass()
3120+
cls.spark.stop()
31013121

31023122
def assertFramesEqual(self, df_with_arrow, df_without):
31033123
msg = ("DataFrame from Arrow is not equal" +
@@ -3106,8 +3126,8 @@ def assertFramesEqual(self, df_with_arrow, df_without):
31063126
self.assertTrue(df_without.equals(df_with_arrow), msg=msg)
31073127

31083128
def test_unsupported_datatype(self):
3109-
schema = StructType([StructField("dt", DateType(), True)])
3110-
df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema)
3129+
schema = StructType([StructField("decimal", DecimalType(), True)])
3130+
df = self.spark.createDataFrame([(None,)], schema=schema)
31113131
with QuietTest(self.sc):
31123132
self.assertRaises(Exception, lambda: df.toPandas())
31133133

@@ -3385,13 +3405,77 @@ def test_vectorized_udf_varargs(self):
33853405

33863406
def test_vectorized_udf_unsupported_types(self):
33873407
from pyspark.sql.functions import pandas_udf, col
3388-
schema = StructType([StructField("dt", DateType(), True)])
3389-
df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema)
3390-
f = pandas_udf(lambda x: x, DateType())
3408+
schema = StructType([StructField("dt", DecimalType(), True)])
3409+
df = self.spark.createDataFrame([(None,)], schema=schema)
3410+
f = pandas_udf(lambda x: x, DecimalType())
33913411
with QuietTest(self.sc):
33923412
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
33933413
df.select(f(col('dt'))).collect()
33943414

3415+
def test_vectorized_udf_null_date(self):
3416+
from pyspark.sql.functions import pandas_udf, col
3417+
from datetime import date
3418+
schema = StructType().add("date", DateType())
3419+
data = [(date(1969, 1, 1),),
3420+
(date(2012, 2, 2),),
3421+
(None,),
3422+
(date(2100, 4, 4),)]
3423+
df = self.spark.createDataFrame(data, schema=schema)
3424+
date_f = pandas_udf(lambda t: t, returnType=DateType())
3425+
res = df.select(date_f(col("date")))
3426+
self.assertEquals(df.collect(), res.collect())
3427+
3428+
def test_vectorized_udf_timestamps(self):
3429+
from pyspark.sql.functions import pandas_udf, col
3430+
from datetime import datetime
3431+
schema = StructType([
3432+
StructField("idx", LongType(), True),
3433+
StructField("timestamp", TimestampType(), True)])
3434+
data = [(0, datetime(1969, 1, 1, 1, 1, 1)),
3435+
(1, datetime(2012, 2, 2, 2, 2, 2)),
3436+
(2, None),
3437+
(3, datetime(2100, 4, 4, 4, 4, 4))]
3438+
df = self.spark.createDataFrame(data, schema=schema)
3439+
3440+
# Check that a timestamp passed through a pandas_udf will not be altered by timezone calc
3441+
f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType())
3442+
df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp")))
3443+
3444+
@pandas_udf(returnType=BooleanType())
3445+
def check_data(idx, timestamp, timestamp_copy):
3446+
is_equal = timestamp.isnull() # use this array to check values are equal
3447+
for i in range(len(idx)):
3448+
# Check that timestamps are as expected in the UDF
3449+
is_equal[i] = (is_equal[i] and data[idx[i]][1] is None) or \
3450+
timestamp[i].to_pydatetime() == data[idx[i]][1]
3451+
return is_equal
3452+
3453+
result = df.withColumn("is_equal", check_data(col("idx"), col("timestamp"),
3454+
col("timestamp_copy"))).collect()
3455+
# Check that collection values are correct
3456+
self.assertEquals(len(data), len(result))
3457+
for i in range(len(result)):
3458+
self.assertEquals(data[i][1], result[i][1]) # "timestamp" col
3459+
self.assertTrue(result[i][3]) # "is_equal" data in udf was as expected
3460+
3461+
def test_vectorized_udf_return_timestamp_tz(self):
3462+
from pyspark.sql.functions import pandas_udf, col
3463+
import pandas as pd
3464+
df = self.spark.range(10)
3465+
3466+
@pandas_udf(returnType=TimestampType())
3467+
def gen_timestamps(id):
3468+
ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id]
3469+
return pd.Series(ts)
3470+
3471+
result = df.withColumn("ts", gen_timestamps(col("id"))).collect()
3472+
spark_ts_t = TimestampType()
3473+
for r in result:
3474+
i, ts = r
3475+
ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime()
3476+
expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz))
3477+
self.assertEquals(expected, ts)
3478+
33953479

33963480
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
33973481
class GroupbyApplyTests(ReusedPySparkTestCase):
@@ -3550,8 +3634,8 @@ def test_wrong_args(self):
35503634
def test_unsupported_types(self):
35513635
from pyspark.sql.functions import pandas_udf, col
35523636
schema = StructType(
3553-
[StructField("id", LongType(), True), StructField("dt", DateType(), True)])
3554-
df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema)
3637+
[StructField("id", LongType(), True), StructField("dt", DecimalType(), True)])
3638+
df = self.spark.createDataFrame([(1, None,)], schema=schema)
35553639
f = pandas_udf(lambda x: x, df.schema)
35563640
with QuietTest(self.sc):
35573641
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):

python/pyspark/sql/types.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,11 +1619,47 @@ def to_arrow_type(dt):
16191619
arrow_type = pa.decimal(dt.precision, dt.scale)
16201620
elif type(dt) == StringType:
16211621
arrow_type = pa.string()
1622+
elif type(dt) == DateType:
1623+
arrow_type = pa.date32()
1624+
elif type(dt) == TimestampType:
1625+
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
1626+
arrow_type = pa.timestamp('us', tz='UTC')
16221627
else:
16231628
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
16241629
return arrow_type
16251630

16261631

1632+
def _check_dataframe_localize_timestamps(pdf):
1633+
"""
1634+
Convert timezone aware timestamps to timezone-naive in local time
1635+
1636+
:param pdf: pandas.DataFrame
1637+
:return pandas.DataFrame where any timezone aware columns have be converted to tz-naive
1638+
"""
1639+
from pandas.api.types import is_datetime64tz_dtype
1640+
for column, series in pdf.iteritems():
1641+
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
1642+
if is_datetime64tz_dtype(series.dtype):
1643+
pdf[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None)
1644+
return pdf
1645+
1646+
1647+
def _check_series_convert_timestamps_internal(s):
1648+
"""
1649+
Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage
1650+
:param s: a pandas.Series
1651+
:return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
1652+
"""
1653+
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
1654+
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
1655+
if is_datetime64_dtype(s.dtype):
1656+
return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC')
1657+
elif is_datetime64tz_dtype(s.dtype):
1658+
return s.dt.tz_convert('UTC')
1659+
else:
1660+
return s
1661+
1662+
16271663
def _test():
16281664
import doctest
16291665
from pyspark.context import SparkContext

sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,10 @@ public ArrowColumnVector(ValueVector vector) {
320320
accessor = new StringAccessor((NullableVarCharVector) vector);
321321
} else if (vector instanceof NullableVarBinaryVector) {
322322
accessor = new BinaryAccessor((NullableVarBinaryVector) vector);
323+
} else if (vector instanceof NullableDateDayVector) {
324+
accessor = new DateAccessor((NullableDateDayVector) vector);
325+
} else if (vector instanceof NullableTimeStampMicroTZVector) {
326+
accessor = new TimestampAccessor((NullableTimeStampMicroTZVector) vector);
323327
} else if (vector instanceof ListVector) {
324328
ListVector listVector = (ListVector) vector;
325329
accessor = new ArrayAccessor(listVector);
@@ -575,6 +579,36 @@ final byte[] getBinary(int rowId) {
575579
}
576580
}
577581

582+
private static class DateAccessor extends ArrowVectorAccessor {
583+
584+
private final NullableDateDayVector.Accessor accessor;
585+
586+
DateAccessor(NullableDateDayVector vector) {
587+
super(vector);
588+
this.accessor = vector.getAccessor();
589+
}
590+
591+
@Override
592+
final int getInt(int rowId) {
593+
return accessor.get(rowId);
594+
}
595+
}
596+
597+
private static class TimestampAccessor extends ArrowVectorAccessor {
598+
599+
private final NullableTimeStampMicroTZVector.Accessor accessor;
600+
601+
TimestampAccessor(NullableTimeStampMicroTZVector vector) {
602+
super(vector);
603+
this.accessor = vector.getAccessor();
604+
}
605+
606+
@Override
607+
final long getLong(int rowId) {
608+
return accessor.get(rowId);
609+
}
610+
}
611+
578612
private static class ArrayAccessor extends ArrowVectorAccessor {
579613

580614
private final UInt4Vector.Accessor accessor;

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3154,9 +3154,11 @@ class Dataset[T] private[sql](
31543154
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
31553155
val schemaCaptured = this.schema
31563156
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
3157+
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
31573158
queryExecution.toRdd.mapPartitionsInternal { iter =>
31583159
val context = TaskContext.get()
3159-
ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch, context)
3160+
ArrowConverters.toPayloadIterator(
3161+
iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context)
31603162
}
31613163
}
31623164
}

sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,10 @@ private[sql] object ArrowConverters {
7474
rowIter: Iterator[InternalRow],
7575
schema: StructType,
7676
maxRecordsPerBatch: Int,
77+
timeZoneId: String,
7778
context: TaskContext): Iterator[ArrowPayload] = {
7879

79-
val arrowSchema = ArrowUtils.toArrowSchema(schema)
80+
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
8081
val allocator =
8182
ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue)
8283

0 commit comments

Comments
 (0)