Skip to content

Commit

Permalink
[SPARK-5898] [SPARK-5896] [SQL] [PySpark] create DataFrame from panda…
Browse files Browse the repository at this point in the history
…s and tuple/list

Fix createDataFrame() from pandas DataFrame (not tested by jenkins, depends on SPARK-5693).

It also support to create DataFrame from plain tuple/list without column names, `_1`, `_2` will be used as column names.

Author: Davies Liu <davies@databricks.com>

Closes apache#4679 from davies/pandas and squashes the following commits:

c0cbe0b [Davies Liu] fix tests
8466d1d [Davies Liu] fix create DataFrame from pandas
  • Loading branch information
Davies Liu authored and marmbrus committed Feb 20, 2015
1 parent 4a17eed commit 5b0a42c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
12 changes: 10 additions & 2 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
:return: a DataFrame
>>> l = [('Alice', 1)]
>>> sqlCtx.createDataFrame(l).collect()
[Row(_1=u'Alice', _2=1)]
>>> sqlCtx.createDataFrame(l, ['name', 'age']).collect()
[Row(name=u'Alice', age=1)]
Expand All @@ -359,6 +361,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
[Row(age=1, name=u'Alice')]
>>> rdd = sc.parallelize(l)
>>> sqlCtx.createDataFrame(rdd).collect()
[Row(_1=u'Alice', _2=1)]
>>> df = sqlCtx.createDataFrame(rdd, ['name', 'age'])
>>> df.collect()
[Row(name=u'Alice', age=1)]
Expand All @@ -377,14 +381,17 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
>>> df3 = sqlCtx.createDataFrame(rdd, schema)
>>> df3.collect()
[Row(name=u'Alice', age=1)]
>>> sqlCtx.createDataFrame(df.toPandas()).collect() # doctest: +SKIP
[Row(name=u'Alice', age=1)]
"""
if isinstance(data, DataFrame):
raise TypeError("data is already a DataFrame")

if has_pandas and isinstance(data, pandas.DataFrame):
data = self._sc.parallelize(data.to_records(index=False))
if schema is None:
schema = list(data.columns)
data = [r.tolist() for r in data.to_records(index=False)]

if not isinstance(data, RDD):
try:
Expand All @@ -399,7 +406,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
if isinstance(schema, (list, tuple)):
first = data.first()
if not isinstance(first, (list, tuple)):
raise ValueError("each row in `rdd` should be list or tuple")
raise ValueError("each row in `rdd` should be list or tuple, "
"but got %r" % type(first))
row_cls = Row(*schema)
schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio)

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_serialize_nested_array_and_map(self):
self.assertEqual("2", row.d)

def test_infer_schema(self):
d = [Row(l=[], d={}),
d = [Row(l=[], d={}, s=None),
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
rdd = self.sc.parallelize(d)
df = self.sqlCtx.createDataFrame(rdd)
Expand Down
26 changes: 9 additions & 17 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def _infer_type(obj):
ExamplePointUDT
"""
if obj is None:
raise ValueError("Can not infer type for None")
return NullType()

if hasattr(obj, '__UDT__'):
return obj.__UDT__
Expand Down Expand Up @@ -637,15 +637,14 @@ def _infer_schema(row):
if isinstance(row, dict):
items = sorted(row.items())

elif isinstance(row, tuple):
elif isinstance(row, (tuple, list)):
if hasattr(row, "_fields"): # namedtuple
items = zip(row._fields, tuple(row))
elif hasattr(row, "__FIELDS__"): # Row
items = zip(row.__FIELDS__, tuple(row))
elif all(isinstance(x, tuple) and len(x) == 2 for x in row):
items = row
else:
raise ValueError("Can't infer schema from tuple")
names = ['_%d' % i for i in range(1, len(row) + 1)]
items = zip(names, row)

elif hasattr(row, "__dict__"): # object
items = sorted(row.__dict__.items())
Expand Down Expand Up @@ -812,17 +811,10 @@ def convert_struct(obj):
if obj is None:
return

if isinstance(obj, tuple):
if hasattr(obj, "_fields"):
d = dict(zip(obj._fields, obj))
elif hasattr(obj, "__FIELDS__"):
d = dict(zip(obj.__FIELDS__, obj))
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
d = dict(obj)
else:
raise ValueError("unexpected tuple: %s" % str(obj))
if isinstance(obj, (tuple, list)):
return tuple(conv(v) for v, conv in zip(obj, converters))

elif isinstance(obj, dict):
if isinstance(obj, dict):
d = obj
elif hasattr(obj, "__dict__"): # object
d = obj.__dict__
Expand Down Expand Up @@ -1022,7 +1014,7 @@ def _verify_type(obj, dataType):
return

_type = type(dataType)
assert _type in _acceptable_types, "unkown datatype: %s" % dataType
assert _type in _acceptable_types, "unknown datatype: %s" % dataType

# subclass of them can not be deserialized in JVM
if type(obj) not in _acceptable_types[_type]:
Expand All @@ -1040,7 +1032,7 @@ def _verify_type(obj, dataType):

elif isinstance(dataType, StructType):
if len(obj) != len(dataType.fields):
raise ValueError("Length of object (%d) does not match with"
raise ValueError("Length of object (%d) does not match with "
"length of fields (%d)" % (len(obj), len(dataType.fields)))
for v, f in zip(obj, dataType.fields):
_verify_type(v, f.dataType)
Expand Down

0 comments on commit 5b0a42c

Please sign in to comment.