Skip to content

Commit 797b2fd

Browse files
committed
Pass SQLContext instead of SparkContext into physical operators.
1 parent 171ebb3 commit 797b2fd

File tree

7 files changed

+51
-44
lines changed

7 files changed

+51
-44
lines changed

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
221221
}
222222

223223
protected[sql] class SparkPlanner extends SparkStrategies {
224-
val sparkContext = self.sparkContext
224+
val sparkContext: SparkContext = self.sparkContext
225+
226+
val sqlContext: SQLContext = self
225227

226228
def numPartitions = self.numShufflePartitions
227229

sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.SparkContext
2424
import org.apache.spark.sql.catalyst.errors._
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.plans.physical._
27+
import org.apache.spark.sql.SQLContext
2728

2829
/**
2930
* :: DeveloperApi ::
@@ -41,7 +42,7 @@ case class Aggregate(
4142
partial: Boolean,
4243
groupingExpressions: Seq[Expression],
4344
aggregateExpressions: Seq[NamedExpression],
44-
child: SparkPlan)(@transient sc: SparkContext)
45+
child: SparkPlan)(@transient sqlContext: SQLContext)
4546
extends UnaryNode with NoBind {
4647

4748
override def requiredChildDistribution =
@@ -55,7 +56,7 @@ case class Aggregate(
5556
}
5657
}
5758

58-
override def otherCopyArgs = sc :: Nil
59+
override def otherCopyArgs = sqlContext :: Nil
5960

6061
// HACK: Generators don't correctly preserve their output through serializations so we grab
6162
// out child's output attributes statically here.

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
4040
// no predicate can be evaluated by matching hash keys
4141
case logical.Join(left, right, LeftSemi, condition) =>
4242
execution.LeftSemiJoinBNL(
43-
planLater(left), planLater(right), condition)(sparkContext) :: Nil
43+
planLater(left), planLater(right), condition)(sqlContext) :: Nil
4444
case _ => Nil
4545
}
4646
}
@@ -103,7 +103,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
103103
partial = true,
104104
groupingExpressions,
105105
partialComputation,
106-
planLater(child))(sparkContext))(sparkContext) :: Nil
106+
planLater(child))(sqlContext))(sqlContext) :: Nil
107107
} else {
108108
Nil
109109
}
@@ -115,7 +115,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
115115
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
116116
case logical.Join(left, right, joinType, condition) =>
117117
execution.BroadcastNestedLoopJoin(
118-
planLater(left), planLater(right), joinType, condition)(sparkContext) :: Nil
118+
planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil
119119
case _ => Nil
120120
}
121121
}
@@ -143,7 +143,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
143143
object TakeOrdered extends Strategy {
144144
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
145145
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
146-
execution.TakeOrdered(limit, order, planLater(child))(sparkContext) :: Nil
146+
execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil
147147
case _ => Nil
148148
}
149149
}
@@ -155,9 +155,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
155155
val relation =
156156
ParquetRelation.create(path, child, sparkContext.hadoopConfiguration)
157157
// Note: overwrite=false because otherwise the metadata we just created will be deleted
158-
InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sparkContext) :: Nil
158+
InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil
159159
case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
160-
InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil
160+
InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil
161161
case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) =>
162162
val prunePushedDownFilters =
163163
if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
@@ -186,7 +186,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
186186
projectList,
187187
filters,
188188
prunePushedDownFilters,
189-
ParquetTableScan(_, relation, filters)(sparkContext)) :: Nil
189+
ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil
190190

191191
case _ => Nil
192192
}
@@ -211,7 +211,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
211211
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
212212
case logical.Distinct(child) =>
213213
execution.Aggregate(
214-
partial = false, child.output, child.output, planLater(child))(sparkContext) :: Nil
214+
partial = false, child.output, child.output, planLater(child))(sqlContext) :: Nil
215215
case logical.Sort(sortExprs, child) =>
216216
// This sort is a global sort. Its requiredDistribution will be an OrderedDistribution.
217217
execution.Sort(sortExprs, global = true, planLater(child)):: Nil
@@ -224,7 +224,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
224224
case logical.Filter(condition, child) =>
225225
execution.Filter(condition, planLater(child)) :: Nil
226226
case logical.Aggregate(group, agg, child) =>
227-
execution.Aggregate(partial = false, group, agg, planLater(child))(sparkContext) :: Nil
227+
execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil
228228
case logical.Sample(fraction, withReplacement, seed, child) =>
229229
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
230230
case logical.LocalRelation(output, data) =>
@@ -233,9 +233,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
233233
new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row))
234234
execution.ExistingRdd(output, dataAsRdd) :: Nil
235235
case logical.Limit(IntegerLiteral(limit), child) =>
236-
execution.Limit(limit, planLater(child))(sparkContext) :: Nil
236+
execution.Limit(limit, planLater(child))(sqlContext) :: Nil
237237
case Unions(unionChildren) =>
238-
execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil
238+
execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil
239239
case logical.Generate(generator, join, outer, _, child) =>
240240
execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil
241241
case logical.NoRelation =>

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ package org.apache.spark.sql.execution
2020
import scala.reflect.runtime.universe.TypeTag
2121

2222
import org.apache.spark.annotation.DeveloperApi
23-
import org.apache.spark.{HashPartitioner, SparkConf, SparkContext}
23+
import org.apache.spark.{HashPartitioner, SparkConf}
2424
import org.apache.spark.rdd.{RDD, ShuffledRDD}
25+
import org.apache.spark.sql.SQLContext
2526
import org.apache.spark.sql.catalyst.ScalaReflection
2627
import org.apache.spark.sql.catalyst.errors._
2728
import org.apache.spark.sql.catalyst.expressions._
@@ -70,12 +71,12 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child:
7071
* :: DeveloperApi ::
7172
*/
7273
@DeveloperApi
73-
case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan {
74+
case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan {
7475
// TODO: attributes output by union should be distinct for nullability purposes
7576
override def output = children.head.output
76-
override def execute() = sc.union(children.map(_.execute()))
77+
override def execute() = sqlContext.sparkContext.union(children.map(_.execute()))
7778

78-
override def otherCopyArgs = sc :: Nil
79+
override def otherCopyArgs = sqlContext :: Nil
7980
}
8081

8182
/**
@@ -87,11 +88,12 @@ case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends
8788
* data to a single partition to compute the global limit.
8889
*/
8990
@DeveloperApi
90-
case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
91+
case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext)
92+
extends UnaryNode {
9193
// TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
9294
// partition local limit -> exchange into one partition -> partition local limit again
9395

94-
override def otherCopyArgs = sc :: Nil
96+
override def otherCopyArgs = sqlContext :: Nil
9597

9698
override def output = child.output
9799

@@ -117,8 +119,8 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) exte
117119
*/
118120
@DeveloperApi
119121
case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
120-
(@transient sc: SparkContext) extends UnaryNode {
121-
override def otherCopyArgs = sc :: Nil
122+
(@transient sqlContext: SQLContext) extends UnaryNode {
123+
override def otherCopyArgs = sqlContext :: Nil
122124

123125
override def output = child.output
124126

@@ -129,7 +131,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
129131

130132
// TODO: Terminal split should be implemented differently from non-terminal split.
131133
// TODO: Pick num splits based on |limit|.
132-
override def execute() = sc.makeRDD(executeCollect(), 1)
134+
override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1)
133135
}
134136

135137
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@ package org.apache.spark.sql.execution
1919

2020
import scala.collection.mutable.{ArrayBuffer, BitSet}
2121

22-
import org.apache.spark.SparkContext
23-
2422
import org.apache.spark.annotation.DeveloperApi
23+
import org.apache.spark.sql.SQLContext
2524
import org.apache.spark.sql.catalyst.expressions._
2625
import org.apache.spark.sql.catalyst.plans._
2726
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
@@ -200,13 +199,13 @@ case class LeftSemiJoinHash(
200199
@DeveloperApi
201200
case class LeftSemiJoinBNL(
202201
streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
203-
(@transient sc: SparkContext)
202+
(@transient sqlContext: SQLContext)
204203
extends BinaryNode {
205204
// TODO: Override requiredChildDistribution.
206205

207206
override def outputPartitioning: Partitioning = streamed.outputPartitioning
208207

209-
override def otherCopyArgs = sc :: Nil
208+
override def otherCopyArgs = sqlContext :: Nil
210209

211210
def output = left.output
212211

@@ -223,7 +222,8 @@ case class LeftSemiJoinBNL(
223222

224223

225224
def execute() = {
226-
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
225+
val broadcastedRelation =
226+
sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
227227

228228
streamed.execute().mapPartitions { streamedIter =>
229229
val joinedRow = new JoinedRow
@@ -263,13 +263,13 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod
263263
@DeveloperApi
264264
case class BroadcastNestedLoopJoin(
265265
streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression])
266-
(@transient sc: SparkContext)
266+
(@transient sqlContext: SQLContext)
267267
extends BinaryNode {
268268
// TODO: Override requiredChildDistribution.
269269

270270
override def outputPartitioning: Partitioning = streamed.outputPartitioning
271271

272-
override def otherCopyArgs = sc :: Nil
272+
override def otherCopyArgs = sqlContext :: Nil
273273

274274
def output = left.output ++ right.output
275275

@@ -286,7 +286,8 @@ case class BroadcastNestedLoopJoin(
286286

287287

288288
def execute() = {
289-
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
289+
val broadcastedRelation =
290+
sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
290291

291292
val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
292293
val matchedRows = new ArrayBuffer[Row]
@@ -337,7 +338,7 @@ case class BroadcastNestedLoopJoin(
337338
}
338339

339340
// TODO: Breaks lineage.
340-
sc.union(
341-
streamedPlusMatches.flatMap(_._1), sc.makeRDD(rightOuterMatches))
341+
sqlContext.sparkContext.union(
342+
streamedPlusMatches.flatMap(_._1), sqlContext.sparkContext.makeRDD(rightOuterMatches))
342343
}
343344
}

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ import parquet.hadoop.util.ContextUtil
3333
import parquet.io.InvalidRecordException
3434
import parquet.schema.MessageType
3535

36-
import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext}
36+
import org.apache.spark.{Logging, SerializableWritable, TaskContext}
3737
import org.apache.spark.rdd.RDD
38+
import org.apache.spark.sql.SQLContext
3839
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
39-
import org.apache.spark.sql.catalyst.types.StructType
4040
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
4141

4242
/**
@@ -49,10 +49,11 @@ case class ParquetTableScan(
4949
output: Seq[Attribute],
5050
relation: ParquetRelation,
5151
columnPruningPred: Seq[Expression])(
52-
@transient val sc: SparkContext)
52+
@transient val sqlContext: SQLContext)
5353
extends LeafNode {
5454

5555
override def execute(): RDD[Row] = {
56+
val sc = sqlContext.sparkContext
5657
val job = new Job(sc.hadoopConfiguration)
5758
ParquetInputFormat.setReadSupportClass(
5859
job,
@@ -93,7 +94,7 @@ case class ParquetTableScan(
9394
.filter(_ != null) // Parquet's record filters may produce null values
9495
}
9596

96-
override def otherCopyArgs = sc :: Nil
97+
override def otherCopyArgs = sqlContext :: Nil
9798

9899
/**
99100
* Applies a (candidate) projection.
@@ -104,7 +105,7 @@ case class ParquetTableScan(
104105
def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = {
105106
val success = validateProjection(prunedAttributes)
106107
if (success) {
107-
ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sc)
108+
ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sqlContext)
108109
} else {
109110
sys.error("Warning: Could not validate Parquet schema projection in pruneColumns")
110111
this
@@ -152,7 +153,7 @@ case class InsertIntoParquetTable(
152153
relation: ParquetRelation,
153154
child: SparkPlan,
154155
overwrite: Boolean = false)(
155-
@transient val sc: SparkContext)
156+
@transient val sqlContext: SQLContext)
156157
extends UnaryNode with SparkHadoopMapReduceUtil {
157158

158159
/**
@@ -168,7 +169,7 @@ case class InsertIntoParquetTable(
168169
val childRdd = child.execute()
169170
assert(childRdd != null)
170171

171-
val job = new Job(sc.hadoopConfiguration)
172+
val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
172173

173174
val writeSupport =
174175
if (child.output.map(_.dataType).forall(_.isPrimitive)) {
@@ -204,7 +205,7 @@ case class InsertIntoParquetTable(
204205

205206
override def output = child.output
206207

207-
override def otherCopyArgs = sc :: Nil
208+
override def otherCopyArgs = sqlContext :: Nil
208209

209210
/**
210211
* Stores the given Row RDD as a Hadoop file.
@@ -231,7 +232,7 @@ case class InsertIntoParquetTable(
231232
val wrappedConf = new SerializableWritable(job.getConfiguration)
232233
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
233234
val jobtrackerID = formatter.format(new Date())
234-
val stageId = sc.newRddId()
235+
val stageId = sqlContext.sparkContext.newRddId()
235236

236237
val taskIdOffset =
237238
if (overwrite) {
@@ -270,7 +271,7 @@ case class InsertIntoParquetTable(
270271
val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
271272
val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
272273
jobCommitter.setupJob(jobTaskContext)
273-
sc.runJob(rdd, writeShard _)
274+
sqlContext.sparkContext.runJob(rdd, writeShard _)
274275
jobCommitter.commitJob(jobTaskContext)
275276
}
276277
}

sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
166166
val scanner = new ParquetTableScan(
167167
ParquetTestData.testData.output,
168168
ParquetTestData.testData,
169-
Seq())(TestSQLContext.sparkContext)
169+
Seq())(TestSQLContext)
170170
val projected = scanner.pruneColumns(ParquetTypesConverter
171171
.convertToAttributes(MessageTypeParser
172172
.parseMessageType(ParquetTestData.subTestSchema)))

0 commit comments

Comments
 (0)