Skip to content

Commit

Permalink
[SPARK-5859] [PySpark] [SQL] fix DataFrame Python API
Browse files Browse the repository at this point in the history
1. added explain()
2. add isLocal()
3. do not call show() in __repl__
4. add foreach() and foreachPartition()
5. add distinct()
6. fix functions.col()/column()/lit()
7. fix unit tests in sql/functions.py
8. fix unicode in showString()

Author: Davies Liu <davies@databricks.com>

Closes #4645 from davies/df6 and squashes the following commits:

6b46a2c [Davies Liu] fix DataFrame Python API
  • Loading branch information
Davies Liu authored and marmbrus committed Feb 17, 2015
1 parent c74b07f commit d8adefe
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 18 deletions.
65 changes: 54 additions & 11 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,22 @@ def printSchema(self):
"""
print (self._jdf.schema().treeString())

def explain(self, extended=False):
"""
Prints the plans (logical and physical) to the console for
debugging purpose.
If extended is False, only prints the physical plan.
"""
self._jdf.explain(extended)

def isLocal(self):
"""
Returns True if the `collect` and `take` methods can be run locally
(without any Spark executors).
"""
return self._jdf.isLocal()

def show(self):
"""
Print the first 20 rows.
Expand All @@ -247,14 +263,12 @@ def show(self):
2 Alice
5 Bob
>>> df
age name
2 Alice
5 Bob
DataFrame[age: int, name: string]
"""
print (self)
print self._jdf.showString().encode('utf8', 'ignore')

def __repr__(self):
return self._jdf.showString()
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))

def count(self):
"""Return the number of elements in this RDD.
Expand Down Expand Up @@ -336,13 +350,40 @@ def mapPartitions(self, f, preservesPartitioning=False):
"""
Return a new RDD by applying a function to each partition.
It's a shorthand for df.rdd.mapPartitions()
>>> rdd = sc.parallelize([1, 2, 3, 4], 4)
>>> def f(iterator): yield 1
>>> rdd.mapPartitions(f).sum()
4
"""
return self.rdd.mapPartitions(f, preservesPartitioning)

def foreach(self, f):
"""
Applies a function to all rows of this DataFrame.
It's a shorthand for df.rdd.foreach()
>>> def f(person):
... print person.name
>>> df.foreach(f)
"""
return self.rdd.foreach(f)

def foreachPartition(self, f):
"""
Applies a function to each partition of this DataFrame.
It's a shorthand for df.rdd.foreachPartition()
>>> def f(people):
... for person in people:
... print person.name
>>> df.foreachPartition(f)
"""
return self.rdd.foreachPartition(f)

def cache(self):
""" Persist with the default storage level (C{MEMORY_ONLY_SER}).
"""
Expand Down Expand Up @@ -377,8 +418,13 @@ def repartition(self, numPartitions):
""" Return a new :class:`DataFrame` that has exactly `numPartitions`
partitions.
"""
rdd = self._jdf.repartition(numPartitions, None)
return DataFrame(rdd, self.sql_ctx)
return DataFrame(self._jdf.repartition(numPartitions, None), self.sql_ctx)

def distinct(self):
"""
Return a new :class:`DataFrame` containing the distinct rows in this DataFrame.
"""
return DataFrame(self._jdf.distinct(), self.sql_ctx)

def sample(self, withReplacement, fraction, seed=None):
"""
Expand Down Expand Up @@ -957,10 +1003,7 @@ def cast(self, dataType):
return Column(jc, self.sql_ctx)

def __repr__(self):
if self._jdf.isComputable():
return self._jdf.samples()
else:
return 'Column<%s>' % self._jdf.toString()
return 'Column<%s>' % self._jdf.toString().encode('utf8')

def toPandas(self):
"""
Expand Down
12 changes: 5 additions & 7 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _create_function(name, doc=""):
""" Create a function for aggregator by name"""
def _(col):
sc = SparkContext._active_spark_context
jc = getattr(sc._jvm.functions, name)(_to_java_column(col))
jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
return Column(jc)
_.__name__ = name
_.__doc__ = doc
Expand Down Expand Up @@ -140,6 +140,7 @@ def __call__(self, *cols):
def udf(f, returnType=StringType()):
"""Create a user defined function (UDF)
>>> from pyspark.sql.types import IntegerType
>>> slen = udf(lambda s: len(s), IntegerType())
>>> df.select(slen(df.name).alias('slen')).collect()
[Row(slen=5), Row(slen=3)]
Expand All @@ -151,17 +152,14 @@ def _test():
import doctest
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext
import pyspark.sql.dataframe
globs = pyspark.sql.dataframe.__dict__.copy()
import pyspark.sql.functions
globs = pyspark.sql.functions.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlCtx'] = SQLContext(sc)
globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
Row(name='Bob', age=5, height=85)]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.dataframe, globs=globs,
pyspark.sql.functions, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
globs['sc'].stop()
if failure_count:
Expand Down

0 comments on commit d8adefe

Please sign in to comment.