Skip to content

[SPARK-8060] Improve DataFrame Python test coverage and documentation. #6601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .rat-excludes
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,4 @@ local-1426633911242/*
local-1430917381534/*
DESCRIPTION
NAMESPACE
test_support/*
13 changes: 12 additions & 1 deletion python/pyspark/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,20 @@


def since(version):
"""
A decorator that annotates a function to append the version of Spark the function was added.
"""
import re
indent_p = re.compile(r'\n( +)')

def deco(f):
f.__doc__ = f.__doc__.rstrip() + "\n\n.. versionadded:: %s" % version
indents = indent_p.findall(f.__doc__)
indent = ' ' * (min(len(m) for m in indents) if indents else 0)
f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version)
return f
return deco


from pyspark.sql.types import Row
from pyspark.sql.context import SQLContext, HiveContext
from pyspark.sql.column import Column
Expand All @@ -58,7 +67,9 @@ def deco(f):
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
from pyspark.sql.window import Window, WindowSpec


__all__ = [
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
'DataFrameReader', 'DataFrameWriter'
]
89 changes: 37 additions & 52 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def getConf(self, key, defaultValue):
@property
@since("1.3.1")
def udf(self):
"""Returns a :class:`UDFRegistration` for UDF registration."""
"""Returns a :class:`UDFRegistration` for UDF registration.

:return: :class:`UDFRegistration`
"""
return UDFRegistration(self)

@since(1.4)
Expand All @@ -138,7 +141,7 @@ def range(self, start, end, step=1, numPartitions=None):
:param end: the end value (exclusive)
:param step: the incremental step (default: 1)
:param numPartitions: the number of partitions of the DataFrame
:return: A new DataFrame
:return: :class:`DataFrame`

>>> sqlContext.range(1, 7, 2).collect()
[Row(id=1), Row(id=3), Row(id=5)]
Expand Down Expand Up @@ -195,8 +198,8 @@ def _inferSchema(self, rdd, samplingRatio=None):
raise ValueError("The first row in RDD is empty, "
"can not infer schema")
if type(first) is dict:
warnings.warn("Using RDD of dict to inferSchema is deprecated,"
"please use pyspark.sql.Row instead")
warnings.warn("Using RDD of dict to inferSchema is deprecated. "
"Use pyspark.sql.Row instead")

if samplingRatio is None:
schema = _infer_schema(first)
Expand All @@ -219,7 +222,7 @@ def inferSchema(self, rdd, samplingRatio=None):
"""
.. note:: Deprecated in 1.3, use :func:`createDataFrame` instead.
"""
warnings.warn("inferSchema is deprecated, please use createDataFrame instead")
warnings.warn("inferSchema is deprecated, please use createDataFrame instead.")

if isinstance(rdd, DataFrame):
raise TypeError("Cannot apply schema to DataFrame")
Expand Down Expand Up @@ -262,6 +265,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
:class:`list`, or :class:`pandas.DataFrame`.
:param schema: a :class:`StructType` or list of column names. default None.
:param samplingRatio: the sample ratio of rows used for inferring
:return: :class:`DataFrame`

>>> l = [('Alice', 1)]
>>> sqlContext.createDataFrame(l).collect()
Expand Down Expand Up @@ -359,58 +363,31 @@ def registerDataFrameAsTable(self, df, tableName):
else:
raise ValueError("Can only register DataFrame as table")

@since(1.0)
def parquetFile(self, *paths):
"""Loads a Parquet file, returning the result as a :class:`DataFrame`.

>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
>>> df.saveAsParquetFile(parquetFile)
>>> df2 = sqlContext.parquetFile(parquetFile)
>>> sorted(df.collect()) == sorted(df2.collect())
True
.. note:: Deprecated in 1.4, use :func:`DataFrameReader.parquet` instead.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a different implementation than data source api, may have different behavior. Even we deprecate it, should still have tests for it (make sure it's not broken).


>>> sqlContext.parquetFile('python/test_support/sql/parquet_partitioned').dtypes
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
warnings.warn("parquetFile is deprecated. Use read.parquet() instead.")
gateway = self._sc._gateway
jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths))
for i in range(0, len(paths)):
jpaths[i] = paths[i]
jdf = self._ssql_ctx.parquetFile(jpaths)
return DataFrame(jdf, self)

@since(1.0)
def jsonFile(self, path, schema=None, samplingRatio=1.0):
"""Loads a text file storing one JSON object per line as a :class:`DataFrame`.

If the schema is provided, applies the given schema to this JSON dataset.
Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema.

>>> import tempfile, shutil
>>> jsonFile = tempfile.mkdtemp()
>>> shutil.rmtree(jsonFile)
>>> with open(jsonFile, 'w') as f:
... f.writelines(jsonStrings)
>>> df1 = sqlContext.jsonFile(jsonFile)
>>> df1.printSchema()
root
|-- field1: long (nullable = true)
|-- field2: string (nullable = true)
|-- field3: struct (nullable = true)
| |-- field4: long (nullable = true)
.. note:: Deprecated in 1.4, use :func:`DataFrameReader.json` instead.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as parquetFile().


>>> from pyspark.sql.types import *
>>> schema = StructType([
... StructField("field2", StringType()),
... StructField("field3",
... StructType([StructField("field5", ArrayType(IntegerType()))]))])
>>> df2 = sqlContext.jsonFile(jsonFile, schema)
>>> df2.printSchema()
root
|-- field2: string (nullable = true)
|-- field3: struct (nullable = true)
| |-- field5: array (nullable = true)
| | |-- element: integer (containsNull = true)
>>> sqlContext.jsonFile('python/test_support/sql/people.json').dtypes
[('age', 'bigint'), ('name', 'string')]
"""
warnings.warn("jsonFile is deprecated. Use read.json() instead.")
if schema is None:
df = self._ssql_ctx.jsonFile(path, samplingRatio)
else:
Expand Down Expand Up @@ -462,21 +439,16 @@ def func(iterator):
df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return DataFrame(df, self)

@since(1.3)
def load(self, path=None, source=None, schema=None, **options):
"""Returns the dataset in a data source as a :class:`DataFrame`.

The data source is specified by the ``source`` and a set of ``options``.
If ``source`` is not specified, the default data source configured by
``spark.sql.sources.default`` will be used.

Optionally, a schema can be provided as the schema of the returned DataFrame.
.. note:: Deprecated in 1.4, use :func:`DataFrameReader.load` instead.
"""
warnings.warn("load is deprecated. Use read.load() instead.")
return self.read.load(path, source, schema, **options)

@since(1.3)
def createExternalTable(self, tableName, path=None, source=None,
schema=None, **options):
def createExternalTable(self, tableName, path=None, source=None, schema=None, **options):
"""Creates an external table based on the dataset in a data source.

It returns the DataFrame associated with the external table.
Expand All @@ -487,6 +459,8 @@ def createExternalTable(self, tableName, path=None, source=None,

Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and
created external table.

:return: :class:`DataFrame`
"""
if path is not None:
options["path"] = path
Expand All @@ -508,6 +482,8 @@ def createExternalTable(self, tableName, path=None, source=None,
def sql(self, sqlQuery):
"""Returns a :class:`DataFrame` representing the result of the given query.

:return: :class:`DataFrame`

>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> df2.collect()
Expand All @@ -519,6 +495,8 @@ def sql(self, sqlQuery):
def table(self, tableName):
"""Returns the specified table as a :class:`DataFrame`.

:return: :class:`DataFrame`

>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> df2 = sqlContext.table("table1")
>>> sorted(df.collect()) == sorted(df2.collect())
Expand All @@ -536,6 +514,9 @@ def tables(self, dbName=None):
The returned DataFrame has two columns: ``tableName`` and ``isTemporary``
(a column with :class:`BooleanType` indicating if a table is a temporary one or not).

:param dbName: string, name of the database to use.
:return: :class:`DataFrame`

>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> df2 = sqlContext.tables()
>>> df2.filter("tableName = 'table1'").first()
Expand All @@ -550,7 +531,8 @@ def tables(self, dbName=None):
def tableNames(self, dbName=None):
"""Returns a list of names of tables in the database ``dbName``.

If ``dbName`` is not specified, the current database will be used.
:param dbName: string, name of the database to use. Default to the current database.
:return: list of table names, in string

>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> "table1" in sqlContext.tableNames()
Expand Down Expand Up @@ -585,8 +567,7 @@ def read(self):
Returns a :class:`DataFrameReader` that can be used to read data
in as a :class:`DataFrame`.

>>> sqlContext.read
<pyspark.sql.readwriter.DataFrameReader object at ...>
:return: :class:`DataFrameReader`
"""
return DataFrameReader(self)

Expand Down Expand Up @@ -644,10 +625,14 @@ def register(self, name, f, returnType=StringType()):


def _test():
import os
import doctest
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext
import pyspark.sql.context

os.chdir(os.environ["SPARK_HOME"])

globs = pyspark.sql.context.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
Expand Down
Loading