Skip to content

[SPARK-25672][SQL] schema_of_csv() - schema inference from an example #22666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2364,6 +2364,33 @@ def schema_of_json(json, options={}):
return Column(jc)


@ignore_unicode_prefix
@since(3.0)
def schema_of_csv(csv, options={}):
"""
Parses a CSV string and infers its schema in DDL format.

:param col: a CSV string or a string literal containing a CSV string.
:param options: options to control parsing. accepts the same options as the CSV datasource

>>> df = spark.range(1)
>>> df.select(schema_of_csv(lit('1|a'), {'sep':'|'}).alias("csv")).collect()
[Row(csv=u'struct<_c0:int,_c1:string>')]
>>> df.select(schema_of_csv('1|a', {'sep':'|'}).alias("csv")).collect()
[Row(csv=u'struct<_c0:int,_c1:string>')]
"""
if isinstance(csv, basestring):
col = _create_column_from_literal(csv)
elif isinstance(csv, Column):
col = _to_java_column(csv)
else:
raise TypeError("schema argument should be a column or string")

sc = SparkContext._active_spark_context
jc = sc._jvm.functions.schema_of_csv(col, options)
return Column(jc)


@since(1.5)
def size(col):
"""
Expand Down Expand Up @@ -2664,13 +2691,13 @@ def from_csv(col, schema, options={}):
:param schema: a string with schema in DDL format to use when parsing the CSV column.
:param options: options to control parsing. accepts the same options as the CSV datasource

>>> data = [(1, '1')]
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> df.select(from_csv(df.value, "a INT").alias("csv")).collect()
[Row(csv=Row(a=1))]
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> df.select(from_csv(df.value, lit("a INT")).alias("csv")).collect()
[Row(csv=Row(a=1))]
>>> data = [("1,2,3",)]
>>> df = spark.createDataFrame(data, ("value",))
>>> df.select(from_csv(df.value, "a INT, b INT, c INT").alias("csv")).collect()
[Row(csv=Row(a=1, b=2, c=3))]
>>> value = data[0][0]
>>> df.select(from_csv(df.value, schema_of_csv(value)).alias("csv")).collect()
[Row(csv=Row(_c0=1, _c1=2, _c2=3))]
"""

sc = SparkContext._active_spark_context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,8 @@ object FunctionRegistry {
castAlias("string", StringType),

// csv
expression[CsvToStructs]("from_csv")
expression[CsvToStructs]("from_csv"),
expression[SchemaOfCsv]("schema_of_csv")
)

val builtin: SimpleFunctionRegistry = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,18 @@
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.csv
package org.apache.spark.sql.catalyst.csv

import java.math.BigDecimal

import scala.util.control.Exception._
import scala.util.control.Exception.allCatch

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._

private[csv] object CSVInferSchema {
object CSVInferSchema {

/**
* Similar to the JSON schema inference
Expand All @@ -44,13 +43,7 @@ private[csv] object CSVInferSchema {
val rootTypes: Array[DataType] =
tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes)

header.zip(rootTypes).map { case (thisHeader, rootType) =>
val dType = rootType match {
case _: NullType => StringType
case other => other
}
StructField(thisHeader, dType, nullable = true)
}
toStructFields(rootTypes, header, options)
} else {
// By default fields are assumed to be StringType
header.map(fieldName => StructField(fieldName, StringType, nullable = true))
Expand All @@ -59,7 +52,20 @@ private[csv] object CSVInferSchema {
StructType(fields)
}

private def inferRowType(options: CSVOptions)
def toStructFields(
fieldTypes: Array[DataType],
header: Array[String],
options: CSVOptions): Array[StructField] = {
header.zip(fieldTypes).map { case (thisHeader, rootType) =>
val dType = rootType match {
case _: NullType => StringType
case other => other
}
StructField(thisHeader, dType, nullable = true)
}
}

def inferRowType(options: CSVOptions)
(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
var i = 0
while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,39 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{MapType, StringType, StructType}
import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

object ExprUtils {

def evalSchemaExpr(exp: Expression): StructType = exp match {
case Literal(s, StringType) => StructType.fromDDL(s.toString)
def evalSchemaExpr(exp: Expression): StructType = {
// Use `DataType.fromDDL` since the type string can be struct<...>.
val dataType = exp match {
case Literal(s, StringType) =>
DataType.fromDDL(s.toString)
case e @ SchemaOfCsv(_: Literal, _) =>
val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String]
DataType.fromDDL(ddlSchema.toString)
case e => throw new AnalysisException(
"Schema should be specified in DDL format as a string literal or output of " +
s"the schema_of_csv function instead of ${e.sql}")
}

if (!dataType.isInstanceOf[StructType]) {
throw new AnalysisException(
s"Schema should be struct type but got ${dataType.sql}.")
}
dataType.asInstanceOf[StructType]
}

def evalTypeExpr(exp: Expression): DataType = exp match {
case Literal(s, StringType) => DataType.fromDDL(s.toString)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about

if (expr.isFoldable && expr.dataType == StringType) {
  DataType.fromDDL(expr.eval().asInstanceOf[UTF8String].toString)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need to update https://github.com/apache/spark/pull/22666/files#diff-5321c01e95bffc4413c5f3457696213eR157

in case the constant folding rule is disabled.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that's what I initially thought that we should allow constant-foldable expressions as well but just decided to follow the initial intent - literal only support. I wasn't also sure about when we would need constant folding to construct a JSON example because I suspected that's usually copied and pasted from, for instance, a file.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, a column with CSV string may be a result of string functions. So, you could just invoke the functions with an particular inputs. Currently, we force people to materialize an example and copy-past it to schema_of_csv(). That could cause maintainability issues, so, users should keep in sync the example in schema_of_csv() with the code which forms CSV column.

I prepared the PR #27777 to avoid the restriction which is not necessary from my point of view.

case e @ SchemaOfJson(_: Literal, _) =>
val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String]
DataType.fromDDL(ddlSchema.toString)
case e => throw new AnalysisException(
s"Schema should be specified in DDL format as a string literal instead of ${e.sql}")
"Schema should be specified in DDL format as a string literal or output of " +
s"the schema_of_json function instead of ${e.sql}")
}

def convertToMapData(exp: Expression): Map[String, String] = exp match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.sql.catalyst.expressions

import com.univocity.parsers.csv.CsvParser

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.csv._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util._
Expand Down Expand Up @@ -120,3 +123,54 @@ case class CsvToStructs(

override def prettyName: String = "from_csv"
}

/**
* A function infers schema of CSV string.
*/
@ExpressionDescription(
usage = "_FUNC_(csv[, options]) - Returns schema in the DDL format of CSV string.",
examples = """
Examples:
> SELECT _FUNC_('1,abc');
struct<_c0:int,_c1:string>
""",
since = "3.0.0")
case class SchemaOfCsv(
child: Expression,
options: Map[String, String])
extends UnaryExpression with CodegenFallback {

def this(child: Expression) = this(child, Map.empty[String, String])

def this(child: Expression, options: Expression) = this(
child = child,
options = ExprUtils.convertToMapData(options))

override def dataType: DataType = StringType

override def nullable: Boolean = false

@transient
private lazy val csv = child.eval().asInstanceOf[UTF8String]

override def checkInputDataTypes(): TypeCheckResult = child match {
case Literal(s, StringType) if s != null => super.checkInputDataTypes()
case _ => TypeCheckResult.TypeCheckFailure(
s"The input csv should be a string literal and not null; however, got ${child.sql}.")
}

override def eval(v: InternalRow): Any = {
val parsedOptions = new CSVOptions(options, true, "UTC")
val parser = new CsvParser(parsedOptions.asParserSettings)
val row = parser.parseLine(csv.toString)
assert(row != null, "Parsed CSV record should not be null.")

val header = row.zipWithIndex.map { case (_, index) => s"_c$index" }
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row)
val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, parsedOptions))
UTF8String.fromString(st.catalogString)
}

override def prettyName: String = "schema_of_csv"
}
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ case class JsonToStructs(
// Used in `FunctionRegistry`
def this(child: Expression, schema: Expression, options: Map[String, String]) =
this(
schema = JsonExprUtils.evalSchemaExpr(schema),
schema = ExprUtils.evalTypeExpr(schema),
options = options,
child = child,
timeZoneId = None)
Expand All @@ -538,7 +538,7 @@ case class JsonToStructs(

def this(child: Expression, schema: Expression, options: Expression) =
this(
schema = JsonExprUtils.evalSchemaExpr(schema),
schema = ExprUtils.evalTypeExpr(schema),
options = ExprUtils.convertToMapData(options),
child = child,
timeZoneId = None)
Expand Down Expand Up @@ -784,15 +784,3 @@ case class SchemaOfJson(

override def prettyName: String = "schema_of_json"
}

object JsonExprUtils {
def evalSchemaExpr(exp: Expression): DataType = exp match {
case Literal(s, StringType) => DataType.fromDDL(s.toString)
case e @ SchemaOfJson(_: Literal, _) =>
val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String]
DataType.fromDDL(ddlSchema.toString)
case e => throw new AnalysisException(
"Schema should be specified in DDL format as a string literal" +
s" or output of the schema_of_json function instead of ${e.sql}")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.csv
package org.apache.spark.sql.catalyst.csv

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.types._

class CSVInferSchemaSuite extends SparkFunSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.csv
package org.apache.spark.sql.catalyst.csv

import java.math.BigDecimal

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityParser}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,14 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P
}.getCause
assert(exception.getMessage.contains("from_csv() doesn't support the DROPMALFORMED mode"))
}

test("infer schema of CSV strings") {
checkEvaluation(new SchemaOfCsv(Literal.create("1,abc")), "struct<_c0:int,_c1:string>")
}

test("infer schema of CSV strings by using options") {
checkEvaluation(
new SchemaOfCsv(Literal.create("1|abc"), Map("delimiter" -> "|")),
"struct<_c0:int,_c1:string>")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser}
import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVInferSchema, CSVOptions, UnivocityParser}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
Expand Down
35 changes: 35 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3896,6 +3896,41 @@ object functions {
withExpr(new CsvToStructs(e.expr, schema.expr, options.asScala.toMap))
}

/**
* Parses a CSV string and infers its schema in DDL format.
*
* @param csv a CSV string.
*
* @group collection_funcs
* @since 3.0.0
*/
def schema_of_csv(csv: String): Column = schema_of_csv(lit(csv))

/**
* Parses a CSV string and infers its schema in DDL format.
*
* @param csv a string literal containing a CSV string.
*
* @group collection_funcs
* @since 3.0.0
*/
def schema_of_csv(csv: Column): Column = withExpr(new SchemaOfCsv(csv.expr))

/**
* Parses a CSV string and infers its schema in DDL format using options.
*
* @param csv a string literal containing a CSV string.
* @param options options to control how the CSV is parsed. accepts the same options and the
* json data source. See [[DataFrameReader#csv]].
* @return a column with string literal containing schema in DDL format.
*
* @group collection_funcs
* @since 3.0.0
*/
def schema_of_csv(csv: Column, options: java.util.Map[String, String]): Column = {
withExpr(SchemaOfCsv(csv.expr, options.asScala.toMap))
}

// scalastyle:off line.size.limit
// scalastyle:off parameter.number

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,11 @@ select from_csv('1', 'a InvalidType');
select from_csv('1', 'a INT', named_struct('mode', 'PERMISSIVE'));
select from_csv('1', 'a INT', map('mode', 1));
select from_csv();
-- infer schema of json literal
select from_csv('1,abc', schema_of_csv('1,abc'));
select schema_of_csv('1|abc', map('delimiter', '|'));
select schema_of_csv(null);
CREATE TEMPORARY VIEW csvTable(csvField, a) AS SELECT * FROM VALUES ('1,abc', 'a');
SELECT schema_of_csv(csvField) FROM csvTable;
-- Clean up
DROP VIEW IF EXISTS csvTable;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually we don't need to clean up temp views. The golden file test is run with a fresh session.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see but isn't it still better to explicitly clean tables up?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea we need to clean up tables, as they are permanent.

Actually I'm fine with it, as we clean up temp views in a lot of golden files. We can have another PR to remove these temp view clean up.

Loading