Skip to content

Commit f53b94c

Browse files
committed
fix for SPARK-5722 infer long type in python similar to Java long
1 parent 53de237 commit f53b94c

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

python/pyspark/sql.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,10 @@ def _infer_type(obj):
605605

606606
dataType = _type_mappings.get(type(obj))
607607
if dataType is not None:
608+
# Conform to Java int/long sizes SPARK-5722
609+
if dataType == IntegerType:
610+
if obj > 2**31 - 1 or obj < -2**31:
611+
dataType = LongType
608612
return dataType()
609613

610614
if isinstance(obj, dict):

python/pyspark/tests.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
CloudPickleSerializer, CompressedSerializer
5252
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
5353
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
54-
UserDefinedType, DoubleType
54+
UserDefinedType, DoubleType, LongType, _infer_type
5555
from pyspark import shuffle
5656

5757
_have_scipy = False
@@ -923,6 +923,20 @@ def test_infer_schema(self):
923923
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
924924
self.assertEqual(1, result.first()[0])
925925

926+
def test_infer_long_type(self):
927+
longrow = [Row(f1='a', f2=100000000000000)]
928+
lrdd = self.sc.parallelize(longrow)
929+
slrdd = self.sqlCtx.inferSchema(lrdd)
930+
self.assertEqual(slrdd.schema().fields[1].dataType, LongType())
931+
932+
self.assertEqual(_infer_type(1), IntegerType())
933+
self.assertEqual(_infer_type(2**10), IntegerType())
934+
self.assertEqual(_infer_type(2**20), IntegerType())
935+
self.assertEqual(_infer_type(2**31 - 1), IntegerType())
936+
self.assertEqual(_infer_type(2**31), LongType())
937+
self.assertEqual(_infer_type(2**61), LongType())
938+
self.assertEqual(_infer_type(2**71), LongType())
939+
926940
def test_convert_row_to_dict(self):
927941
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
928942
self.assertEqual(1, row.asDict()['l'][0].a)

0 commit comments

Comments
 (0)