Skip to content

Commit

Permalink
[SPARK-5447][SQL] Replaced reference to SchemaRDD with DataFrame.
Browse files Browse the repository at this point in the history
and

[SPARK-5448][SQL] Make CacheManager a concrete class and field in SQLContext

Author: Reynold Xin <rxin@databricks.com>

Closes apache#4242 from rxin/sqlCleanup and squashes the following commits:

e351cb2 [Reynold Xin] Fixed toDataFrame.
6545c42 [Reynold Xin] More changes.
728c017 [Reynold Xin] [SPARK-5447][SQL] Replaced reference to SchemaRDD with DataFrame.
  • Loading branch information
rxin committed Jan 28, 2015
1 parent 453d799 commit c8e934e
Show file tree
Hide file tree
Showing 33 changed files with 217 additions and 203 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ object DatasetExample {
println(s"Loaded ${origData.count()} instances from file: ${params.input}")

// Convert input data to DataFrame explicitly.
val df: DataFrame = origData.toDF
val df: DataFrame = origData.toDataFrame
println(s"Inferred schema:\n${df.schema.prettyJson}")
println(s"Converted to DataFrame with ${df.count()} records")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class ALSModel private[ml] (
}

private object ALSModel {
/** Case class to convert factors to SchemaRDDs */
/** Case class to convert factors to [[DataFrame]]s */
private case class Factor(id: Int, features: Seq[Float])
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ sealed trait Vector extends Serializable {

/**
* User-defined type for [[Vector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.SchemaRDD]].
* via [[org.apache.spark.sql.DataFrame]].
*/
private[spark] class VectorUDT extends UserDefinedType[Vector] {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
dataset = sqlContext.createSchemaRDD(
dataset = sqlContext.createDataFrame(
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
numItemBlocks: Int = 3,
targetRMSE: Double = 0.05): Unit = {
val sqlContext = this.sqlContext
import sqlContext.createSchemaRDD
import sqlContext.createDataFrame
val als = new ALS()
.setRank(rank)
.setRegParam(regParam)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
override def beforeAll(): Unit = {
super.beforeAll()
val sqlContext = new SQLContext(sc)
dataset = sqlContext.createSchemaRDD(
dataset = sqlContext.createDataFrame(
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,14 @@ class ReplSuite extends FunSuite {
assertDoesNotContain("Exception", output)
}

test("SPARK-2576 importing SQLContext.createSchemaRDD.") {
test("SPARK-2576 importing SQLContext.createDataFrame.") {
// We need to use local-cluster to test this case.
val output = runInterpreter("local-cluster[1,1,512]",
"""
|val sqlContext = new org.apache.spark.sql.SQLContext(sc)
|import sqlContext.createSchemaRDD
|import sqlContext.createDataFrame
|case class TestCaseClass(value: Int)
|sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toSchemaRDD.collect
|sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDataFrame.collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
Expand Down
2 changes: 1 addition & 1 deletion sql/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Type in expressions to have them evaluated.
Type :help for more information.

scala> val query = sql("SELECT * FROM (SELECT * FROM src) a")
query: org.apache.spark.sql.SchemaRDD =
query: org.apache.spark.sql.DataFrame =
== Query Plan ==
== Physical Plan ==
HiveTableScan [key#10,value#11], (MetastoreRelation default, src, None), None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -930,13 +930,13 @@ case class MapType(
*
* This interface allows a user to make their own classes more interoperable with SparkSQL;
* e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create
* a SchemaRDD which has class X in the schema.
* a `DataFrame` which has class X in the schema.
*
* For SparkSQL to recognize UDTs, the UDT must be annotated with
* [[SQLUserDefinedType]].
*
* The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD.
* The conversion via `deserialize` occurs when reading from a `SchemaRDD`.
* The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD.
* The conversion via `deserialize` occurs when reading from a `DataFrame`.
*/
@DeveloperApi
abstract class UserDefinedType[UserType] extends DataType with Serializable {
Expand Down
22 changes: 12 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import java.util.concurrent.locks.ReentrantReadWriteLock

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.columnar.InMemoryRelation
import org.apache.spark.storage.StorageLevel
Expand All @@ -32,9 +33,10 @@ private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryR
* results when subsequent queries are executed. Data is cached using byte buffers stored in an
* InMemoryRelation. This relation is automatically substituted query plans that return the
* `sameResult` as the originally cached query.
*
* Internal to Spark SQL.
*/
private[sql] trait CacheManager {
self: SQLContext =>
private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {

@transient
private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
Expand All @@ -43,13 +45,13 @@ private[sql] trait CacheManager {
private val cacheLock = new ReentrantReadWriteLock

/** Returns true if the table is currently cached in-memory. */
def isCached(tableName: String): Boolean = lookupCachedData(table(tableName)).nonEmpty
def isCached(tableName: String): Boolean = lookupCachedData(sqlContext.table(tableName)).nonEmpty

/** Caches the specified table in-memory. */
def cacheTable(tableName: String): Unit = cacheQuery(table(tableName), Some(tableName))
def cacheTable(tableName: String): Unit = cacheQuery(sqlContext.table(tableName), Some(tableName))

/** Removes the specified table from the in-memory cache. */
def uncacheTable(tableName: String): Unit = uncacheQuery(table(tableName))
def uncacheTable(tableName: String): Unit = uncacheQuery(sqlContext.table(tableName))

/** Acquires a read lock on the cache for the duration of `f`. */
private def readLock[A](f: => A): A = {
Expand Down Expand Up @@ -91,15 +93,15 @@ private[sql] trait CacheManager {
CachedData(
planToCache,
InMemoryRelation(
conf.useCompression,
conf.columnBatchSize,
sqlContext.conf.useCompression,
sqlContext.conf.columnBatchSize,
storageLevel,
query.queryExecution.executedPlan,
tableName))
}
}

/** Removes the data for the given SchemaRDD from the cache */
/** Removes the data for the given [[DataFrame]] from the cache */
private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
Expand All @@ -108,7 +110,7 @@ private[sql] trait CacheManager {
cachedData.remove(dataIndex)
}

/** Tries to remove the data for the given SchemaRDD from the cache if it's cached */
/** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */
private[sql] def tryUncacheQuery(
query: DataFrame,
blocking: Boolean = true): Boolean = writeLock {
Expand All @@ -122,7 +124,7 @@ private[sql] trait CacheManager {
found
}

/** Optionally returns cached data for the given SchemaRDD */
/** Optionally returns cached data for the given [[DataFrame]] */
private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock {
lookupCachedData(query.queryExecution.analyzed)
}
Expand Down
10 changes: 5 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class DataFrame protected[sql](
* An implicit conversion function internal to this class for us to avoid doing
* "new DataFrame(...)" everywhere.
*/
private[this] implicit def toDataFrame(logicalPlan: LogicalPlan): DataFrame = {
private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = {
new DataFrame(sqlContext, logicalPlan, true)
}

Expand All @@ -130,7 +130,7 @@ class DataFrame protected[sql](
/**
* Return the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala.
*/
def toDF: DataFrame = this
def toDataFrame: DataFrame = this

/** Return the schema of this [[DataFrame]]. */
override def schema: StructType = queryExecution.analyzed.schema
Expand Down Expand Up @@ -496,17 +496,17 @@ class DataFrame protected[sql](
}

override def persist(): this.type = {
sqlContext.cacheQuery(this)
sqlContext.cacheManager.cacheQuery(this)
this
}

override def persist(newLevel: StorageLevel): this.type = {
sqlContext.cacheQuery(this, None, newLevel)
sqlContext.cacheManager.cacheQuery(this, None, newLevel)
this
}

override def unpersist(blocking: Boolean): this.type = {
sqlContext.tryUncacheQuery(this, blocking)
sqlContext.cacheManager.tryUncacheQuery(this, blocking)
this
}

Expand Down
102 changes: 56 additions & 46 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ import org.apache.spark.util.Utils
@AlphaComponent
class SQLContext(@transient val sparkContext: SparkContext)
extends org.apache.spark.Logging
with CacheManager
with Serializable {

self =>
Expand Down Expand Up @@ -117,12 +116,57 @@ class SQLContext(@transient val sparkContext: SparkContext)
case _ =>
}

protected[sql] val cacheManager = new CacheManager(this)

/**
* A collection of methods that are considered experimental, but can be used to hook into
* the query planner for advanced functionalities.
*/
val experimental: ExperimentalMethods = new ExperimentalMethods(this)

/**
* A collection of methods for registering user-defined functions (UDF).
*
* The following example registers a Scala closure as UDF:
* {{{
* sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1)
* }}}
*
* The following example registers a UDF in Java:
* {{{
* sqlContext.udf().register("myUDF",
* new UDF2<Integer, String, String>() {
* @Override
* public String call(Integer arg1, String arg2) {
* return arg2 + arg1;
* }
* }, DataTypes.StringType);
* }}}
*
* Or, to use Java 8 lambda syntax:
* {{{
* sqlContext.udf().register("myUDF",
* (Integer arg1, String arg2) -> arg2 + arg1),
* DataTypes.StringType);
* }}}
*/
val udf: UDFRegistration = new UDFRegistration(this)

/** Returns true if the table is currently cached in-memory. */
def isCached(tableName: String): Boolean = cacheManager.isCached(tableName)

/** Caches the specified table in-memory. */
def cacheTable(tableName: String): Unit = cacheManager.cacheTable(tableName)

/** Removes the specified table from the in-memory cache. */
def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName)

/**
* Creates a SchemaRDD from an RDD of case classes.
* Creates a DataFrame from an RDD of case classes.
*
* @group userf
*/
implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = {
implicit def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = {
SparkPlan.currentContext.set(self)
val attributeSeq = ScalaReflection.attributesFor[A]
val schema = StructType.fromAttributes(attributeSeq)
Expand All @@ -133,7 +177,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
/**
* Convert a [[BaseRelation]] created for external data sources into a [[DataFrame]].
*/
def baseRelationToSchemaRDD(baseRelation: BaseRelation): DataFrame = {
def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = {
new DataFrame(this, LogicalRelation(baseRelation))
}

Expand All @@ -155,21 +199,21 @@ class SQLContext(@transient val sparkContext: SparkContext)
* val people =
* sc.textFile("examples/src/main/resources/people.txt").map(
* _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
* val peopleSchemaRDD = sqlContext. applySchema(people, schema)
* peopleSchemaRDD.printSchema
* val dataFrame = sqlContext. applySchema(people, schema)
* dataFrame.printSchema
* // root
* // |-- name: string (nullable = false)
* // |-- age: integer (nullable = true)
*
* peopleSchemaRDD.registerTempTable("people")
* dataFrame.registerTempTable("people")
* sqlContext.sql("select name from people").collect.foreach(println)
* }}}
*
* @group userf
*/
@DeveloperApi
def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
// TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
new DataFrame(this, logicalPlan)
Expand Down Expand Up @@ -309,12 +353,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group userf
*/
def dropTempTable(tableName: String): Unit = {
tryUncacheQuery(table(tableName))
cacheManager.tryUncacheQuery(table(tableName))
catalog.unregisterTable(Seq(tableName))
}

/**
* Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is
* Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is
* used for SQL parsing can be configured with 'spark.sql.dialect'.
*
* @group userf
Expand All @@ -327,44 +371,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
}

/** Returns the specified table as a SchemaRDD */
/** Returns the specified table as a [[DataFrame]]. */
def table(tableName: String): DataFrame =
new DataFrame(this, catalog.lookupRelation(Seq(tableName)))

/**
* A collection of methods that are considered experimental, but can be used to hook into
* the query planner for advanced functionalities.
*/
val experimental: ExperimentalMethods = new ExperimentalMethods(this)

/**
* A collection of methods for registering user-defined functions (UDF).
*
* The following example registers a Scala closure as UDF:
* {{{
* sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1)
* }}}
*
* The following example registers a UDF in Java:
* {{{
* sqlContext.udf().register("myUDF",
* new UDF2<Integer, String, String>() {
* @Override
* public String call(Integer arg1, String arg2) {
* return arg2 + arg1;
* }
* }, DataTypes.StringType);
* }}}
*
* Or, to use Java 8 lambda syntax:
* {{{
* sqlContext.udf().register("myUDF",
* (Integer arg1, String arg2) -> arg2 + arg1),
* DataTypes.StringType);
* }}}
*/
val udf: UDFRegistration = new UDFRegistration(this)

protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext: SparkContext = self.sparkContext

Expand Down Expand Up @@ -455,7 +465,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected class QueryExecution(val logical: LogicalPlan) {

lazy val analyzed: LogicalPlan = ExtractPythonUdfs(analyzer(logical))
lazy val withCachedData: LogicalPlan = useCachedData(analyzed)
lazy val withCachedData: LogicalPlan = cacheManager.useCachedData(analyzed)
lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData)

// TODO: Don't just pick the first one...
Expand Down
Loading

0 comments on commit c8e934e

Please sign in to comment.