Skip to content

Commit 7e79ce5

Browse files
committed
Merge pull request #22 from yhuai/pr3431yin
Remove Option from createRelation.
2 parents a852b10 + 38f634e commit 7e79ce5

File tree

5 files changed

+56
-35
lines changed

5 files changed

+56
-35
lines changed

sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,27 @@ import org.apache.spark.sql.SQLContext
2121
import org.apache.spark.sql.catalyst.types.StructType
2222
import org.apache.spark.sql.sources._
2323

24-
private[sql] class DefaultSource extends SchemaRelationProvider {
25-
/** Returns a new base relation with the given parameters. */
24+
private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider {
25+
26+
/** Returns a new base relation with the parameters. */
27+
override def createRelation(
28+
sqlContext: SQLContext,
29+
parameters: Map[String, String]): BaseRelation = {
30+
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
31+
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
32+
33+
JSONRelation(fileName, samplingRatio, None)(sqlContext)
34+
}
35+
36+
/** Returns a new base relation with the given schema and parameters. */
2637
override def createRelation(
2738
sqlContext: SQLContext,
2839
parameters: Map[String, String],
29-
schema: Option[StructType]): BaseRelation = {
40+
schema: StructType): BaseRelation = {
3041
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
3142
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
3243

33-
JSONRelation(fileName, samplingRatio, schema)(sqlContext)
44+
JSONRelation(fileName, samplingRatio, Some(schema))(sqlContext)
3445
}
3546
}
3647

sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,37 +22,37 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
2222
import org.apache.hadoop.conf.{Configurable, Configuration}
2323
import org.apache.hadoop.io.Writable
2424
import org.apache.hadoop.mapreduce.{JobContext, InputSplit, Job}
25+
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
26+
2527
import parquet.hadoop.ParquetInputFormat
2628
import parquet.hadoop.util.ContextUtil
2729

2830
import org.apache.spark.annotation.DeveloperApi
2931
import org.apache.spark.{Partition => SparkPartition, Logging}
3032
import org.apache.spark.rdd.{NewHadoopPartition, RDD}
31-
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
33+
34+
import org.apache.spark.sql.{SQLConf, Row, SQLContext}
3235
import org.apache.spark.sql.catalyst.expressions._
33-
import org.apache.spark.sql.catalyst.types.{IntegerType, StructField, StructType}
36+
import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, StructField, StructType}
3437
import org.apache.spark.sql.sources._
35-
import org.apache.spark.sql.{SQLConf, SQLContext}
3638

3739
import scala.collection.JavaConversions._
3840

39-
4041
/**
4142
* Allows creation of parquet based tables using the syntax
4243
* `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option
4344
* required is `path`, which should be the location of a collection of, optionally partitioned,
4445
* parquet files.
4546
*/
46-
class DefaultSource extends SchemaRelationProvider {
47+
class DefaultSource extends RelationProvider {
4748
/** Returns a new base relation with the given parameters. */
4849
override def createRelation(
4950
sqlContext: SQLContext,
50-
parameters: Map[String, String],
51-
schema: Option[StructType]): BaseRelation = {
51+
parameters: Map[String, String]): BaseRelation = {
5252
val path =
5353
parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables."))
5454

55-
ParquetRelation2(path, schema)(sqlContext)
55+
ParquetRelation2(path)(sqlContext)
5656
}
5757
}
5858

@@ -82,9 +82,7 @@ private[parquet] case class Partition(partitionValues: Map[String, Any], files:
8282
* discovery.
8383
*/
8484
@DeveloperApi
85-
case class ParquetRelation2(
86-
path: String,
87-
userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext)
85+
case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext)
8886
extends CatalystScan with Logging {
8987

9088
def sparkContext = sqlContext.sparkContext
@@ -135,13 +133,12 @@ case class ParquetRelation2(
135133

136134
override val sizeInBytes = partitions.flatMap(_.files).map(_.getLen).sum
137135

138-
val dataSchema = userSpecifiedSchema.getOrElse(
139-
StructType.fromAttributes( // TODO: Parquet code should not deal with attributes.
140-
ParquetTypesConverter.readSchemaFromFile(
141-
partitions.head.files.head.getPath,
142-
Some(sparkContext.hadoopConfiguration),
143-
sqlContext.isParquetBinaryAsString))
144-
)
136+
val dataSchema = StructType.fromAttributes( // TODO: Parquet code should not deal with attributes.
137+
ParquetTypesConverter.readSchemaFromFile(
138+
partitions.head.files.head.getPath,
139+
Some(sparkContext.hadoopConfiguration),
140+
sqlContext.isParquetBinaryAsString))
141+
145142
val dataIncludesKey =
146143
partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true)
147144

sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,28 @@ private[sql] case class CreateTableUsing(
190190
sys.error(s"Failed to load class for data source: $provider")
191191
}
192192
}
193-
val relation = clazz.newInstance match {
194-
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
195-
dataSource
196-
.asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
197-
.createRelation(sqlContext, new CaseInsensitiveMap(options))
198-
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
199-
dataSource
200-
.asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider]
201-
.createRelation(sqlContext, new CaseInsensitiveMap(options), userSpecifiedSchema)
193+
194+
val relation = userSpecifiedSchema match {
195+
case Some(schema: StructType) => {
196+
clazz.newInstance match {
197+
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
198+
dataSource
199+
.asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider]
200+
.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
201+
case _ =>
202+
sys.error(s"${clazz.getCanonicalName} should extend SchemaRelationProvider.")
203+
}
204+
}
205+
case None => {
206+
clazz.newInstance match {
207+
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
208+
dataSource
209+
.asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
210+
.createRelation(sqlContext, new CaseInsensitiveMap(options))
211+
case _ =>
212+
sys.error(s"${clazz.getCanonicalName} should extend RelationProvider.")
213+
}
214+
}
202215
}
203216

204217
sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName)

sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ trait SchemaRelationProvider {
6868
def createRelation(
6969
sqlContext: SQLContext,
7070
parameters: Map[String, String],
71-
schema: Option[StructType]): BaseRelation
71+
schema: StructType): BaseRelation
7272
}
7373

7474
/**

sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,18 @@ class AllDataTypesScanSource extends SchemaRelationProvider {
4545
override def createRelation(
4646
sqlContext: SQLContext,
4747
parameters: Map[String, String],
48-
schema: Option[StructType]): BaseRelation = {
48+
schema: StructType): BaseRelation = {
4949
AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext)
5050
}
5151
}
5252

5353
case class AllDataTypesScan(
5454
from: Int,
5555
to: Int,
56-
userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext)
56+
userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext)
5757
extends TableScan {
5858

59-
override def schema = userSpecifiedSchema.get
59+
override def schema = userSpecifiedSchema
6060

6161
override def buildScan() = {
6262
sqlContext.sparkContext.parallelize(from to to).map { i =>

0 commit comments

Comments
 (0)