Skip to content

Commit e9af946

Browse files
MaxGekkHyukjinKwon
andcommitted
[SPARK-25393][SQL] Adding new function from_csv()
## What changes were proposed in this pull request? The PR adds new function `from_csv()` similar to `from_json()` to parse columns with CSV strings. I added the following methods: ```Scala def from_csv(e: Column, schema: StructType, options: Map[String, String]): Column ``` and this signature to call it from Python, R and Java: ```Scala def from_csv(e: Column, schema: String, options: java.util.Map[String, String]): Column ``` ## How was this patch tested? Added new test suites `CsvExpressionsSuite`, `CsvFunctionsSuite` and sql tests. Closes #22379 from MaxGekk/from_csv. Lead-authored-by: Maxim Gekk <maxim.gekk@databricks.com> Co-authored-by: Maxim Gekk <max.gekk@gmail.com> Co-authored-by: Hyukjin Kwon <gurwls223@gmail.com> Co-authored-by: hyukjinkwon <gurwls223@apache.org> Signed-off-by: hyukjinkwon <gurwls223@apache.org>
1 parent 9d4dd79 commit e9af946

File tree

30 files changed

+714
-113
lines changed

30 files changed

+714
-113
lines changed

R/pkg/NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ exportMethods("%<=>%",
274274
"floor",
275275
"format_number",
276276
"format_string",
277+
"from_csv",
277278
"from_json",
278279
"from_unixtime",
279280
"from_utc_timestamp",

R/pkg/R/functions.R

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ NULL
188188
#' \item \code{to_json}: it is the column containing the struct, array of the structs,
189189
#' the map or array of maps.
190190
#' \item \code{from_json}: it is the column containing the JSON string.
191+
#' \item \code{from_csv}: it is the column containing the CSV string.
191192
#' }
192193
#' @param y Column to compute on.
193194
#' @param value A value to compute on.
@@ -196,6 +197,13 @@ NULL
196197
#' \item \code{array_position}: a value to locate in the given array.
197198
#' \item \code{array_remove}: a value to remove in the given array.
198199
#' }
200+
#' @param schema
201+
#' \itemize{
202+
#' \item \code{from_json}: a structType object to use as the schema to use
203+
#' when parsing the JSON string. Since Spark 2.3, the DDL-formatted string is
204+
#' also supported for the schema.
205+
#' \item \code{from_csv}: a DDL-formatted string
206+
#' }
199207
#' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains
200208
#' additional named properties to control how it is converted, accepts the same
201209
#' options as the JSON data source. Additionally \code{to_json} supports the "pretty"
@@ -2165,8 +2173,6 @@ setMethod("date_format", signature(y = "Column", x = "character"),
21652173
#' to \code{TRUE}. If the string is unparseable, the Column will contain the value NA.
21662174
#'
21672175
#' @rdname column_collection_functions
2168-
#' @param schema a structType object to use as the schema to use when parsing the JSON string.
2169-
#' Since Spark 2.3, the DDL-formatted string is also supported for the schema.
21702176
#' @param as.json.array indicating if input string is JSON array of objects or a single object.
21712177
#' @aliases from_json from_json,Column,characterOrstructType-method
21722178
#' @examples
@@ -2203,6 +2209,36 @@ setMethod("from_json", signature(x = "Column", schema = "characterOrstructType")
22032209
column(jc)
22042210
})
22052211

2212+
#' @details
2213+
#' \code{from_csv}: Parses a column containing a CSV string into a Column of \code{structType}
2214+
#' with the specified \code{schema}.
2215+
#' If the string is unparseable, the Column will contain the value NA.
2216+
#'
2217+
#' @rdname column_collection_functions
2218+
#' @aliases from_csv from_csv,Column,character-method
2219+
#' @examples
2220+
#'
2221+
#' \dontrun{
2222+
#' df <- sql("SELECT 'Amsterdam,2018' as csv")
2223+
#' schema <- "city STRING, year INT"
2224+
#' head(select(df, from_csv(df$csv, schema)))}
2225+
#' @note from_csv since 3.0.0
2226+
setMethod("from_csv", signature(x = "Column", schema = "characterOrColumn"),
2227+
function(x, schema, ...) {
2228+
if (class(schema) == "Column") {
2229+
jschema <- schema@jc
2230+
} else if (is.character(schema)) {
2231+
jschema <- callJStatic("org.apache.spark.sql.functions", "lit", schema)
2232+
} else {
2233+
stop("schema argument should be a column or character")
2234+
}
2235+
options <- varargsToStrEnv(...)
2236+
jc <- callJStatic("org.apache.spark.sql.functions",
2237+
"from_csv",
2238+
x@jc, jschema, options)
2239+
column(jc)
2240+
})
2241+
22062242
#' @details
22072243
#' \code{from_utc_timestamp}: This is a common function for databases supporting TIMESTAMP WITHOUT
22082244
#' TIMEZONE. This function takes a timestamp which is timezone-agnostic, and interprets it as a

R/pkg/R/generics.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,10 @@ setGeneric("format_string", function(format, x, ...) { standardGeneric("format_s
984984
#' @name NULL
985985
setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") })
986986

987+
#' @rdname column_collection_functions
988+
#' @name NULL
989+
setGeneric("from_csv", function(x, schema, ...) { standardGeneric("from_csv") })
990+
987991
#' @rdname column_datetime_functions
988992
#' @name NULL
989993
setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") })

R/pkg/tests/fulltests/test_sparkSQL.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,13 @@ test_that("column functions", {
16471647
expect_equal(collect(select(df, bround(df$x, 0)))[[1]][1], 2)
16481648
expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4)
16491649

1650+
# Test from_csv()
1651+
df <- as.DataFrame(list(list("col" = "1")))
1652+
c <- collect(select(df, alias(from_csv(df$col, "a INT"), "csv")))
1653+
expect_equal(c[[1]][[1]]$a, 1)
1654+
c <- collect(select(df, alias(from_csv(df$col, lit("a INT")), "csv")))
1655+
expect_equal(c[[1]][[1]]$a, 1)
1656+
16501657
# Test to_json(), from_json()
16511658
df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people")
16521659
j <- collect(select(df, alias(to_json(df$people), "json")))

python/pyspark/sql/functions.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@
2525
if sys.version < "3":
2626
from itertools import imap as map
2727

28+
if sys.version >= '3':
29+
basestring = str
30+
2831
from pyspark import since, SparkContext
2932
from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
30-
from pyspark.sql.column import Column, _to_java_column, _to_seq
33+
from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal
3134
from pyspark.sql.dataframe import DataFrame
3235
from pyspark.sql.types import StringType, DataType
3336
# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
@@ -2678,6 +2681,38 @@ def sequence(start, stop, step=None):
26782681
_to_java_column(start), _to_java_column(stop), _to_java_column(step)))
26792682

26802683

2684+
@ignore_unicode_prefix
2685+
@since(3.0)
2686+
def from_csv(col, schema, options={}):
2687+
"""
2688+
Parses a column containing a CSV string to a row with the specified schema.
2689+
Returns `null`, in the case of an unparseable string.
2690+
2691+
:param col: string column in CSV format
2692+
:param schema: a string with schema in DDL format to use when parsing the CSV column.
2693+
:param options: options to control parsing. accepts the same options as the CSV datasource
2694+
2695+
>>> data = [(1, '1')]
2696+
>>> df = spark.createDataFrame(data, ("key", "value"))
2697+
>>> df.select(from_csv(df.value, "a INT").alias("csv")).collect()
2698+
[Row(csv=Row(a=1))]
2699+
>>> df = spark.createDataFrame(data, ("key", "value"))
2700+
>>> df.select(from_csv(df.value, lit("a INT")).alias("csv")).collect()
2701+
[Row(csv=Row(a=1))]
2702+
"""
2703+
2704+
sc = SparkContext._active_spark_context
2705+
if isinstance(schema, basestring):
2706+
schema = _create_column_from_literal(schema)
2707+
elif isinstance(schema, Column):
2708+
schema = _to_java_column(schema)
2709+
else:
2710+
raise TypeError("schema argument should be a column or string")
2711+
2712+
jc = sc._jvm.functions.from_csv(_to_java_column(col), schema, options)
2713+
return Column(jc)
2714+
2715+
26812716
# ---------------------------- User Defined Function ----------------------------------
26822717

26832718
class PandasUDFType(object):

sql/catalyst/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@
103103
<groupId>commons-codec</groupId>
104104
<artifactId>commons-codec</artifactId>
105105
</dependency>
106+
<dependency>
107+
<groupId>com.univocity</groupId>
108+
<artifactId>univocity-parsers</artifactId>
109+
<version>2.7.3</version>
110+
<type>jar</type>
111+
</dependency>
106112
</dependencies>
107113
<build>
108114
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,10 @@ object FunctionRegistry {
520520
castAlias("date", DateType),
521521
castAlias("timestamp", TimestampType),
522522
castAlias("binary", BinaryType),
523-
castAlias("string", StringType)
523+
castAlias("string", StringType),
524+
525+
// csv
526+
expression[CsvToStructs]("from_csv")
524527
)
525528

526529
val builtin: SimpleFunctionRegistry = {
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.csv
19+
20+
object CSVExprUtils {
21+
/**
22+
* Filter ignorable rows for CSV iterator (lines empty and starting with `comment`).
23+
* This is currently being used in CSV reading path and CSV schema inference.
24+
*/
25+
def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] = {
26+
iter.filter { line =>
27+
line.trim.nonEmpty && !line.startsWith(options.comment.toString)
28+
}
29+
}
30+
31+
def skipComments(iter: Iterator[String], options: CSVOptions): Iterator[String] = {
32+
if (options.isCommentSet) {
33+
val commentPrefix = options.comment.toString
34+
iter.dropWhile { line =>
35+
line.trim.isEmpty || line.trim.startsWith(commentPrefix)
36+
}
37+
} else {
38+
iter.dropWhile(_.trim.isEmpty)
39+
}
40+
}
41+
42+
/**
43+
* Extracts header and moves iterator forward so that only data remains in it
44+
*/
45+
def extractHeader(iter: Iterator[String], options: CSVOptions): Option[String] = {
46+
val nonEmptyLines = skipComments(iter, options)
47+
if (nonEmptyLines.hasNext) {
48+
Some(nonEmptyLines.next())
49+
} else {
50+
None
51+
}
52+
}
53+
54+
/**
55+
* Helper method that converts string representation of a character to actual character.
56+
* It handles some Java escaped strings and throws exception if given string is longer than one
57+
* character.
58+
*/
59+
@throws[IllegalArgumentException]
60+
def toChar(str: String): Char = {
61+
(str: Seq[Char]) match {
62+
case Seq() => throw new IllegalArgumentException("Delimiter cannot be empty string")
63+
case Seq('\\') => throw new IllegalArgumentException("Single backslash is prohibited." +
64+
" It has special meaning as beginning of an escape sequence." +
65+
" To get the backslash character, pass a string with two backslashes as the delimiter.")
66+
case Seq(c) => c
67+
case Seq('\\', 't') => '\t'
68+
case Seq('\\', 'r') => '\r'
69+
case Seq('\\', 'b') => '\b'
70+
case Seq('\\', 'f') => '\f'
71+
// In case user changes quote char and uses \" as delimiter in options
72+
case Seq('\\', '\"') => '\"'
73+
case Seq('\\', '\'') => '\''
74+
case Seq('\\', '\\') => '\\'
75+
case _ if str == """\u0000""" => '\u0000'
76+
case Seq('\\', _) =>
77+
throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str")
78+
case _ =>
79+
throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str")
80+
}
81+
}
82+
}
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.sql.execution.datasources.csv
18+
package org.apache.spark.sql.catalyst.csv
1919

2020
import com.univocity.parsers.csv.CsvParser
2121

@@ -123,7 +123,7 @@ class CSVHeaderChecker(
123123
// Note: if there are only comments in the first block, the header would probably
124124
// be not extracted.
125125
if (options.headerFlag && isStartOfFile) {
126-
CSVUtils.extractHeader(lines, options).foreach { header =>
126+
CSVExprUtils.extractHeader(lines, options).foreach { header =>
127127
checkHeaderColumnNames(tokenizer.parseLine(header))
128128
}
129129
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala renamed to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.sql.execution.datasources.csv
18+
package org.apache.spark.sql.catalyst.csv
1919

2020
import java.nio.charset.StandardCharsets
2121
import java.util.{Locale, TimeZone}
@@ -83,7 +83,7 @@ class CSVOptions(
8383
}
8484
}
8585

86-
val delimiter = CSVUtils.toChar(
86+
val delimiter = CSVExprUtils.toChar(
8787
parameters.getOrElse("sep", parameters.getOrElse("delimiter", ",")))
8888
val parseMode: ParseMode =
8989
parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode)
Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.sql.execution.datasources.csv
18+
package org.apache.spark.sql.catalyst.csv
1919

2020
import java.io.InputStream
2121
import java.math.BigDecimal
@@ -28,8 +28,7 @@ import com.univocity.parsers.csv.CsvParser
2828
import org.apache.spark.internal.Logging
2929
import org.apache.spark.sql.catalyst.InternalRow
3030
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
31-
import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils}
32-
import org.apache.spark.sql.execution.datasources.FailureSafeParser
31+
import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser}
3332
import org.apache.spark.sql.types._
3433
import org.apache.spark.unsafe.types.UTF8String
3534

@@ -264,7 +263,7 @@ class UnivocityParser(
264263
}
265264
}
266265

267-
private[csv] object UnivocityParser {
266+
private[sql] object UnivocityParser {
268267

269268
/**
270269
* Parses a stream that contains CSV strings and turns it into an iterator of tokens.
@@ -339,7 +338,7 @@ private[csv] object UnivocityParser {
339338

340339
val options = parser.options
341340

342-
val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(lines, options)
341+
val filteredLines: Iterator[String] = CSVExprUtils.filterCommentAndEmpty(lines, options)
343342

344343
val safeParser = new FailureSafeParser[String](
345344
input => Seq(parser.parse(input)),
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.sql.AnalysisException
21+
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
22+
import org.apache.spark.sql.types.{MapType, StringType, StructType}
23+
24+
object ExprUtils {
25+
26+
def evalSchemaExpr(exp: Expression): StructType = exp match {
27+
case Literal(s, StringType) => StructType.fromDDL(s.toString)
28+
case e => throw new AnalysisException(
29+
s"Schema should be specified in DDL format as a string literal instead of ${e.sql}")
30+
}
31+
32+
def convertToMapData(exp: Expression): Map[String, String] = exp match {
33+
case m: CreateMap
34+
if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) =>
35+
val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData]
36+
ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) =>
37+
key.toString -> value.toString
38+
}
39+
case m: CreateMap =>
40+
throw new AnalysisException(
41+
s"A type of keys and values in map() must be string, but got ${m.dataType.catalogString}")
42+
case _ =>
43+
throw new AnalysisException("Must use a map() function for options")
44+
}
45+
}

0 commit comments

Comments
 (0)