Skip to content

Commit 51aa135

Browse files
committed
use Row to infer schema
1 parent e9c0d5c commit 51aa135

File tree

1 file changed

+125
-31
lines changed

1 file changed

+125
-31
lines changed

python/pyspark/sql.py

Lines changed: 125 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818

1919
import sys
2020
import types
21-
import array
2221
import itertools
2322
import warnings
2423
import decimal
2524
import datetime
26-
from operator import itemgetter
2725
import keyword
2826
import warnings
27+
from array import array
28+
from operator import itemgetter
2929

3030
from pyspark.rdd import RDD, PipelinedRDD
3131
from pyspark.serializers import BatchedSerializer, PickleSerializer
@@ -441,7 +441,7 @@ def _infer_type(obj):
441441
raise ValueError("Can not infer type for empty dict")
442442
key, value = obj.iteritems().next()
443443
return MapType(_infer_type(key), _infer_type(value), True)
444-
elif isinstance(obj, (list, array.array)):
444+
elif isinstance(obj, (list, array)):
445445
if not obj:
446446
raise ValueError("Can not infer type for empty list/array")
447447
return ArrayType(_infer_type(obj[0]), True)
@@ -456,14 +456,20 @@ def _infer_schema(row):
456456
"""Infer the schema from dict/namedtuple/object"""
457457
if isinstance(row, dict):
458458
items = sorted(row.items())
459+
459460
elif isinstance(row, tuple):
460461
if hasattr(row, "_fields"): # namedtuple
461462
items = zip(row._fields, tuple(row))
462-
elif all(isinstance(x, tuple) and len(x) == 2
463-
for x in row):
463+
elif hasattr(row, "__FIELDS__"): # Row
464+
items = zip(row.__FIELDS__, tuple(row))
465+
elif all(isinstance(x, tuple) and len(x) == 2 for x in row):
464466
items = row
467+
else:
468+
raise ValueError("Can't infer schema from tuple")
469+
465470
elif hasattr(row, "__dict__"): # object
466471
items = sorted(row.__dict__.items())
472+
467473
else:
468474
raise ValueError("Can not infer schema for type: %s" % type(row))
469475

@@ -494,9 +500,12 @@ def _create_converter(obj, dataType):
494500
elif isinstance(obj, tuple):
495501
if hasattr(obj, "_fields"): # namedtuple
496502
conv = tuple
497-
elif all(isinstance(x, tuple) and len(x) == 2
498-
for x in obj):
503+
elif hasattr(obj, "__FIELDS__"):
504+
conv = tuple
505+
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
499506
conv = lambda o: tuple(v for k, v in o)
507+
else:
508+
raise ValueError("unexpected tuple")
500509

501510
elif hasattr(obj, "__dict__"): # object
502511
conv = lambda o: [o.__dict__.get(n, None) for n in names]
@@ -783,6 +792,7 @@ class Row(tuple):
783792
""" Row in SchemaRDD """
784793
__DATATYPE__ = dataType
785794
__FIELDS__ = tuple(f.name for f in dataType.fields)
795+
__slots__ = ()
786796

787797
# create property for fast access
788798
locals().update(_create_properties(dataType.fields))
@@ -814,7 +824,7 @@ def __init__(self, sparkContext, sqlContext=None):
814824
>>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
815825
Traceback (most recent call last):
816826
...
817-
ValueError:...
827+
TypeError:...
818828
819829
>>> bad_rdd = sc.parallelize([1,2,3])
820830
>>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
@@ -823,9 +833,9 @@ def __init__(self, sparkContext, sqlContext=None):
823833
ValueError:...
824834
825835
>>> from datetime import datetime
826-
>>> allTypes = sc.parallelize([{"int": 1, "string": "string",
827-
... "double": 1.0, "long": 1L, "boolean": True, "list": [1, 2, 3],
828-
... "time": datetime(2010, 1, 1, 1, 1, 1), "dict": {"a": 1},}])
836+
>>> allTypes = sc.parallelize([Row(int=1, string="string",
837+
... double=1.0, long=1L, boolean=True, list=[1, 2, 3],
838+
... time=datetime(2010, 1, 1, 1, 1, 1), dict={"a": 1})])
829839
>>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string,
830840
... x.double, x.long, x.boolean, x.time, x.dict["a"], x.list))
831841
>>> srdd.collect()[0]
@@ -851,33 +861,48 @@ def _ssql_ctx(self):
851861
return self._scala_SQLContext
852862

853863
def inferSchema(self, rdd):
854-
"""Infer and apply a schema to an RDD of L{dict}s.
864+
"""Infer and apply a schema to an RDD of L{Row}s.
865+
866+
We peek at the first row of the RDD to determine the fields' names
867+
and types. Nested collections are supported, which include array,
868+
dict, list, Row, tuple, namedtuple, or object.
855869
856-
We peek at the first row of the RDD to determine the fields names
857-
and types, and then use that to extract all the dictionaries. Nested
858-
collections are supported, which include array, dict, list, set, and
859-
tuple.
870+
Each row in `rdd` should be Row object or namedtuple or objects,
871+
using dict is deprecated.
860872
873+
>>> rdd = sc.parallelize(
874+
... [Row(field1=1, field2="row1"),
875+
... Row(field1=2, field2="row2"),
876+
... Row(field1=3, field2="row3")])
861877
>>> srdd = sqlCtx.inferSchema(rdd)
862878
>>> srdd.collect()[0]
863879
Row(field1=1, field2=u'row1')
864880
865-
>>> from array import array
881+
>>> NestedRow = Row("f1", "f2")
882+
>>> nestedRdd1 = sc.parallelize([
883+
... NestedRow(array('i', [1, 2]), {"row1": 1.0}),
884+
... NestedRow(array('i', [2, 3]), {"row2": 2.0})])
866885
>>> srdd = sqlCtx.inferSchema(nestedRdd1)
867886
>>> srdd.collect()
868887
[Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})]
869888
889+
>>> nestedRdd2 = sc.parallelize([
890+
... NestedRow([[1, 2], [2, 3]], [1, 2]),
891+
... NestedRow([[2, 3], [3, 4]], [2, 3])])
870892
>>> srdd = sqlCtx.inferSchema(nestedRdd2)
871893
>>> srdd.collect()
872894
[Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
873895
"""
874-
if (rdd.__class__ is SchemaRDD):
875-
raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__)
896+
897+
if isinstance(rdd, SchemaRDD):
898+
raise TypeError("Cannot apply schema to SchemaRDD")
876899

877900
first = rdd.first()
878901
if not first:
879902
raise ValueError("The first row in RDD is empty, "
880903
"can not infer schema")
904+
if type(first) is dict:
905+
warnings.warn("Using RDD of dict to inferSchema is deprecated")
881906

882907
schema = _infer_schema(first)
883908
rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema))
@@ -889,6 +914,7 @@ def applySchema(self, rdd, schema):
889914
890915
The schema should be a StructType.
891916
917+
>>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")])
892918
>>> schema = StructType([StructField("field1", IntegerType(), False),
893919
... StructField("field2", StringType(), False)])
894920
>>> srdd = sqlCtx.applySchema(rdd2, schema)
@@ -929,6 +955,9 @@ def applySchema(self, rdd, schema):
929955
[Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
930956
"""
931957

958+
if isinstance(rdd, SchemaRDD):
959+
raise TypeError("Cannot apply schema to SchemaRDD")
960+
932961
first = rdd.first()
933962
if not isinstance(first, (tuple, list)):
934963
raise ValueError("Can not apply schema to type: %s" % type(first))
@@ -1198,12 +1227,84 @@ def _get_hive_ctx(self):
11981227
return self._jvm.TestHiveContext(self._jsc.sc())
11991228

12001229

1201-
# a stub type, the real type is dynamic generated.
1230+
def _create_row(fields, values):
1231+
row = Row(*values)
1232+
row.__FIELDS__ = fields
1233+
return row
1234+
1235+
12021236
class Row(tuple):
12031237
"""
12041238
A row in L{SchemaRDD}. The fields in it can be accessed like attributes.
1239+
1240+
Row can be used to create a row object by using named arguments,
1241+
the fields will be sorted by names.
1242+
1243+
>>> row = Row(name="Alice", age=11)
1244+
>>> row
1245+
Row(age=11, name='Alice')
1246+
>>> row.name, row.age
1247+
('Alice', 11)
1248+
1249+
Row also can be used to create another Row like class, then it
1250+
could be used to create Row objects, such as
1251+
1252+
>>> Person = Row("name", "age")
1253+
>>> Person
1254+
<Row(name, age)>
1255+
>>> Person("Alice", 11)
1256+
Row(name='Alice', age=11)
12051257
"""
12061258

1259+
def __new__(self, *args, **kwargs):
1260+
if args and kwargs:
1261+
raise ValueError("Can not use both args "
1262+
"and kwargs to create Row")
1263+
if args:
1264+
# create row class or objects
1265+
return tuple.__new__(self, args)
1266+
1267+
elif kwargs:
1268+
# create row objects
1269+
names = sorted(kwargs.keys())
1270+
values = tuple(kwargs[n] for n in names)
1271+
row = tuple.__new__(self, values)
1272+
row.__FIELDS__ = names
1273+
return row
1274+
1275+
else:
1276+
raise ValueError("No args or kwargs")
1277+
1278+
1279+
# let obect acs like class
1280+
def __call__(self, *args):
1281+
"""create new Row object"""
1282+
return _create_row(self, args)
1283+
1284+
def __getattr__(self, item):
1285+
if item.startswith("__"):
1286+
raise AttributeError(item)
1287+
try:
1288+
# it will be slow when it has many fields,
1289+
# but this will not be used in normal cases
1290+
idx = self.__FIELDS__.index(item)
1291+
return self[idx]
1292+
except IndexError:
1293+
raise AttributeError(item)
1294+
1295+
def __reduce__(self):
1296+
if hasattr(self, "__FIELDS__"):
1297+
return (_create_row, (self.__FIELDS__, tuple(self)))
1298+
else:
1299+
return tuple.__reduce__(self)
1300+
1301+
def __repr__(self):
1302+
if hasattr(self, "__FIELDS__"):
1303+
return "Row(%s)" % ", ".join("%s=%r" % (k, v)
1304+
for k, v in zip(self.__FIELDS__, self))
1305+
else:
1306+
return "<Row(%s)>" % ", ".join(self)
1307+
12071308

12081309
class SchemaRDD(RDD):
12091310
"""An RDD of L{Row} objects that has an associated schema.
@@ -1424,19 +1525,18 @@ def _test():
14241525
from pyspark.context import SparkContext
14251526
# let doctest run in pyspark.sql, so DataTypes can be picklable
14261527
import pyspark.sql
1427-
from pyspark.sql import SQLContext
1528+
from pyspark.sql import Row, SQLContext
14281529
globs = pyspark.sql.__dict__.copy()
14291530
# The small batch size here ensures that we see multiple batches,
14301531
# even in these small test examples:
14311532
sc = SparkContext('local[4]', 'PythonTest', batchSize=2)
14321533
globs['sc'] = sc
14331534
globs['sqlCtx'] = SQLContext(sc)
14341535
globs['rdd'] = sc.parallelize(
1435-
[{"field1": 1, "field2": "row1"},
1436-
{"field1": 2, "field2": "row2"},
1437-
{"field1": 3, "field2": "row3"}]
1536+
[Row(field1=1, field2="row1"),
1537+
Row(field1=2, field2="row2"),
1538+
Row(field1=3, field2="row3")]
14381539
)
1439-
globs['rdd2'] = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")])
14401540
jsonStrings = [
14411541
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
14421542
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
@@ -1446,12 +1546,6 @@ def _test():
14461546
]
14471547
globs['jsonStrings'] = jsonStrings
14481548
globs['json'] = sc.parallelize(jsonStrings)
1449-
globs['nestedRdd1'] = sc.parallelize([
1450-
{"f1": array('i', [1, 2]), "f2": {"row1": 1.0}},
1451-
{"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}])
1452-
globs['nestedRdd2'] = sc.parallelize([
1453-
{"f1": [[1, 2], [2, 3]], "f2": [1, 2]},
1454-
{"f1": [[2, 3], [3, 4]], "f2": [2, 3]}])
14551549
(failure_count, test_count) = doctest.testmod(
14561550
pyspark.sql, globs=globs, optionflags=doctest.ELLIPSIS)
14571551
globs['sc'].stop()

0 commit comments

Comments
 (0)