Skip to content

Commit 073c2b5

Browse files
yhuainemccarthy
authored andcommitted
[SPARK-8532] [SQL] In Python's DataFrameWriter, save/saveAsTable/json/parquet/jdbc always override mode
https://issues.apache.org/jira/browse/SPARK-8532 This PR has two changes. First, it fixes the bug that save actions (i.e. `save/saveAsTable/json/parquet/jdbc`) always override mode. Second, it adds input argument `partitionBy` to `save/saveAsTable/parquet`. Author: Yin Huai <yhuai@databricks.com> Closes apache#6937 from yhuai/SPARK-8532 and squashes the following commits: f972d5d [Yin Huai] davies's comment. d37abd2 [Yin Huai] style. d21290a [Yin Huai] Python doc. 889eb25 [Yin Huai] Minor refactoring and add partitionBy to save, saveAsTable, and parquet. 7fbc24b [Yin Huai] Use None instead of "error" as the default value of mode since JVM-side already uses "error" as the default value. d696dff [Yin Huai] Python style. 88eb6c4 [Yin Huai] If mode is "error", do not call mode method. c40c461 [Yin Huai] Regression test. (cherry picked from commit 5ab9fcf) Signed-off-by: Yin Huai <yhuai@databricks.com>
1 parent 49dcd88 commit 073c2b5

File tree

2 files changed

+51
-11
lines changed

2 files changed

+51
-11
lines changed

python/pyspark/sql/readwriter.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,10 @@ def mode(self, saveMode):
218218
219219
>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
220220
"""
221-
self._jwrite = self._jwrite.mode(saveMode)
221+
# At the JVM side, the default value of mode is already set to "error".
222+
# So, if the given saveMode is None, we will not call JVM-side's mode method.
223+
if saveMode is not None:
224+
self._jwrite = self._jwrite.mode(saveMode)
222225
return self
223226

224227
@since(1.4)
@@ -253,11 +256,12 @@ def partitionBy(self, *cols):
253256
"""
254257
if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
255258
cols = cols[0]
256-
self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
259+
if len(cols) > 0:
260+
self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
257261
return self
258262

259263
@since(1.4)
260-
def save(self, path=None, format=None, mode="error", **options):
264+
def save(self, path=None, format=None, mode=None, partitionBy=(), **options):
261265
"""Saves the contents of the :class:`DataFrame` to a data source.
262266
263267
The data source is specified by the ``format`` and a set of ``options``.
@@ -272,11 +276,12 @@ def save(self, path=None, format=None, mode="error", **options):
272276
* ``overwrite``: Overwrite existing data.
273277
* ``ignore``: Silently ignore this operation if data already exists.
274278
* ``error`` (default case): Throw an exception if data already exists.
279+
:param partitionBy: names of partitioning columns
275280
:param options: all other string options
276281
277282
>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
278283
"""
279-
self.mode(mode).options(**options)
284+
self.partitionBy(partitionBy).mode(mode).options(**options)
280285
if format is not None:
281286
self.format(format)
282287
if path is None:
@@ -296,7 +301,7 @@ def insertInto(self, tableName, overwrite=False):
296301
self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName)
297302

298303
@since(1.4)
299-
def saveAsTable(self, name, format=None, mode="error", **options):
304+
def saveAsTable(self, name, format=None, mode=None, partitionBy=(), **options):
300305
"""Saves the content of the :class:`DataFrame` as the specified table.
301306
302307
In the case the table already exists, behavior of this function depends on the
@@ -312,15 +317,16 @@ def saveAsTable(self, name, format=None, mode="error", **options):
312317
:param name: the table name
313318
:param format: the format used to save
314319
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
320+
:param partitionBy: names of partitioning columns
315321
:param options: all other string options
316322
"""
317-
self.mode(mode).options(**options)
323+
self.partitionBy(partitionBy).mode(mode).options(**options)
318324
if format is not None:
319325
self.format(format)
320326
self._jwrite.saveAsTable(name)
321327

322328
@since(1.4)
323-
def json(self, path, mode="error"):
329+
def json(self, path, mode=None):
324330
"""Saves the content of the :class:`DataFrame` in JSON format at the specified path.
325331
326332
:param path: the path in any Hadoop supported file system
@@ -333,10 +339,10 @@ def json(self, path, mode="error"):
333339
334340
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
335341
"""
336-
self._jwrite.mode(mode).json(path)
342+
self.mode(mode)._jwrite.json(path)
337343

338344
@since(1.4)
339-
def parquet(self, path, mode="error"):
345+
def parquet(self, path, mode=None, partitionBy=()):
340346
"""Saves the content of the :class:`DataFrame` in Parquet format at the specified path.
341347
342348
:param path: the path in any Hadoop supported file system
@@ -346,13 +352,15 @@ def parquet(self, path, mode="error"):
346352
* ``overwrite``: Overwrite existing data.
347353
* ``ignore``: Silently ignore this operation if data already exists.
348354
* ``error`` (default case): Throw an exception if data already exists.
355+
:param partitionBy: names of partitioning columns
349356
350357
>>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data'))
351358
"""
352-
self._jwrite.mode(mode).parquet(path)
359+
self.partitionBy(partitionBy).mode(mode)
360+
self._jwrite.parquet(path)
353361

354362
@since(1.4)
355-
def jdbc(self, url, table, mode="error", properties={}):
363+
def jdbc(self, url, table, mode=None, properties={}):
356364
"""Saves the content of the :class:`DataFrame` to a external database table via JDBC.
357365
358366
.. note:: Don't create too many partitions in parallel on a large cluster;\

python/pyspark/sql/tests.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,38 @@ def test_save_and_load(self):
524524

525525
shutil.rmtree(tmpPath)
526526

527+
def test_save_and_load_builder(self):
528+
df = self.df
529+
tmpPath = tempfile.mkdtemp()
530+
shutil.rmtree(tmpPath)
531+
df.write.json(tmpPath)
532+
actual = self.sqlCtx.read.json(tmpPath)
533+
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
534+
535+
schema = StructType([StructField("value", StringType(), True)])
536+
actual = self.sqlCtx.read.json(tmpPath, schema)
537+
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
538+
539+
df.write.mode("overwrite").json(tmpPath)
540+
actual = self.sqlCtx.read.json(tmpPath)
541+
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
542+
543+
df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
544+
.format("json").save(path=tmpPath)
545+
actual =\
546+
self.sqlCtx.read.format("json")\
547+
.load(path=tmpPath, noUse="this options will not be used in load.")
548+
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
549+
550+
defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
551+
"org.apache.spark.sql.parquet")
552+
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
553+
actual = self.sqlCtx.load(path=tmpPath)
554+
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
555+
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
556+
557+
shutil.rmtree(tmpPath)
558+
527559
def test_help_command(self):
528560
# Regression test for SPARK-5464
529561
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])

0 commit comments

Comments
 (0)