Skip to content

Commit 693a323

Browse files
scwfmarmbrus
authored andcommitted
[SPARK-4574][SQL] Adding support for defining schema in foreign DDL commands.
Adding support for defining schema in foreign DDL commands. Now foreign DDL support commands like: ``` CREATE TEMPORARY TABLE avroTable USING org.apache.spark.sql.avro OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") ``` With this PR user can define schema instead of infer from file, so support ddl command as follows: ``` CREATE TEMPORARY TABLE avroTable(a int, b string) USING org.apache.spark.sql.avro OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") ``` Author: scwf <wangfei1@huawei.com> Author: Yin Huai <yhuai@databricks.com> Author: Fei Wang <wangfei1@huawei.com> Author: wangfei <wangfei1@huawei.com> Closes #3431 from scwf/ddl and squashes the following commits: 7e79ce5 [Fei Wang] Merge pull request #22 from yhuai/pr3431yin 38f634e [Yin Huai] Remove Option from createRelation. 65e9c73 [Yin Huai] Revert all changes since applying a given schema has not been testd. a852b10 [scwf] remove cleanIdentifier f336a16 [Fei Wang] Merge pull request #21 from yhuai/pr3431yin baf79b5 [Yin Huai] Test special characters quoted by backticks. 50a03b0 [Yin Huai] Use JsonRDD.nullTypeToStringType to convert NullType to StringType. 1eeb769 [Fei Wang] Merge pull request #20 from yhuai/pr3431yin f5c22b0 [Yin Huai] Refactor code and update test cases. f1cffe4 [Yin Huai] Revert "minor refactory" b621c8f [scwf] minor refactory d02547f [scwf] fix HiveCompatibilitySuite test failure 8dfbf7a [scwf] more tests for complex data type ddab984 [Fei Wang] Merge pull request #19 from yhuai/pr3431yin 91ad91b [Yin Huai] Parse data types in DDLParser. cf982d2 [scwf] fixed test failure 445b57b [scwf] address comments 02a662c [scwf] style issue 44eb70c [scwf] fix decimal parser issue 83b6fc3 [scwf] minor fix 9bf12f8 [wangfei] adding test case 7787ec7 [wangfei] added SchemaRelationProvider 0ba70df [wangfei] draft version
1 parent 4b39fd1 commit 693a323

File tree

6 files changed

+400
-113
lines changed

6 files changed

+400
-113
lines changed

sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,48 @@
1818
package org.apache.spark.sql.json
1919

2020
import org.apache.spark.sql.SQLContext
21+
import org.apache.spark.sql.catalyst.types.StructType
2122
import org.apache.spark.sql.sources._
2223

23-
private[sql] class DefaultSource extends RelationProvider {
24-
/** Returns a new base relation with the given parameters. */
24+
private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider {
25+
26+
/** Returns a new base relation with the parameters. */
2527
override def createRelation(
2628
sqlContext: SQLContext,
2729
parameters: Map[String, String]): BaseRelation = {
2830
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
2931
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
3032

31-
JSONRelation(fileName, samplingRatio)(sqlContext)
33+
JSONRelation(fileName, samplingRatio, None)(sqlContext)
34+
}
35+
36+
/** Returns a new base relation with the given schema and parameters. */
37+
override def createRelation(
38+
sqlContext: SQLContext,
39+
parameters: Map[String, String],
40+
schema: StructType): BaseRelation = {
41+
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
42+
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
43+
44+
JSONRelation(fileName, samplingRatio, Some(schema))(sqlContext)
3245
}
3346
}
3447

35-
private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)(
48+
private[sql] case class JSONRelation(
49+
fileName: String,
50+
samplingRatio: Double,
51+
userSpecifiedSchema: Option[StructType])(
3652
@transient val sqlContext: SQLContext)
3753
extends TableScan {
3854

3955
private def baseRDD = sqlContext.sparkContext.textFile(fileName)
4056

41-
override val schema =
42-
JsonRDD.inferSchema(
43-
baseRDD,
44-
samplingRatio,
45-
sqlContext.columnNameOfCorruptRecord)
57+
override val schema = userSpecifiedSchema.getOrElse(
58+
JsonRDD.nullTypeToStringType(
59+
JsonRDD.inferSchema(
60+
baseRDD,
61+
samplingRatio,
62+
sqlContext.columnNameOfCorruptRecord)))
4663

4764
override def buildScan() =
4865
JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord)

sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala

Lines changed: 125 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,15 @@
1717

1818
package org.apache.spark.sql.sources
1919

20-
import org.apache.spark.Logging
21-
import org.apache.spark.sql.SQLContext
22-
import org.apache.spark.sql.execution.RunnableCommand
23-
import org.apache.spark.util.Utils
24-
2520
import scala.language.implicitConversions
26-
import scala.util.parsing.combinator.lexical.StdLexical
2721
import scala.util.parsing.combinator.syntactical.StandardTokenParsers
2822
import scala.util.parsing.combinator.PackratParsers
2923

24+
import org.apache.spark.Logging
25+
import org.apache.spark.sql.SQLContext
26+
import org.apache.spark.sql.catalyst.types._
27+
import org.apache.spark.sql.execution.RunnableCommand
28+
import org.apache.spark.util.Utils
3029
import org.apache.spark.sql.catalyst.plans.logical._
3130
import org.apache.spark.sql.catalyst.SqlLexical
3231

@@ -44,6 +43,14 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
4443
}
4544
}
4645

46+
def parseType(input: String): DataType = {
47+
phrase(dataType)(new lexical.Scanner(input)) match {
48+
case Success(r, x) => r
49+
case x =>
50+
sys.error(s"Unsupported dataType: $x")
51+
}
52+
}
53+
4754
protected case class Keyword(str: String)
4855

4956
protected implicit def asParser(k: Keyword): Parser[String] =
@@ -55,6 +62,24 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
5562
protected val USING = Keyword("USING")
5663
protected val OPTIONS = Keyword("OPTIONS")
5764

65+
// Data types.
66+
protected val STRING = Keyword("STRING")
67+
protected val BINARY = Keyword("BINARY")
68+
protected val BOOLEAN = Keyword("BOOLEAN")
69+
protected val TINYINT = Keyword("TINYINT")
70+
protected val SMALLINT = Keyword("SMALLINT")
71+
protected val INT = Keyword("INT")
72+
protected val BIGINT = Keyword("BIGINT")
73+
protected val FLOAT = Keyword("FLOAT")
74+
protected val DOUBLE = Keyword("DOUBLE")
75+
protected val DECIMAL = Keyword("DECIMAL")
76+
protected val DATE = Keyword("DATE")
77+
protected val TIMESTAMP = Keyword("TIMESTAMP")
78+
protected val VARCHAR = Keyword("VARCHAR")
79+
protected val ARRAY = Keyword("ARRAY")
80+
protected val MAP = Keyword("MAP")
81+
protected val STRUCT = Keyword("STRUCT")
82+
5883
// Use reflection to find the reserved words defined in this class.
5984
protected val reservedWords =
6085
this.getClass
@@ -67,26 +92,92 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
6792
protected lazy val ddl: Parser[LogicalPlan] = createTable
6893

6994
/**
70-
* CREATE TEMPORARY TABLE avroTable
95+
* `CREATE TEMPORARY TABLE avroTable
7196
* USING org.apache.spark.sql.avro
72-
* OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")
97+
* OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
98+
* or
99+
* `CREATE TEMPORARY TABLE avroTable(intField int, stringField string...)
100+
* USING org.apache.spark.sql.avro
101+
* OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
73102
*/
74103
protected lazy val createTable: Parser[LogicalPlan] =
75-
CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ {
76-
case tableName ~ provider ~ opts =>
77-
CreateTableUsing(tableName, provider, opts)
104+
(
105+
CREATE ~ TEMPORARY ~ TABLE ~> ident
106+
~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ {
107+
case tableName ~ columns ~ provider ~ opts =>
108+
val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields)))
109+
CreateTableUsing(tableName, userSpecifiedSchema, provider, opts)
78110
}
111+
)
112+
113+
protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")"
79114

80115
protected lazy val options: Parser[Map[String, String]] =
81116
"(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap }
82117

83118
protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")}
84119

85120
protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) }
121+
122+
protected lazy val column: Parser[StructField] =
123+
ident ~ dataType ^^ { case columnName ~ typ =>
124+
StructField(columnName, typ)
125+
}
126+
127+
protected lazy val primitiveType: Parser[DataType] =
128+
STRING ^^^ StringType |
129+
BINARY ^^^ BinaryType |
130+
BOOLEAN ^^^ BooleanType |
131+
TINYINT ^^^ ByteType |
132+
SMALLINT ^^^ ShortType |
133+
INT ^^^ IntegerType |
134+
BIGINT ^^^ LongType |
135+
FLOAT ^^^ FloatType |
136+
DOUBLE ^^^ DoubleType |
137+
fixedDecimalType | // decimal with precision/scale
138+
DECIMAL ^^^ DecimalType.Unlimited | // decimal with no precision/scale
139+
DATE ^^^ DateType |
140+
TIMESTAMP ^^^ TimestampType |
141+
VARCHAR ~ "(" ~ numericLit ~ ")" ^^^ StringType
142+
143+
protected lazy val fixedDecimalType: Parser[DataType] =
144+
(DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
145+
case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
146+
}
147+
148+
protected lazy val arrayType: Parser[DataType] =
149+
ARRAY ~> "<" ~> dataType <~ ">" ^^ {
150+
case tpe => ArrayType(tpe)
151+
}
152+
153+
protected lazy val mapType: Parser[DataType] =
154+
MAP ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
155+
case t1 ~ _ ~ t2 => MapType(t1, t2)
156+
}
157+
158+
protected lazy val structField: Parser[StructField] =
159+
ident ~ ":" ~ dataType ^^ {
160+
case fieldName ~ _ ~ tpe => StructField(fieldName, tpe, nullable = true)
161+
}
162+
163+
protected lazy val structType: Parser[DataType] =
164+
(STRUCT ~> "<" ~> repsep(structField, ",") <~ ">" ^^ {
165+
case fields => new StructType(fields)
166+
}) |
167+
(STRUCT ~> "<>" ^^ {
168+
case fields => new StructType(Nil)
169+
})
170+
171+
private[sql] lazy val dataType: Parser[DataType] =
172+
arrayType |
173+
mapType |
174+
structType |
175+
primitiveType
86176
}
87177

88178
private[sql] case class CreateTableUsing(
89179
tableName: String,
180+
userSpecifiedSchema: Option[StructType],
90181
provider: String,
91182
options: Map[String, String]) extends RunnableCommand {
92183

@@ -99,8 +190,29 @@ private[sql] case class CreateTableUsing(
99190
sys.error(s"Failed to load class for data source: $provider")
100191
}
101192
}
102-
val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
103-
val relation = dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
193+
194+
val relation = userSpecifiedSchema match {
195+
case Some(schema: StructType) => {
196+
clazz.newInstance match {
197+
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
198+
dataSource
199+
.asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider]
200+
.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
201+
case _ =>
202+
sys.error(s"${clazz.getCanonicalName} should extend SchemaRelationProvider.")
203+
}
204+
}
205+
case None => {
206+
clazz.newInstance match {
207+
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
208+
dataSource
209+
.asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
210+
.createRelation(sqlContext, new CaseInsensitiveMap(options))
211+
case _ =>
212+
sys.error(s"${clazz.getCanonicalName} should extend RelationProvider.")
213+
}
214+
}
215+
}
104216

105217
sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName)
106218
Seq.empty

sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package org.apache.spark.sql.sources
1818

1919
import org.apache.spark.annotation.{Experimental, DeveloperApi}
2020
import org.apache.spark.rdd.RDD
21-
import org.apache.spark.sql.{SQLConf, Row, SQLContext, StructType}
21+
import org.apache.spark.sql.{Row, SQLContext, StructType}
2222
import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute}
2323

2424
/**
@@ -44,6 +44,33 @@ trait RelationProvider {
4444
def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation
4545
}
4646

47+
/**
48+
* ::DeveloperApi::
49+
* Implemented by objects that produce relations for a specific kind of data source. When
50+
* Spark SQL is given a DDL operation with
51+
* 1. USING clause: to specify the implemented SchemaRelationProvider
52+
* 2. User defined schema: users can define schema optionally when create table
53+
*
54+
* Users may specify the fully qualified class name of a given data source. When that class is
55+
* not found Spark SQL will append the class name `DefaultSource` to the path, allowing for
56+
* less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the
57+
* data source 'org.apache.spark.sql.json.DefaultSource'
58+
*
59+
* A new instance of this class with be instantiated each time a DDL call is made.
60+
*/
61+
@DeveloperApi
62+
trait SchemaRelationProvider {
63+
/**
64+
* Returns a new base relation with the given parameters and user defined schema.
65+
* Note: the parameters' keywords are case insensitive and this insensitivity is enforced
66+
* by the Map that is passed to the function.
67+
*/
68+
def createRelation(
69+
sqlContext: SQLContext,
70+
parameters: Map[String, String],
71+
schema: StructType): BaseRelation
72+
}
73+
4774
/**
4875
* ::DeveloperApi::
4976
* Represents a collection of tuples with a known schema. Classes that extend BaseRelation must

0 commit comments

Comments
 (0)