Skip to content

Commit b26e49e

Browse files
authored
Merge pull request #10 from HyukjinKwon/address-from_csv
Address from csv
2 parents 88e3b10 + a32bbcb commit b26e49e

File tree

12 files changed

+50
-25
lines changed

12 files changed

+50
-25
lines changed

R/pkg/R/functions.R

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2223,12 +2223,19 @@ setMethod("from_json", signature(x = "Column", schema = "characterOrstructType")
22232223
#' schema <- "city STRING, year INT"
22242224
#' head(select(df, from_csv(df$csv, schema)))}
22252225
#' @note from_csv since 3.0.0
2226-
setMethod("from_csv", signature(x = "Column", schema = "character"),
2226+
setMethod("from_csv", signature(x = "Column", schema = "characterOrColumn"),
22272227
function(x, schema, ...) {
2228+
if (class(schema) == "Column") {
2229+
jschema <- schema@jc
2230+
} else if (is.character(schema)) {
2231+
jschema <- callJStatic("org.apache.spark.sql.functions", "lit", schema)
2232+
} else {
2233+
stop("schema argument should be a column or character")
2234+
}
22282235
options <- varargsToStrEnv(...)
22292236
jc <- callJStatic("org.apache.spark.sql.functions",
22302237
"from_csv",
2231-
x@jc, schema, options)
2238+
x@jc, jschema, options)
22322239
column(jc)
22332240
})
22342241

R/pkg/tests/fulltests/test_sparkSQL.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,6 +1651,8 @@ test_that("column functions", {
16511651
df <- as.DataFrame(list(list("col" = "1")))
16521652
c <- collect(select(df, alias(from_csv(df$col, "a INT"), "csv")))
16531653
expect_equal(c[[1]][[1]]$a, 1)
1654+
c <- collect(select(df, alias(from_csv(df$col, lit("a INT")), "csv")))
1655+
expect_equal(c[[1]][[1]]$a, 1)
16541656

16551657
# Test to_json(), from_json()
16561658
df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people")

python/pyspark/sql/functions.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@
2525
if sys.version < "3":
2626
from itertools import imap as map
2727

28+
if sys.version >= '3':
29+
basestring = str
30+
2831
from pyspark import since, SparkContext
2932
from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
30-
from pyspark.sql.column import Column, _to_java_column, _to_seq
33+
from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal
3134
from pyspark.sql.dataframe import DataFrame
3235
from pyspark.sql.types import StringType, DataType
3336
# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
@@ -2693,9 +2696,19 @@ def from_csv(col, schema, options={}):
26932696
>>> df = spark.createDataFrame(data, ("key", "value"))
26942697
>>> df.select(from_csv(df.value, "a INT").alias("csv")).collect()
26952698
[Row(csv=Row(a=1))]
2699+
>>> df = spark.createDataFrame(data, ("key", "value"))
2700+
>>> df.select(from_csv(df.value, lit("a INT")).alias("csv")).collect()
2701+
[Row(csv=Row(a=1))]
26962702
"""
26972703

26982704
sc = SparkContext._active_spark_context
2705+
if isinstance(schema, basestring):
2706+
schema = _create_column_from_literal(schema)
2707+
elif isinstance(schema, Column):
2708+
schema = _to_java_column(schema)
2709+
else:
2710+
raise TypeError("schema argument should be a column or string")
2711+
26992712
jc = sc._jvm.functions.from_csv(_to_java_column(col), schema, options)
27002713
return Column(jc)
27012714

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala renamed to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExpressionUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.csv
1919

20-
object CSVUtils {
20+
object CSVExpressionUtils {
2121
/**
2222
* Filter ignorable rows for CSV iterator (lines empty and starting with `comment`).
2323
* This is currently being used in CSV reading path and CSV schema inference.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class CSVHeaderChecker(
123123
// Note: if there are only comments in the first block, the header would probably
124124
// be not extracted.
125125
if (options.headerFlag && isStartOfFile) {
126-
CSVUtils.extractHeader(lines, options).foreach { header =>
126+
CSVExpressionUtils.extractHeader(lines, options).foreach { header =>
127127
checkHeaderColumnNames(tokenizer.parseLine(header))
128128
}
129129
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class CSVOptions(
8383
}
8484
}
8585

86-
val delimiter = CSVUtils.toChar(
86+
val delimiter = CSVExpressionUtils.toChar(
8787
parameters.getOrElse("sep", parameters.getOrElse("delimiter", ",")))
8888
val parseMode: ParseMode =
8989
parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ private[sql] object UnivocityParser {
338338

339339
val options = parser.options
340340

341-
val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(lines, options)
341+
val filteredLines: Iterator[String] = CSVExpressionUtils.filterCommentAndEmpty(lines, options)
342342

343343
val safeParser = new FailureSafeParser[String](
344344
input => Seq(parser.parse(input)),

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,42 +19,42 @@ package org.apache.spark.sql.catalyst.csv
1919

2020
import org.apache.spark.SparkFunSuite
2121

22-
class CSVUtilsSuite extends SparkFunSuite {
22+
class CSVExpressionUtilsSuite extends SparkFunSuite {
2323
test("Can parse escaped characters") {
24-
assert(CSVUtils.toChar("""\t""") === '\t')
25-
assert(CSVUtils.toChar("""\r""") === '\r')
26-
assert(CSVUtils.toChar("""\b""") === '\b')
27-
assert(CSVUtils.toChar("""\f""") === '\f')
28-
assert(CSVUtils.toChar("""\"""") === '\"')
29-
assert(CSVUtils.toChar("""\'""") === '\'')
30-
assert(CSVUtils.toChar("""\u0000""") === '\u0000')
31-
assert(CSVUtils.toChar("""\\""") === '\\')
24+
assert(CSVExpressionUtils.toChar("""\t""") === '\t')
25+
assert(CSVExpressionUtils.toChar("""\r""") === '\r')
26+
assert(CSVExpressionUtils.toChar("""\b""") === '\b')
27+
assert(CSVExpressionUtils.toChar("""\f""") === '\f')
28+
assert(CSVExpressionUtils.toChar("""\"""") === '\"')
29+
assert(CSVExpressionUtils.toChar("""\'""") === '\'')
30+
assert(CSVExpressionUtils.toChar("""\u0000""") === '\u0000')
31+
assert(CSVExpressionUtils.toChar("""\\""") === '\\')
3232
}
3333

3434
test("Does not accept delimiter larger than one character") {
3535
val exception = intercept[IllegalArgumentException]{
36-
CSVUtils.toChar("ab")
36+
CSVExpressionUtils.toChar("ab")
3737
}
3838
assert(exception.getMessage.contains("cannot be more than one character"))
3939
}
4040

4141
test("Throws exception for unsupported escaped characters") {
4242
val exception = intercept[IllegalArgumentException]{
43-
CSVUtils.toChar("""\1""")
43+
CSVExpressionUtils.toChar("""\1""")
4444
}
4545
assert(exception.getMessage.contains("Unsupported special character for delimiter"))
4646
}
4747

4848
test("string with one backward slash is prohibited") {
4949
val exception = intercept[IllegalArgumentException]{
50-
CSVUtils.toChar("""\""")
50+
CSVExpressionUtils.toChar("""\""")
5151
}
5252
assert(exception.getMessage.contains("Single backslash is prohibited"))
5353
}
5454

5555
test("output proper error message for empty string") {
5656
val exception = intercept[IllegalArgumentException]{
57-
CSVUtils.toChar("")
57+
CSVExpressionUtils.toChar("")
5858
}
5959
assert(exception.getMessage.contains("Delimiter cannot be empty string"))
6060
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD}
3535
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
3636
import org.apache.spark.sql.catalyst.InternalRow
3737
import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser}
38-
import org.apache.spark.sql.catalyst.csv.CSVUtils.filterCommentAndEmpty
3938
import org.apache.spark.sql.execution.datasources._
4039
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
4140
import org.apache.spark.sql.types.StructType
@@ -130,7 +129,7 @@ object TextInputCSVDataSource extends CSVDataSource {
130129
val header = CSVUtils.makeSafeHeader(firstRow, caseSensitive, parsedOptions)
131130
val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions)
132131
val tokenRDD = sampled.rdd.mapPartitions { iter =>
133-
val filteredLines = filterCommentAndEmpty(iter, parsedOptions)
132+
val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
134133
val linesWithoutHeader =
135134
CSVUtils.filterHeaderLine(filteredLines, maybeFirstLine.get, parsedOptions)
136135
val parser = new CsvParser(parsedOptions.asParserSettings)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.csv
1919

2020
import org.apache.spark.rdd.RDD
2121
import org.apache.spark.sql.Dataset
22+
import org.apache.spark.sql.catalyst.csv.CSVExpressionUtils
2223
import org.apache.spark.sql.catalyst.csv.CSVOptions
2324
import org.apache.spark.sql.functions._
2425

@@ -125,4 +126,7 @@ object CSVUtils {
125126
csv.sample(withReplacement = false, options.samplingRatio, 1)
126127
}
127128
}
129+
130+
def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] =
131+
CSVExpressionUtils.filterCommentAndEmpty(iter, options)
128132
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3882,8 +3882,8 @@ object functions {
38823882
* @group collection_funcs
38833883
* @since 3.0.0
38843884
*/
3885-
def from_csv(e: Column, schema: String, options: java.util.Map[String, String]): Column = {
3886-
withExpr(new CsvToStructs(e.expr, lit(schema).expr, options.asScala.toMap))
3885+
def from_csv(e: Column, schema: Column, options: java.util.Map[String, String]): Column = {
3886+
withExpr(new CsvToStructs(e.expr, schema.expr, options.asScala.toMap))
38873887
}
38883888

38893889
// scalastyle:off line.size.limit

sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext {
3131
val schema = "a int"
3232

3333
checkAnswer(
34-
df.select(from_csv($"value", schema, Map[String, String]().asJava)),
34+
df.select(from_csv($"value", lit(schema), Map[String, String]().asJava)),
3535
Row(Row(1)) :: Nil)
3636
}
3737

0 commit comments

Comments
 (0)