18
18
19
19
import sys
20
20
import types
21
- import array
22
21
import itertools
23
22
import warnings
24
23
import decimal
25
24
import datetime
26
- from operator import itemgetter
27
25
import keyword
28
26
import warnings
27
+ from array import array
28
+ from operator import itemgetter
29
29
30
30
from pyspark .rdd import RDD , PipelinedRDD
31
31
from pyspark .serializers import BatchedSerializer , PickleSerializer
@@ -441,7 +441,7 @@ def _infer_type(obj):
441
441
raise ValueError ("Can not infer type for empty dict" )
442
442
key , value = obj .iteritems ().next ()
443
443
return MapType (_infer_type (key ), _infer_type (value ), True )
444
- elif isinstance (obj , (list , array . array )):
444
+ elif isinstance (obj , (list , array )):
445
445
if not obj :
446
446
raise ValueError ("Can not infer type for empty list/array" )
447
447
return ArrayType (_infer_type (obj [0 ]), True )
@@ -456,14 +456,20 @@ def _infer_schema(row):
456
456
"""Infer the schema from dict/namedtuple/object"""
457
457
if isinstance (row , dict ):
458
458
items = sorted (row .items ())
459
+
459
460
elif isinstance (row , tuple ):
460
461
if hasattr (row , "_fields" ): # namedtuple
461
462
items = zip (row ._fields , tuple (row ))
462
- elif all (isinstance (x , tuple ) and len (x ) == 2
463
- for x in row ):
463
+ elif hasattr (row , "__FIELDS__" ): # Row
464
+ items = zip (row .__FIELDS__ , tuple (row ))
465
+ elif all (isinstance (x , tuple ) and len (x ) == 2 for x in row ):
464
466
items = row
467
+ else :
468
+ raise ValueError ("Can't infer schema from tuple" )
469
+
465
470
elif hasattr (row , "__dict__" ): # object
466
471
items = sorted (row .__dict__ .items ())
472
+
467
473
else :
468
474
raise ValueError ("Can not infer schema for type: %s" % type (row ))
469
475
@@ -494,9 +500,12 @@ def _create_converter(obj, dataType):
494
500
elif isinstance (obj , tuple ):
495
501
if hasattr (obj , "_fields" ): # namedtuple
496
502
conv = tuple
497
- elif all (isinstance (x , tuple ) and len (x ) == 2
498
- for x in obj ):
503
+ elif hasattr (obj , "__FIELDS__" ):
504
+ conv = tuple
505
+ elif all (isinstance (x , tuple ) and len (x ) == 2 for x in obj ):
499
506
conv = lambda o : tuple (v for k , v in o )
507
+ else :
508
+ raise ValueError ("unexpected tuple" )
500
509
501
510
elif hasattr (obj , "__dict__" ): # object
502
511
conv = lambda o : [o .__dict__ .get (n , None ) for n in names ]
@@ -783,6 +792,7 @@ class Row(tuple):
783
792
""" Row in SchemaRDD """
784
793
__DATATYPE__ = dataType
785
794
__FIELDS__ = tuple (f .name for f in dataType .fields )
795
+ __slots__ = ()
786
796
787
797
# create property for fast access
788
798
locals ().update (_create_properties (dataType .fields ))
@@ -814,7 +824,7 @@ def __init__(self, sparkContext, sqlContext=None):
814
824
>>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
815
825
Traceback (most recent call last):
816
826
...
817
- ValueError :...
827
+ TypeError :...
818
828
819
829
>>> bad_rdd = sc.parallelize([1,2,3])
820
830
>>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
@@ -823,9 +833,9 @@ def __init__(self, sparkContext, sqlContext=None):
823
833
ValueError:...
824
834
825
835
>>> from datetime import datetime
826
- >>> allTypes = sc.parallelize([{" int": 1, " string": "string",
827
- ... "double": 1.0, " long": 1L, " boolean": True, " list": [1, 2, 3],
828
- ... "time": datetime(2010, 1, 1, 1, 1, 1), " dict": {"a": 1},} ])
836
+ >>> allTypes = sc.parallelize([Row( int= 1, string= "string",
837
+ ... double= 1.0, long= 1L, boolean= True, list= [1, 2, 3],
838
+ ... time= datetime(2010, 1, 1, 1, 1, 1), dict= {"a": 1}) ])
829
839
>>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string,
830
840
... x.double, x.long, x.boolean, x.time, x.dict["a"], x.list))
831
841
>>> srdd.collect()[0]
@@ -851,33 +861,48 @@ def _ssql_ctx(self):
851
861
return self ._scala_SQLContext
852
862
853
863
def inferSchema (self , rdd ):
854
- """Infer and apply a schema to an RDD of L{dict}s.
864
+ """Infer and apply a schema to an RDD of L{Row}s.
865
+
866
+ We peek at the first row of the RDD to determine the fields' names
867
+ and types. Nested collections are supported, which include array,
868
+ dict, list, Row, tuple, namedtuple, or object.
855
869
856
- We peek at the first row of the RDD to determine the fields names
857
- and types, and then use that to extract all the dictionaries. Nested
858
- collections are supported, which include array, dict, list, set, and
859
- tuple.
870
+ Each row in `rdd` should be Row object or namedtuple or objects,
871
+ using dict is deprecated.
860
872
873
+ >>> rdd = sc.parallelize(
874
+ ... [Row(field1=1, field2="row1"),
875
+ ... Row(field1=2, field2="row2"),
876
+ ... Row(field1=3, field2="row3")])
861
877
>>> srdd = sqlCtx.inferSchema(rdd)
862
878
>>> srdd.collect()[0]
863
879
Row(field1=1, field2=u'row1')
864
880
865
- >>> from array import array
881
+ >>> NestedRow = Row("f1", "f2")
882
+ >>> nestedRdd1 = sc.parallelize([
883
+ ... NestedRow(array('i', [1, 2]), {"row1": 1.0}),
884
+ ... NestedRow(array('i', [2, 3]), {"row2": 2.0})])
866
885
>>> srdd = sqlCtx.inferSchema(nestedRdd1)
867
886
>>> srdd.collect()
868
887
[Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})]
869
888
889
+ >>> nestedRdd2 = sc.parallelize([
890
+ ... NestedRow([[1, 2], [2, 3]], [1, 2]),
891
+ ... NestedRow([[2, 3], [3, 4]], [2, 3])])
870
892
>>> srdd = sqlCtx.inferSchema(nestedRdd2)
871
893
>>> srdd.collect()
872
894
[Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
873
895
"""
874
- if (rdd .__class__ is SchemaRDD ):
875
- raise ValueError ("Cannot apply schema to %s" % SchemaRDD .__name__ )
896
+
897
+ if isinstance (rdd , SchemaRDD ):
898
+ raise TypeError ("Cannot apply schema to SchemaRDD" )
876
899
877
900
first = rdd .first ()
878
901
if not first :
879
902
raise ValueError ("The first row in RDD is empty, "
880
903
"can not infer schema" )
904
+ if type (first ) is dict :
905
+ warnings .warn ("Using RDD of dict to inferSchema is deprecated" )
881
906
882
907
schema = _infer_schema (first )
883
908
rdd = rdd .mapPartitions (lambda rows : _drop_schema (rows , schema ))
@@ -889,6 +914,7 @@ def applySchema(self, rdd, schema):
889
914
890
915
The schema should be a StructType.
891
916
917
+ >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")])
892
918
>>> schema = StructType([StructField("field1", IntegerType(), False),
893
919
... StructField("field2", StringType(), False)])
894
920
>>> srdd = sqlCtx.applySchema(rdd2, schema)
@@ -929,6 +955,9 @@ def applySchema(self, rdd, schema):
929
955
[Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
930
956
"""
931
957
958
+ if isinstance (rdd , SchemaRDD ):
959
+ raise TypeError ("Cannot apply schema to SchemaRDD" )
960
+
932
961
first = rdd .first ()
933
962
if not isinstance (first , (tuple , list )):
934
963
raise ValueError ("Can not apply schema to type: %s" % type (first ))
@@ -1198,12 +1227,84 @@ def _get_hive_ctx(self):
1198
1227
return self ._jvm .TestHiveContext (self ._jsc .sc ())
1199
1228
1200
1229
1201
- # a stub type, the real type is dynamic generated.
1230
+ def _create_row (fields , values ):
1231
+ row = Row (* values )
1232
+ row .__FIELDS__ = fields
1233
+ return row
1234
+
1235
+
1202
1236
class Row (tuple ):
1203
1237
"""
1204
1238
A row in L{SchemaRDD}. The fields in it can be accessed like attributes.
1239
+
1240
+ Row can be used to create a row object by using named arguments,
1241
+ the fields will be sorted by names.
1242
+
1243
+ >>> row = Row(name="Alice", age=11)
1244
+ >>> row
1245
+ Row(age=11, name='Alice')
1246
+ >>> row.name, row.age
1247
+ ('Alice', 11)
1248
+
1249
+ Row also can be used to create another Row like class, then it
1250
+ could be used to create Row objects, such as
1251
+
1252
+ >>> Person = Row("name", "age")
1253
+ >>> Person
1254
+ <Row(name, age)>
1255
+ >>> Person("Alice", 11)
1256
+ Row(name='Alice', age=11)
1205
1257
"""
1206
1258
1259
+ def __new__ (self , * args , ** kwargs ):
1260
+ if args and kwargs :
1261
+ raise ValueError ("Can not use both args "
1262
+ "and kwargs to create Row" )
1263
+ if args :
1264
+ # create row class or objects
1265
+ return tuple .__new__ (self , args )
1266
+
1267
+ elif kwargs :
1268
+ # create row objects
1269
+ names = sorted (kwargs .keys ())
1270
+ values = tuple (kwargs [n ] for n in names )
1271
+ row = tuple .__new__ (self , values )
1272
+ row .__FIELDS__ = names
1273
+ return row
1274
+
1275
+ else :
1276
+ raise ValueError ("No args or kwargs" )
1277
+
1278
+
1279
+ # let obect acs like class
1280
+ def __call__ (self , * args ):
1281
+ """create new Row object"""
1282
+ return _create_row (self , args )
1283
+
1284
+ def __getattr__ (self , item ):
1285
+ if item .startswith ("__" ):
1286
+ raise AttributeError (item )
1287
+ try :
1288
+ # it will be slow when it has many fields,
1289
+ # but this will not be used in normal cases
1290
+ idx = self .__FIELDS__ .index (item )
1291
+ return self [idx ]
1292
+ except IndexError :
1293
+ raise AttributeError (item )
1294
+
1295
+ def __reduce__ (self ):
1296
+ if hasattr (self , "__FIELDS__" ):
1297
+ return (_create_row , (self .__FIELDS__ , tuple (self )))
1298
+ else :
1299
+ return tuple .__reduce__ (self )
1300
+
1301
+ def __repr__ (self ):
1302
+ if hasattr (self , "__FIELDS__" ):
1303
+ return "Row(%s)" % ", " .join ("%s=%r" % (k , v )
1304
+ for k , v in zip (self .__FIELDS__ , self ))
1305
+ else :
1306
+ return "<Row(%s)>" % ", " .join (self )
1307
+
1207
1308
1208
1309
class SchemaRDD (RDD ):
1209
1310
"""An RDD of L{Row} objects that has an associated schema.
@@ -1424,19 +1525,18 @@ def _test():
1424
1525
from pyspark .context import SparkContext
1425
1526
# let doctest run in pyspark.sql, so DataTypes can be picklable
1426
1527
import pyspark .sql
1427
- from pyspark .sql import SQLContext
1528
+ from pyspark .sql import Row , SQLContext
1428
1529
globs = pyspark .sql .__dict__ .copy ()
1429
1530
# The small batch size here ensures that we see multiple batches,
1430
1531
# even in these small test examples:
1431
1532
sc = SparkContext ('local[4]' , 'PythonTest' , batchSize = 2 )
1432
1533
globs ['sc' ] = sc
1433
1534
globs ['sqlCtx' ] = SQLContext (sc )
1434
1535
globs ['rdd' ] = sc .parallelize (
1435
- [{ " field1" : 1 , " field2" : " row1"} ,
1436
- { " field1" : 2 , " field2" : " row2"} ,
1437
- { " field1" : 3 , " field2" : " row3"} ]
1536
+ [Row ( field1 = 1 , field2 = " row1") ,
1537
+ Row ( field1 = 2 , field2 = " row2") ,
1538
+ Row ( field1 = 3 , field2 = " row3") ]
1438
1539
)
1439
- globs ['rdd2' ] = sc .parallelize ([(1 , "row1" ), (2 , "row2" ), (3 , "row3" )])
1440
1540
jsonStrings = [
1441
1541
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}' ,
1442
1542
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
@@ -1446,12 +1546,6 @@ def _test():
1446
1546
]
1447
1547
globs ['jsonStrings' ] = jsonStrings
1448
1548
globs ['json' ] = sc .parallelize (jsonStrings )
1449
- globs ['nestedRdd1' ] = sc .parallelize ([
1450
- {"f1" : array ('i' , [1 , 2 ]), "f2" : {"row1" : 1.0 }},
1451
- {"f1" : array ('i' , [2 , 3 ]), "f2" : {"row2" : 2.0 }}])
1452
- globs ['nestedRdd2' ] = sc .parallelize ([
1453
- {"f1" : [[1 , 2 ], [2 , 3 ]], "f2" : [1 , 2 ]},
1454
- {"f1" : [[2 , 3 ], [3 , 4 ]], "f2" : [2 , 3 ]}])
1455
1549
(failure_count , test_count ) = doctest .testmod (
1456
1550
pyspark .sql , globs = globs , optionflags = doctest .ELLIPSIS )
1457
1551
globs ['sc' ].stop ()
0 commit comments