Skip to content

Commit 77e845c

Browse files
marmbrusrxin
authored andcommitted
[SPARK-4394][SQL] Data Sources API Improvements
This PR adds two features to the data sources API: - Support for pushing down `IN` filters - The ability for relations to optionally provide information about their `sizeInBytes`. Author: Michael Armbrust <michael@databricks.com> Closes #3260 from marmbrus/sourcesImprovements and squashes the following commits: 9a5e171 [Michael Armbrust] Use method instead of configuration directly 99c0e6b [Michael Armbrust] Add support for sizeInBytes. 416f167 [Michael Armbrust] Support for IN in data sources API. 2a04ab3 [Michael Armbrust] Simplify implementation of InSet.
1 parent e421072 commit 77e845c

File tree

9 files changed

+32
-15
lines changed

9 files changed

+32
-15
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
9999
* Optimized version of In clause, when all filter values of In clause are
100100
* static.
101101
*/
102-
case class InSet(value: Expression, hset: HashSet[Any], child: Seq[Expression])
102+
case class InSet(value: Expression, hset: Set[Any])
103103
extends Predicate {
104104

105-
def children = child
105+
def children = value :: Nil
106106

107107
def nullable = true // TODO: Figure out correct nullability semantics of IN.
108108
override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ object OptimizeIn extends Rule[LogicalPlan] {
289289
case q: LogicalPlan => q transformExpressionsDown {
290290
case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
291291
val hSet = list.map(e => e.eval(null))
292-
InSet(v, HashSet() ++ hSet, v +: list)
292+
InSet(v, HashSet() ++ hSet)
293293
}
294294
}
295295
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,13 @@ class ExpressionEvaluationSuite extends FunSuite {
158158
val nl = Literal(null)
159159
val s = Seq(one, two)
160160
val nullS = Seq(one, two, null)
161-
checkEvaluation(InSet(one, hS, one +: s), true)
162-
checkEvaluation(InSet(two, hS, two +: s), true)
163-
checkEvaluation(InSet(two, nS, two +: nullS), true)
164-
checkEvaluation(InSet(nl, nS, nl +: nullS), true)
165-
checkEvaluation(InSet(three, hS, three +: s), false)
166-
checkEvaluation(InSet(three, nS, three +: nullS), false)
167-
checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true)
161+
checkEvaluation(InSet(one, hS), true)
162+
checkEvaluation(InSet(two, hS), true)
163+
checkEvaluation(InSet(two, nS), true)
164+
checkEvaluation(InSet(nl, nS), true)
165+
checkEvaluation(InSet(three, hS), false)
166+
checkEvaluation(InSet(three, nS), false)
167+
checkEvaluation(InSet(one, hS) && InSet(two, hS), true)
168168
}
169169

170170
test("MaxOf") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ class OptimizeInSuite extends PlanTest {
5252
val optimized = Optimize(originalQuery.analyze)
5353
val correctAnswer =
5454
testRelation
55-
.where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2,
56-
UnresolvedAttribute("a") +: Seq(Literal(1),Literal(2))))
55+
.where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2))
5756
.analyze
5857

5958
comparePlans(optimized, correctAnswer)

sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,7 @@ private[sql] object DataSourceStrategy extends Strategy {
108108

109109
case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v)
110110
case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v)
111+
112+
case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray)
111113
}
112114
}

sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ private[sql] case class LogicalRelation(relation: BaseRelation)
4141
}
4242

4343
@transient override lazy val statistics = Statistics(
44-
// TODO: Allow datasources to provide statistics as well.
45-
sizeInBytes = BigInt(relation.sqlContext.defaultSizeInBytes)
44+
sizeInBytes = BigInt(relation.sizeInBytes)
4645
)
4746

4847
/** Used to lookup original attribute capitalization */

sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ case class GreaterThan(attribute: String, value: Any) extends Filter
2424
case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter
2525
case class LessThan(attribute: String, value: Any) extends Filter
2626
case class LessThanOrEqual(attribute: String, value: Any) extends Filter
27+
case class In(attribute: String, values: Array[Any]) extends Filter

sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package org.apache.spark.sql.sources
1818

1919
import org.apache.spark.annotation.DeveloperApi
2020
import org.apache.spark.rdd.RDD
21-
import org.apache.spark.sql.{Row, SQLContext, StructType}
21+
import org.apache.spark.sql.{SQLConf, Row, SQLContext, StructType}
2222
import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute}
2323

2424
/**
@@ -53,6 +53,15 @@ trait RelationProvider {
5353
abstract class BaseRelation {
5454
def sqlContext: SQLContext
5555
def schema: StructType
56+
57+
/**
58+
* Returns an estimated size of this relation in bytes. This information is used by the planner
59+
* to decided when it is safe to broadcast a relation and can be overridden by sources that
60+
* know the size ahead of time. By default, the system will assume that tables are too
61+
* large to broadcast. This method will be called multiple times during query planning
62+
* and thus should not perform expensive operations for each invocation.
63+
*/
64+
def sizeInBytes = sqlContext.defaultSizeInBytes
5665
}
5766

5867
/**

sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
5151
case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v
5252
case GreaterThan("a", v: Int) => (a: Int) => a > v
5353
case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v
54+
case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a)
5455
}
5556

5657
def eval(a: Int) = !filterFunctions.map(_(a)).contains(false)
@@ -121,6 +122,10 @@ class FilteredScanSuite extends DataSourceTest {
121122
"SELECT * FROM oneToTenFiltered WHERE a = 1",
122123
Seq(1).map(i => Row(i, i * 2)).toSeq)
123124

125+
sqlTest(
126+
"SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)",
127+
Seq(1,3,5).map(i => Row(i, i * 2)).toSeq)
128+
124129
sqlTest(
125130
"SELECT * FROM oneToTenFiltered WHERE A = 1",
126131
Seq(1).map(i => Row(i, i * 2)).toSeq)
@@ -150,6 +155,8 @@ class FilteredScanSuite extends DataSourceTest {
150155

151156
testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8)
152157

158+
testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3)
159+
153160
testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0)
154161
testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10)
155162

0 commit comments

Comments
 (0)