Skip to content

Commit 2225431

Browse files
committed
Merge pull request apache#48 from marmbrus/minorFixes
Several minor fixes for bugs found during benchmarking.
2 parents 9990ec7 + d393d2a commit 2225431

File tree

12 files changed

+102
-31
lines changed

12 files changed

+102
-31
lines changed

catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package org.apache.spark.sql
22
package catalyst
33
package analysis
44

5-
import plans.logical.LogicalPlan
5+
import plans.logical.{LogicalPlan, Subquery}
66
import scala.collection.mutable
77

88
/**
@@ -31,8 +31,14 @@ trait OverrideCatalog extends Catalog {
3131
tableName: String,
3232
alias: Option[String] = None): LogicalPlan = {
3333

34-
overrides.get((databaseName, tableName))
35-
.getOrElse(super.lookupRelation(databaseName, tableName, alias))
34+
val overriddenTable = overrides.get((databaseName, tableName))
35+
36+
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
37+
// properly qualified with this alias.
38+
val withAlias =
39+
overriddenTable.map(r => alias.map(a => Subquery(a.toLowerCase, r)).getOrElse(r))
40+
41+
withAlias.getOrElse(super.lookupRelation(databaseName, tableName, alias))
3642
}
3743

3844
def overrideTable(databaseName: Option[String], tableName: String, plan: LogicalPlan) =

catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ package object dsl {
138138
alias: Option[String] = None) =
139139
Generate(generator, join, outer, None, plan)
140140

141+
def insertInto(tableName: String, overwrite: Boolean = false) =
142+
InsertIntoTable(analysis.UnresolvedRelation(None, tableName), Map.empty, plan, overwrite)
143+
141144
def analyze = analysis.SimpleAnalyzer(plan)
142145
}
143146
}

catalyst/src/main/scala/org/apache/spark/sql/package.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@ package object sql {
1212
com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger(name))
1313

1414
protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging
15+
16+
type Row = catalyst.expressions.Row
1517
}

core/src/main/scala/org/apache/spark/rdd/SharkPairRDDFunctions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class SharkPairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
5656
def groupByKeyLocally(): RDD[(K, Seq[V])] = {
5757
def createCombiner(v: V) = ArrayBuffer(v)
5858
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
59-
val aggregator = new Aggregator[K, V, ArrayBuffer[V]](createCombiner _, mergeValue _, null)
59+
val aggregator = new Aggregator[K, V, ArrayBuffer[V]](createCombiner _, mergeValue _, _ ++ _)
6060
val bufs = self.mapPartitionsWithContext((context, iter) => {
6161
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
6262
}, preservesPartitioning = true)

core/src/main/scala/org/apache/spark/sql/SparkSqlContext.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class SparkSqlContext(val sparkContext: SparkContext) extends Logging {
6060
val sparkContext = self.sparkContext
6161

6262
val strategies: Seq[Strategy] =
63+
TopK ::
6364
PartialAggregation ::
6465
SparkEquiInnerJoin ::
6566
BasicOperators ::

core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,25 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
137137
protected lazy val singleRowRdd =
138138
sparkContext.parallelize(Seq(new GenericRow(IndexedSeq()): Row), 1)
139139

140+
def convertToCatalyst(a: Any): Any = a match {
141+
case s: Seq[Any] => s.map(convertToCatalyst)
142+
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toSeq)
143+
case other => other
144+
}
145+
146+
object TopK extends Strategy {
147+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
148+
case logical.StopAfter(limit, logical.Sort(order, child)) =>
149+
execution.TopK(
150+
Evaluate(limit, Nil).asInstanceOf[Int], order, planLater(child))(sparkContext) :: Nil
151+
case _ => Nil
152+
}
153+
}
154+
140155
// Can we automate these 'pass through' operations?
141156
object BasicOperators extends Strategy {
157+
// TOOD: Set
158+
val numPartitions = 200
142159
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
143160
case logical.Distinct(child) =>
144161
execution.Aggregate(
@@ -160,7 +177,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
160177
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
161178
case logical.LocalRelation(output, data) =>
162179
val dataAsRdd =
163-
sparkContext.parallelize(data.map(r => new GenericRow(r.productIterator.toVector): Row))
180+
sparkContext.parallelize(data.map(r =>
181+
new GenericRow(r.productIterator.map(convertToCatalyst).toVector): Row))
164182
execution.ExistingRdd(output, dataAsRdd) :: Nil
165183
case logical.StopAfter(limit, child) =>
166184
execution.StopAfter(
@@ -172,6 +190,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
172190
execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil
173191
case logical.NoRelation =>
174192
execution.ExistingRdd(Nil, singleRowRdd) :: Nil
193+
case logical.Repartition(expressions, child) =>
194+
execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
175195
case _ => Nil
176196
}
177197
}

core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,23 @@ case class StopAfter(limit: Int, child: SparkPlan)(@transient sc: SparkContext)
5353
def execute() = sc.makeRDD(executeCollect(), 1)
5454
}
5555

56+
case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
57+
(@transient sc: SparkContext) extends UnaryNode {
58+
override def otherCopyArgs = sc :: Nil
59+
60+
def output = child.output
61+
62+
@transient
63+
lazy val ordering = new RowOrdering(sortOrder)
64+
65+
override def executeCollect() = child.execute().takeOrdered(limit)(ordering)
66+
67+
// TODO: Terminal split should be implemented differently from non-terminal split.
68+
// TODO: Pick num splits based on |limit|.
69+
def execute() = sc.makeRDD(executeCollect(), 1)
70+
}
71+
72+
5673
case class Sort(
5774
sortOrder: Seq[SortOrder],
5875
global: Boolean,

core/src/main/scala/org/apache/spark/sql/execution/package.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,4 @@ package org.apache.spark.sql
44
* An execution engine for relational query plans that runs on top Spark and returns RDDs.
55
*/
66
package object execution {
7-
type Row = catalyst.expressions.Row
87
}

core/src/test/scala/org/apache/spark/sql/DslQueryTests.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,13 @@ class BasicQuerySuite extends DslQueryTest {
9696
testData.data)
9797
}
9898

99+
test("agg") {
100+
checkAnswer(
101+
testData2.groupBy('a)('a, Sum('b)),
102+
Seq((1,3),(2,3),(3,3))
103+
)
104+
}
105+
99106
test("select *") {
100107
checkAnswer(
101108
testData.select(Star(None)),

shark/src/main/scala/org/apache/spark/sql/shark/HiveMetastoreCatalog.scala

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import catalyst.types._
2020
import scala.collection.JavaConversions._
2121

2222
class HiveMetastoreCatalog(shark: SharkContext) extends Catalog with Logging {
23+
import HiveMetastoreTypes._
24+
2325
val client = Hive.get(shark.hiveconf)
2426

2527
def lookupRelation(
@@ -42,37 +44,39 @@ class HiveMetastoreCatalog(shark: SharkContext) extends Catalog with Logging {
4244
alias)(table.getTTable, partitions.map(part => part.getTPartition))
4345
}
4446

47+
def createTable(databaseName: String, tableName: String, schema: Seq[Attribute]) {
48+
val table = new Table(databaseName, tableName)
49+
val hiveSchema =
50+
schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), ""))
51+
table.setFields(hiveSchema)
52+
53+
val sd = new StorageDescriptor()
54+
table.getTTable.setSd(sd)
55+
sd.setCols(hiveSchema)
56+
57+
// TODO: THESE ARE ALL DEFAULTS, WE NEED TO PARSE / UNDERSTAND the output specs.
58+
sd.setCompressed(false)
59+
sd.setParameters(Map[String, String]())
60+
sd.setInputFormat("org.apache.hadoop.mapred.TextInputFormat")
61+
sd.setOutputFormat("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")
62+
val serDeInfo = new SerDeInfo()
63+
serDeInfo.setName(tableName)
64+
serDeInfo.setSerializationLib("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")
65+
serDeInfo.setParameters(Map[String, String]())
66+
sd.setSerdeInfo(serDeInfo)
67+
client.createTable(table)
68+
}
69+
4570
/**
4671
* Creates any tables required for query execution.
4772
* For example, because of a CREATE TABLE X AS statement.
4873
*/
4974
object CreateTables extends Rule[LogicalPlan] {
50-
import HiveMetastoreTypes._
51-
5275
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
5376
case InsertIntoCreatedTable(db, tableName, child) =>
5477
val databaseName = db.getOrElse(SessionState.get.getCurrentDatabase())
5578

56-
val table = new Table(databaseName, tableName)
57-
val schema =
58-
child.output.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), ""))
59-
table.setFields(schema)
60-
61-
val sd = new StorageDescriptor()
62-
table.getTTable.setSd(sd)
63-
sd.setCols(schema)
64-
65-
// TODO: THESE ARE ALL DEFAULTS, WE NEED TO PARSE / UNDERSTAND the output specs.
66-
sd.setCompressed(false)
67-
sd.setParameters(Map[String, String]())
68-
sd.setInputFormat("org.apache.hadoop.mapred.TextInputFormat")
69-
sd.setOutputFormat("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")
70-
val serDeInfo = new SerDeInfo()
71-
serDeInfo.setName(tableName)
72-
serDeInfo.setSerializationLib("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")
73-
serDeInfo.setParameters(Map[String, String]())
74-
sd.setSerdeInfo(serDeInfo)
75-
client.createTable(table)
79+
createTable(databaseName, tableName, child.output)
7680

7781
InsertIntoTable(
7882
lookupRelation(Some(databaseName), tableName, None).asInstanceOf[BaseRelation],

0 commit comments

Comments
 (0)