Skip to content

[SPARK-23765][SQL] Supports custom line separator for json datasource #20877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
multiLine=None, allowUnquotedControlChars=None):
multiLine=None, allowUnquotedControlChars=None, lineSep=None):
"""
Loads JSON files and returns the results as a :class:`DataFrame`.

Expand Down Expand Up @@ -237,6 +237,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param allowUnquotedControlChars: allows JSON Strings to contain unquoted control
characters (ASCII characters with value less than 32,
including tab and line feed characters) or not.
:param lineSep: defines the line separator that should be used for parsing. If None is
set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.

>>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
Expand All @@ -254,7 +256,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, multiLine=multiLine,
allowUnquotedControlChars=allowUnquotedControlChars)
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
Expand Down Expand Up @@ -746,7 +748,8 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options)
self._jwrite.saveAsTable(name)

@since(1.4)
def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None):
def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None,
lineSep=None):
"""Saves the content of the :class:`DataFrame` in JSON format
(`JSON Lines text format or newline-delimited JSON <http://jsonlines.org/>`_) at the
specified path.
Expand All @@ -770,12 +773,15 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``.
:param lineSep: defines the line separator that should be used for writing. If None is
set, it uses the default value, ``\\n``.

>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
self._set_opts(
compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat)
compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat,
lineSep=lineSep)
self._jwrite.json(path)

@since(1.4)
Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/sql/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
multiLine=None, allowUnquotedControlChars=None):
multiLine=None, allowUnquotedControlChars=None, lineSep=None):
"""
Loads a JSON file stream and returns the results as a :class:`DataFrame`.

Expand Down Expand Up @@ -470,6 +470,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param allowUnquotedControlChars: allows JSON Strings to contain unquoted control
characters (ASCII characters with value less than 32,
including tab and line feed characters) or not.
:param lineSep: defines the line separator that should be used for parsing. If None is
set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.

>>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)
>>> json_sdf.isStreaming
Expand All @@ -484,7 +486,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, multiLine=multiLine,
allowUnquotedControlChars=allowUnquotedControlChars)
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep)
if isinstance(path, basestring):
return self._df(self._jreader.json(path))
else:
Expand Down
17 changes: 17 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,23 @@ def test_multiline_json(self):
multiLine=True)
self.assertEqual(people1.collect(), people_array.collect())

def test_linesep_json(self):
df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",")
expected = [Row(_corrupt_record=None, name=u'Michael'),
Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None),
Row(_corrupt_record=u' "age":19}\n', name=None)]
self.assertEqual(df.collect(), expected)

tpath = tempfile.mkdtemp()
shutil.rmtree(tpath)
try:
df = self.spark.read.json("python/test_support/sql/people.json")
df.write.json(tpath, lineSep="!!")
readback = self.spark.read.json(tpath, lineSep="!!")
self.assertEqual(readback.collect(), df.collect())
finally:
shutil.rmtree(tpath)

def test_multiline_csv(self):
ages_newlines = self.spark.read.csv(
"python/test_support/sql/ages_newlines.csv", multiLine=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.json

import java.nio.charset.StandardCharsets
import java.util.{Locale, TimeZone}

import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
Expand Down Expand Up @@ -85,6 +86,16 @@ private[sql] class JSONOptions(

val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)

val lineSeparator: Option[String] = parameters.get("lineSep").map { sep =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be private?

require(sep.nonEmpty, "'lineSep' cannot be an empty string.")
sep
}
// Note that the option 'lineSep' uses a different default value in read and write.
val lineSeparatorInRead: Option[Array[Byte]] =
lineSeparator.map(_.getBytes(StandardCharsets.UTF_8))
// Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8.
val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n")

/** Sets config options on a Jackson [[JsonFactory]]. */
def setJacksonOptions(factory: JsonFactory): Unit = {
factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.json

import java.io.Writer
import java.nio.charset.StandardCharsets

import com.fasterxml.jackson.core._

Expand Down Expand Up @@ -74,6 +75,8 @@ private[sql] class JacksonGenerator(

private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)

private val lineSeparator: String = options.lineSeparatorInWrite

private def makeWriter(dataType: DataType): ValueWriter = dataType match {
case NullType =>
(row: SpecializedGetters, ordinal: Int) =>
Expand Down Expand Up @@ -251,5 +254,8 @@ private[sql] class JacksonGenerator(
mapType = dataType.asInstanceOf[MapType]))
}

def writeLineEnding(): Unit = gen.writeRaw('\n')
def writeLineEnding(): Unit = {
// Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant here:

def createOutputStreamWriter(
context: JobContext,
file: Path,
charset: Charset = StandardCharsets.UTF_8): OutputStreamWriter = {
new OutputStreamWriter(createOutputStream(context, file), charset)
}

gen.writeRaw(lineSeparator)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `java.text.SimpleDateFormat`. This applies to timestamp type.</li>
* <li>`multiLine` (default `false`): parse one record, which may span multiple lines,
* per file</li>
* <li>`lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
* that should be used for parsing.</li>
* </ul>
*
* @since 2.0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* <li>`timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that
* indicates a timestamp format. Custom date formats follow the formats at
* `java.text.SimpleDateFormat`. This applies to timestamp type.</li>
* <li>`lineSep` (default `\n`): defines the line separator that should
* be used for writing.</li>
* </ul>
*
* @since 1.4.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions}
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -92,7 +92,8 @@ object TextInputJsonDataSource extends JsonDataSource {
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): StructType = {
val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths)
val json: Dataset[String] = createBaseDataset(
sparkSession, inputPaths, parsedOptions.lineSeparator)
inferFromDataset(json, parsedOptions)
}

Expand All @@ -104,13 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource {

private def createBaseDataset(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus]): Dataset[String] = {
inputPaths: Seq[FileStatus],
lineSeparator: Option[String]): Dataset[String] = {
val textOptions = lineSeparator.map { lineSep =>
Map(TextOptions.LINE_SEPARATOR -> lineSep)
}.getOrElse(Map.empty[String, String])

val paths = inputPaths.map(_.getPath.toString)
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
className = classOf[TextFileFormat].getName
className = classOf[TextFileFormat].getName,
options = textOptions
).resolveRelation(checkFilesExist = false))
.select("value").as(Encoders.STRING)
}
Expand All @@ -120,7 +127,7 @@ object TextInputJsonDataSource extends JsonDataSource {
file: PartitionedFile,
parser: JacksonParser,
schema: StructType): Iterator[InternalRow] = {
val linesReader = new HadoopFileLinesReader(file, conf)
val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
val safeParser = new FailureSafeParser[Text](
input => parser.parse(input, CreateJacksonParser.text, textToUTF8String),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti
lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8))
}

private[text] object TextOptions {
private[datasources] object TextOptions {
val COMPRESSION = "compression"
val WHOLETEXT = "wholetext"
val LINE_SEPARATOR = "lineSep"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* `java.text.SimpleDateFormat`. This applies to timestamp type.</li>
* <li>`multiLine` (default `false`): parse one record, which may span multiple lines,
* per file</li>
* <li>`lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test case for testing the default covers \r, \r\n and \n?

* that should be used for parsing.</li>
* </ul>
*
* @since 2.0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.json

import java.io.{File, StringWriter}
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.sql.{Date, Timestamp}
import java.util.Locale

Expand All @@ -27,7 +28,7 @@ import org.apache.hadoop.fs.{Path, PathFilter}
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.GzipCodec

import org.apache.spark.SparkException
import org.apache.spark.{SparkException, TestUtils}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{functions => F, _}
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
Expand Down Expand Up @@ -2063,4 +2064,67 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
)
}
}

def testLineSeparator(lineSep: String): Unit = {
test(s"SPARK-21289: Support line separator - lineSep: '$lineSep'") {
// Read
val data =
s"""
| {"f":
|"a", "f0": 1}$lineSep{"f":
|
|"c", "f0": 2}$lineSep{"f": "d", "f0": 3}
""".stripMargin
val dataWithTrailingLineSep = s"$data$lineSep"

Seq(data, dataWithTrailingLineSep).foreach { lines =>
withTempPath { path =>
Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8))
val df = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath)
val expectedSchema =
StructType(StructField("f", StringType) :: StructField("f0", LongType) :: Nil)
checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF())
assert(df.schema === expectedSchema)
}
}

// Write
withTempPath { path =>
Seq("a", "b", "c").toDF("value").coalesce(1)
.write.option("lineSep", lineSep).json(path.getAbsolutePath)
val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head
val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8)
assert(
readBack === s"""{"value":"a"}$lineSep{"value":"b"}$lineSep{"value":"c"}$lineSep""")
}

// Roundtrip
withTempPath { path =>
val df = Seq("a", "b", "c").toDF()
df.write.option("lineSep", lineSep).json(path.getAbsolutePath)
val readBack = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath)
checkAnswer(df, readBack)
}
}
}

// scalastyle:off nonascii
Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString, "아").foreach { lineSep =>
testLineSeparator(lineSep)
}
// scalastyle:on nonascii

test("""SPARK-21289: Support line separator - default value \r, \r\n and \n""") {
val data =
"{\"f\": \"a\", \"f0\": 1}\r{\"f\": \"c\", \"f0\": 2}\r\n{\"f\": \"d\", \"f0\": 3}\n"

withTempPath { path =>
Files.write(path.toPath, data.getBytes(StandardCharsets.UTF_8))
val df = spark.read.json(path.getAbsolutePath)
val expectedSchema =
StructType(StructField("f", StringType) :: StructField("f0", LongType) :: Nil)
checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF())
assert(df.schema === expectedSchema)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,11 @@ class TextSuite extends QueryTest with SharedSQLContext {
}
}

Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString).foreach { lineSep =>
// scalastyle:off nonascii
Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString, "아").foreach { lineSep =>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly unrelated but I just added. I am fine with reverting this out if it bugs anyone.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, "아" means just "ah" without any meaning ..

testLineSeparator(lineSep)
}
// scalastyle:on nonascii

private def testFile: String = {
Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString
Expand Down