Skip to content

Commit 8e8d117

Browse files
MaxGekkHyukjinKwon
authored andcommitted
[SPARK-26108][SQL] Support custom lineSep in CSV datasource
## What changes were proposed in this pull request? In the PR, I propose new options for CSV datasource - `lineSep` similar to Text and JSON datasource. The option allows to specify custom line separator of maximum length of 2 characters (because of a restriction in `uniVocity` parser). New option can be used in reading and writing CSV files. ## How was this patch tested? Added a few tests with custom `lineSep` for enabled/disabled `multiLine` in read as well as tests in write. Also I added roundtrip tests. Closes #23080 from MaxGekk/csv-line-sep. Lead-authored-by: Maxim Gekk <max.gekk@gmail.com> Co-authored-by: Maxim Gekk <maxim.gekk@databricks.com> Signed-off-by: hyukjinkwon <gurwls223@apache.org>
1 parent 466d011 commit 8e8d117

File tree

8 files changed

+151
-10
lines changed

8 files changed

+151
-10
lines changed

python/pyspark/sql/readwriter.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
353353
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
354354
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
355355
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
356-
samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None):
356+
samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None, lineSep=None):
357357
r"""Loads a CSV file and returns the result as a :class:`DataFrame`.
358358
359359
This function will go through the input once to determine the input schema if
@@ -453,6 +453,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
453453
:param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
454454
it uses the default value, ``en-US``. For instance, ``locale`` is used while
455455
parsing dates and timestamps.
456+
:param lineSep: defines the line separator that should be used for parsing. If None is
457+
set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
458+
Maximum length is 1 character.
456459
457460
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
458461
>>> df.dtypes
@@ -472,7 +475,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
472475
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
473476
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
474477
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio,
475-
enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale)
478+
enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale, lineSep=lineSep)
476479
if isinstance(path, basestring):
477480
path = [path]
478481
if type(path) == list:
@@ -868,7 +871,7 @@ def text(self, path, compression=None, lineSep=None):
868871
def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None,
869872
header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None,
870873
timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None,
871-
charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None):
874+
charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None, lineSep=None):
872875
r"""Saves the content of the :class:`DataFrame` in CSV format at the specified path.
873876
874877
:param path: the path in any Hadoop supported file system
@@ -922,6 +925,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
922925
the default UTF-8 charset will be used.
923926
:param emptyValue: sets the string representation of an empty value. If None is set, it uses
924927
the default value, ``""``.
928+
:param lineSep: defines the line separator that should be used for writing. If None is
929+
set, it uses the default value, ``\\n``. Maximum length is 1 character.
925930
926931
>>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))
927932
"""
@@ -932,7 +937,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
932937
ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
933938
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace,
934939
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping,
935-
encoding=encoding, emptyValue=emptyValue)
940+
encoding=encoding, emptyValue=emptyValue, lineSep=lineSep)
936941
self._jwrite.csv(path)
937942

938943
@since(1.5)

python/pyspark/sql/streaming.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
576576
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
577577
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
578578
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
579-
enforceSchema=None, emptyValue=None, locale=None):
579+
enforceSchema=None, emptyValue=None, locale=None, lineSep=None):
580580
r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
581581
582582
This function will go through the input once to determine the input schema if
@@ -675,6 +675,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
675675
:param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
676676
it uses the default value, ``en-US``. For instance, ``locale`` is used while
677677
parsing dates and timestamps.
678+
:param lineSep: defines the line separator that should be used for parsing. If None is
679+
set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
680+
Maximum length is 1 character.
678681
679682
>>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema)
680683
>>> csv_sdf.isStreaming
@@ -692,7 +695,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
692695
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
693696
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
694697
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema,
695-
emptyValue=emptyValue, locale=locale)
698+
emptyValue=emptyValue, locale=locale, lineSep=lineSep)
696699
if isinstance(path, basestring):
697700
return self._df(self._jreader.csv(path))
698701
else:

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,20 @@ class CSVOptions(
192192
*/
193193
val emptyValueInWrite = emptyValue.getOrElse("\"\"")
194194

195+
/**
196+
* A string between two consecutive JSON records.
197+
*/
198+
val lineSeparator: Option[String] = parameters.get("lineSep").map { sep =>
199+
require(sep.nonEmpty, "'lineSep' cannot be an empty string.")
200+
require(sep.length == 1, "'lineSep' can contain only 1 character.")
201+
sep
202+
}
203+
204+
val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep =>
205+
lineSep.getBytes(charset)
206+
}
207+
val lineSeparatorInWrite: Option[String] = lineSeparator
208+
195209
def asWriterSettings: CsvWriterSettings = {
196210
val writerSettings = new CsvWriterSettings()
197211
val format = writerSettings.getFormat
@@ -200,6 +214,8 @@ class CSVOptions(
200214
format.setQuoteEscape(escape)
201215
charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping)
202216
format.setComment(comment)
217+
lineSeparatorInWrite.foreach(format.setLineSeparator)
218+
203219
writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite)
204220
writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite)
205221
writerSettings.setNullValue(nullValue)
@@ -216,8 +232,10 @@ class CSVOptions(
216232
format.setDelimiter(delimiter)
217233
format.setQuote(quote)
218234
format.setQuoteEscape(escape)
235+
lineSeparator.foreach(format.setLineSeparator)
219236
charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping)
220237
format.setComment(comment)
238+
221239
settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceInRead)
222240
settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceInRead)
223241
settings.setReadInputOnSeparateThread(false)
@@ -227,7 +245,10 @@ class CSVOptions(
227245
settings.setEmptyValue(emptyValueInRead)
228246
settings.setMaxCharsPerColumn(maxCharsPerColumn)
229247
settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER)
230-
settings.setLineSeparatorDetectionEnabled(multiLine == true)
248+
settings.setLineSeparatorDetectionEnabled(lineSeparatorInRead.isEmpty && multiLine)
249+
lineSeparatorInRead.foreach { _ =>
250+
settings.setNormalizeLineEndingsWithinQuotes(!multiLine)
251+
}
231252

232253
settings
233254
}

sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
609609
* <li>`multiLine` (default `false`): parse one record, which may span multiple lines.</li>
610610
* <li>`locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format.
611611
* For instance, this is used while parsing dates and timestamps.</li>
612+
* <li>`lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
613+
* that should be used for parsing. Maximum length is 1 character.</li>
612614
* </ul>
613615
*
614616
* @since 2.0.0

sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
658658
* whitespaces from values being written should be skipped.</li>
659659
* <li>`ignoreTrailingWhiteSpace` (default `true`): a flag indicating defines whether or not
660660
* trailing whitespaces from values being written should be skipped.</li>
661+
* <li>`lineSep` (default `\n`): defines the line separator that should be used for writing.
662+
* Maximum length is 1 character.</li>
661663
* </ul>
662664
*
663665
* @since 2.0.0

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ object TextInputCSVDataSource extends CSVDataSource {
9595
headerChecker: CSVHeaderChecker,
9696
requiredSchema: StructType): Iterator[InternalRow] = {
9797
val lines = {
98-
val linesReader = new HadoopFileLinesReader(file, conf)
98+
val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
9999
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))
100100
linesReader.map { line =>
101101
new String(line.getBytes, 0, line.getLength, parser.options.charset)

sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
377377
* <li>`multiLine` (default `false`): parse one record, which may span multiple lines.</li>
378378
* <li>`locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format.
379379
* For instance, this is used while parsing dates and timestamps.</li>
380+
* <li>`lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
381+
* that should be used for parsing. Maximum length is 1 character.</li>
380382
* </ul>
381383
*
382384
* @since 2.0.0

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution.datasources.csv
1919

2020
import java.io.File
21-
import java.nio.charset.{Charset, UnsupportedCharsetException}
21+
import java.nio.charset.{Charset, StandardCharsets, UnsupportedCharsetException}
2222
import java.nio.file.Files
2323
import java.sql.{Date, Timestamp}
2424
import java.text.SimpleDateFormat
@@ -33,7 +33,7 @@ import org.apache.hadoop.io.compress.GzipCodec
3333
import org.apache.log4j.{AppenderSkeleton, LogManager}
3434
import org.apache.log4j.spi.LoggingEvent
3535

36-
import org.apache.spark.SparkException
36+
import org.apache.spark.{SparkException, TestUtils}
3737
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
3838
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3939
import org.apache.spark.sql.internal.SQLConf
@@ -1880,4 +1880,110 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
18801880
}
18811881
}
18821882
}
1883+
1884+
test("""Support line separator - default value \r, \r\n and \n""") {
1885+
val data = "\"a\",1\r\"c\",2\r\n\"d\",3\n"
1886+
1887+
withTempPath { path =>
1888+
Files.write(path.toPath, data.getBytes(StandardCharsets.UTF_8))
1889+
val df = spark.read.option("inferSchema", true).csv(path.getAbsolutePath)
1890+
val expectedSchema =
1891+
StructType(StructField("_c0", StringType) :: StructField("_c1", IntegerType) :: Nil)
1892+
checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF())
1893+
assert(df.schema === expectedSchema)
1894+
}
1895+
}
1896+
1897+
def testLineSeparator(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = {
1898+
test(s"Support line separator in ${encoding} #${id}") {
1899+
// Read
1900+
val data =
1901+
s""""a",1$lineSep
1902+
|c,2$lineSep"
1903+
|d",3""".stripMargin
1904+
val dataWithTrailingLineSep = s"$data$lineSep"
1905+
1906+
Seq(data, dataWithTrailingLineSep).foreach { lines =>
1907+
withTempPath { path =>
1908+
Files.write(path.toPath, lines.getBytes(encoding))
1909+
val schema = StructType(StructField("_c0", StringType)
1910+
:: StructField("_c1", LongType) :: Nil)
1911+
1912+
val expected = Seq(("a", 1), ("\nc", 2), ("\nd", 3))
1913+
.toDF("_c0", "_c1")
1914+
Seq(false, true).foreach { multiLine =>
1915+
val reader = spark
1916+
.read
1917+
.option("lineSep", lineSep)
1918+
.option("multiLine", multiLine)
1919+
.option("encoding", encoding)
1920+
val df = if (inferSchema) {
1921+
reader.option("inferSchema", true).csv(path.getAbsolutePath)
1922+
} else {
1923+
reader.schema(schema).csv(path.getAbsolutePath)
1924+
}
1925+
checkAnswer(df, expected)
1926+
}
1927+
}
1928+
}
1929+
1930+
// Write
1931+
withTempPath { path =>
1932+
Seq("a", "b", "c").toDF("value").coalesce(1)
1933+
.write
1934+
.option("lineSep", lineSep)
1935+
.option("encoding", encoding)
1936+
.csv(path.getAbsolutePath)
1937+
val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head
1938+
val readBack = new String(Files.readAllBytes(partFile.toPath), encoding)
1939+
assert(
1940+
readBack === s"a${lineSep}b${lineSep}c${lineSep}")
1941+
}
1942+
1943+
// Roundtrip
1944+
withTempPath { path =>
1945+
val df = Seq("a", "b", "c").toDF()
1946+
df.write
1947+
.option("lineSep", lineSep)
1948+
.option("encoding", encoding)
1949+
.csv(path.getAbsolutePath)
1950+
val readBack = spark
1951+
.read
1952+
.option("lineSep", lineSep)
1953+
.option("encoding", encoding)
1954+
.csv(path.getAbsolutePath)
1955+
checkAnswer(df, readBack)
1956+
}
1957+
}
1958+
}
1959+
1960+
// scalastyle:off nonascii
1961+
List(
1962+
(0, "|", "UTF-8", false),
1963+
(1, "^", "UTF-16BE", true),
1964+
(2, ":", "ISO-8859-1", true),
1965+
(3, "!", "UTF-32LE", false),
1966+
(4, 0x1E.toChar.toString, "UTF-8", true),
1967+
(5, "", "UTF-32BE", false),
1968+
(6, "у", "CP1251", true),
1969+
(8, "\r", "UTF-16LE", true),
1970+
(9, "\u000d", "UTF-32BE", false),
1971+
(10, "=", "US-ASCII", false),
1972+
(11, "$", "utf-32le", true)
1973+
).foreach { case (testNum, sep, encoding, inferSchema) =>
1974+
testLineSeparator(sep, encoding, inferSchema, testNum)
1975+
}
1976+
// scalastyle:on nonascii
1977+
1978+
test("lineSep restrictions") {
1979+
val errMsg1 = intercept[IllegalArgumentException] {
1980+
spark.read.option("lineSep", "").csv(testFile(carsFile)).collect
1981+
}.getMessage
1982+
assert(errMsg1.contains("'lineSep' cannot be an empty string"))
1983+
1984+
val errMsg2 = intercept[IllegalArgumentException] {
1985+
spark.read.option("lineSep", "123").csv(testFile(carsFile)).collect
1986+
}.getMessage
1987+
assert(errMsg2.contains("'lineSep' can contain only 1 character"))
1988+
}
18831989
}

0 commit comments

Comments
 (0)