Skip to content

Commit 1d35c3c

Browse files
committed
Merge branch 'master' of https://github.com/apache/spark into SPARK-1712_new
2 parents 062c182 + c33b8dc commit 1d35c3c

File tree

10 files changed

+225
-12
lines changed

10 files changed

+225
-12
lines changed

core/src/main/scala/org/apache/spark/Partitioner.scala

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,64 @@ class RangePartitioner[K : Ordering : ClassTag, V](
156156
false
157157
}
158158
}
159+
160+
/**
161+
* A [[org.apache.spark.Partitioner]] that partitions records into specified bounds
162+
* Default value is 1000. Once all partitions have bounds elements, the partitioner
163+
* allocates 1 element per partition so eventually the smaller partitions are at most
164+
* off by 1 key compared to the larger partitions.
165+
*/
166+
class BoundaryPartitioner[K : Ordering : ClassTag, V](
167+
partitions: Int,
168+
@transient rdd: RDD[_ <: Product2[K,V]],
169+
private val boundary: Int = 1000)
170+
extends Partitioner {
171+
172+
// this array keeps track of keys assigned to a partition
173+
// counts[0] refers to # of keys in partition 0 and so on
174+
private val counts: Array[Int] = {
175+
new Array[Int](numPartitions)
176+
}
177+
178+
def numPartitions = math.abs(partitions)
179+
180+
/*
181+
* Ideally, this should've been calculated based on # partitions and total keys
182+
* But we are not calling count on RDD here to avoid calling an action.
183+
* User has the flexibility of calling count and passing in any appropriate boundary
184+
*/
185+
def keysPerPartition = boundary
186+
187+
var currPartition = 0
188+
189+
/*
190+
* Pick current partition for the key until we hit the bound for keys / partition,
191+
* start allocating to next partition at that time.
192+
*
193+
* NOTE: In case where we have lets say 2000 keys and user says 3 partitions with 500
194+
* passed in as boundary, the first 500 will goto P1, 501-1000 go to P2, 1001-1500 go to P3,
195+
* after that, next keys go to one partition at a time. So 1501 goes to P1, 1502 goes to P2,
196+
* 1503 goes to P3 and so on.
197+
*/
198+
def getPartition(key: Any): Int = {
199+
val partition = currPartition
200+
counts(partition) = counts(partition) + 1
201+
/*
202+
* Since we are filling up a partition before moving to next one (this helps in maintaining
203+
* order of keys, in certain cases, it is possible to end up with empty partitions, like
204+
* 3 partitions, 500 keys / partition and if rdd has 700 keys, 1 partition will be entirely
205+
* empty.
206+
*/
207+
if(counts(currPartition) >= keysPerPartition)
208+
currPartition = (currPartition + 1) % numPartitions
209+
partition
210+
}
211+
212+
override def equals(other: Any): Boolean = other match {
213+
case r: BoundaryPartitioner[_,_] =>
214+
(r.counts.sameElements(counts) && r.boundary == boundary
215+
&& r.currPartition == currPartition)
216+
case _ =>
217+
false
218+
}
219+
}

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
217217
* Return approximate number of distinct values for each key in this RDD.
218218
* The accuracy of approximation can be controlled through the relative standard deviation
219219
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
220-
* more accurate counts but increase the memory footprint and vise versa. Uses the provided
220+
* more accurate counts but increase the memory footprint and vice versa. Uses the provided
221221
* Partitioner to partition the output RDD.
222222
*/
223223
def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): RDD[(K, Long)] = {
@@ -232,7 +232,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
232232
* Return approximate number of distinct values for each key in this RDD.
233233
* The accuracy of approximation can be controlled through the relative standard deviation
234234
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
235-
* more accurate counts but increase the memory footprint and vise versa. HashPartitions the
235+
* more accurate counts but increase the memory footprint and vice versa. HashPartitions the
236236
* output RDD into numPartitions.
237237
*
238238
*/
@@ -244,7 +244,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
244244
* Return approximate number of distinct values for each key this RDD.
245245
* The accuracy of approximation can be controlled through the relative standard deviation
246246
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
247-
* more accurate counts but increase the memory footprint and vise versa. The default value of
247+
* more accurate counts but increase the memory footprint and vice versa. The default value of
248248
* relativeSD is 0.05. Hash-partitions the output RDD using the existing partitioner/parallelism
249249
* level.
250250
*/

core/src/test/scala/org/apache/spark/PartitioningSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,40 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
6666
assert(descendingP4 != p4)
6767
}
6868

69+
test("BoundaryPartitioner equality") {
70+
// Make an RDD where all the elements are the same so that the partition range bounds
71+
// are deterministically all the same.
72+
val rdd = sc.parallelize(1.to(4000)).map(x => (x, x))
73+
74+
val p2 = new BoundaryPartitioner(2, rdd, 1000)
75+
val p4 = new BoundaryPartitioner(4, rdd, 1000)
76+
val anotherP4 = new BoundaryPartitioner(4, rdd)
77+
78+
assert(p2 === p2)
79+
assert(p4 === p4)
80+
assert(p2 != p4)
81+
assert(p4 != p2)
82+
assert(p4 === anotherP4)
83+
assert(anotherP4 === p4)
84+
}
85+
86+
test("BoundaryPartitioner getPartition") {
87+
val rdd = sc.parallelize(1.to(2000)).map(x => (x, x))
88+
val partitioner = new BoundaryPartitioner(4, rdd, 500)
89+
1.to(2000).map { element => {
90+
val partition = partitioner.getPartition(element)
91+
if (element <= 500) {
92+
assert(partition === 0)
93+
} else if (element > 501 && element <= 1000) {
94+
assert(partition === 1)
95+
} else if (element > 1001 && element <= 1500) {
96+
assert(partition === 2)
97+
} else if (element > 1501 && element <= 2000) {
98+
assert(partition === 3)
99+
}
100+
}}
101+
}
102+
69103
test("RangePartitioner getPartition") {
70104
val rdd = sc.parallelize(1.to(2000)).map(x => (x, x))
71105
// We have different behaviour of getPartition for partitions with less than 1000 and more than

core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
5252
rootDir0.deleteOnExit()
5353
rootDir1 = Files.createTempDir()
5454
rootDir1.deleteOnExit()
55-
rootDirs = rootDir0.getName + "," + rootDir1.getName
55+
rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
5656
println("Created root dirs: " + rootDirs)
5757
}
5858

python/pyspark/sql.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class SQLContext:
2828
register L{SchemaRDD}s as tables, execute sql over tables, cache tables, and read parquet files.
2929
"""
3030

31-
def __init__(self, sparkContext):
31+
def __init__(self, sparkContext, sqlContext = None):
3232
"""
3333
Create a new SQLContext.
3434
@@ -58,10 +58,13 @@ def __init__(self, sparkContext):
5858
self._jvm = self._sc._jvm
5959
self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap
6060

61+
if sqlContext:
62+
self._scala_SQLContext = sqlContext
63+
6164
@property
6265
def _ssql_ctx(self):
6366
"""
64-
Accessor for the JVM SparkSQL context. Subclasses can overrite this property to provide
67+
Accessor for the JVM SparkSQL context. Subclasses can override this property to provide
6568
their own JVM Contexts.
6669
"""
6770
if not hasattr(self, '_scala_SQLContext'):

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
9393
protected val AND = Keyword("AND")
9494
protected val AS = Keyword("AS")
9595
protected val ASC = Keyword("ASC")
96+
protected val APPROXIMATE = Keyword("APPROXIMATE")
9697
protected val AVG = Keyword("AVG")
9798
protected val BY = Keyword("BY")
9899
protected val CAST = Keyword("CAST")
@@ -318,6 +319,12 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
318319
COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } |
319320
COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } |
320321
COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } |
322+
APPROXIMATE ~> COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ {
323+
case exp => ApproxCountDistinct(exp)
324+
} |
325+
APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ {
326+
case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble)
327+
} |
321328
FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
322329
AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
323330
MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } |

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import com.clearspring.analytics.stream.cardinality.HyperLogLog
21+
2022
import org.apache.spark.sql.catalyst.types._
2123
import org.apache.spark.sql.catalyst.trees
2224
import org.apache.spark.sql.catalyst.errors.TreeNodeException
@@ -146,7 +148,6 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
146148
override def eval(input: Row): Any = currentMax
147149
}
148150

149-
150151
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
151152
override def references = child.references
152153
override def nullable = false
@@ -166,10 +167,47 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi
166167
override def references = expressions.flatMap(_.references).toSet
167168
override def nullable = false
168169
override def dataType = IntegerType
169-
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
170+
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
170171
override def newInstance() = new CountDistinctFunction(expressions, this)
171172
}
172173

174+
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
175+
extends AggregateExpression with trees.UnaryNode[Expression] {
176+
override def references = child.references
177+
override def nullable = false
178+
override def dataType = child.dataType
179+
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
180+
override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this, relativeSD)
181+
}
182+
183+
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
184+
extends AggregateExpression with trees.UnaryNode[Expression] {
185+
override def references = child.references
186+
override def nullable = false
187+
override def dataType = IntegerType
188+
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
189+
override def newInstance() = new ApproxCountDistinctMergeFunction(child, this, relativeSD)
190+
}
191+
192+
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
193+
extends PartialAggregate with trees.UnaryNode[Expression] {
194+
override def references = child.references
195+
override def nullable = false
196+
override def dataType = IntegerType
197+
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
198+
199+
override def asPartial: SplitEvaluation = {
200+
val partialCount =
201+
Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")()
202+
203+
SplitEvaluation(
204+
ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD),
205+
partialCount :: Nil)
206+
}
207+
208+
override def newInstance() = new CountDistinctFunction(child :: Nil, this)
209+
}
210+
173211
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
174212
override def references = child.references
175213
override def nullable = false
@@ -269,6 +307,42 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
269307
override def eval(input: Row): Any = count
270308
}
271309

310+
case class ApproxCountDistinctPartitionFunction(
311+
expr: Expression,
312+
base: AggregateExpression,
313+
relativeSD: Double)
314+
extends AggregateFunction {
315+
def this() = this(null, null, 0) // Required for serialization.
316+
317+
private val hyperLogLog = new HyperLogLog(relativeSD)
318+
319+
override def update(input: Row): Unit = {
320+
val evaluatedExpr = expr.eval(input)
321+
if (evaluatedExpr != null) {
322+
hyperLogLog.offer(evaluatedExpr)
323+
}
324+
}
325+
326+
override def eval(input: Row): Any = hyperLogLog
327+
}
328+
329+
case class ApproxCountDistinctMergeFunction(
330+
expr: Expression,
331+
base: AggregateExpression,
332+
relativeSD: Double)
333+
extends AggregateFunction {
334+
def this() = this(null, null, 0) // Required for serialization.
335+
336+
private val hyperLogLog = new HyperLogLog(relativeSD)
337+
338+
override def update(input: Row): Unit = {
339+
val evaluatedExpr = expr.eval(input)
340+
hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
341+
}
342+
343+
override def eval(input: Row): Any = hyperLogLog.cardinality()
344+
}
345+
272346
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
273347
def this() = this(null, null) // Required for serialization.
274348

sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ import org.apache.spark.util.Utils
3333
/**
3434
* The entry point for executing Spark SQL queries from a Java program.
3535
*/
36-
class JavaSQLContext(sparkContext: JavaSparkContext) {
36+
class JavaSQLContext(val sqlContext: SQLContext) {
3737

38-
val sqlContext = new SQLContext(sparkContext.sc)
38+
def this(sparkContext: JavaSparkContext) = this(new SQLContext(sparkContext.sc))
3939

4040
/**
4141
* Executes a query expressed in SQL, returning the result as a JavaSchemaRDD

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.nio.ByteBuffer
2121

2222
import scala.reflect.ClassTag
2323

24+
import com.clearspring.analytics.stream.cardinality.HyperLogLog
2425
import com.esotericsoftware.kryo.io.{Input, Output}
2526
import com.esotericsoftware.kryo.{Serializer, Kryo}
2627

@@ -44,6 +45,8 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
4445
kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
4546
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
4647
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
48+
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
49+
new HyperLogLogSerializer)
4750
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
4851
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
4952
kryo.setReferences(false)
@@ -81,6 +84,20 @@ private[sql] class BigDecimalSerializer extends Serializer[BigDecimal] {
8184
}
8285
}
8386

87+
private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
88+
def write(kryo: Kryo, output: Output, hyperLogLog: HyperLogLog) {
89+
val bytes = hyperLogLog.getBytes()
90+
output.writeInt(bytes.length)
91+
output.writeBytes(bytes)
92+
}
93+
94+
def read(kryo: Kryo, input: Input, tpe: Class[HyperLogLog]): HyperLogLog = {
95+
val length = input.readInt()
96+
val bytes = input.readBytes(length)
97+
HyperLogLog.Builder.build(bytes)
98+
}
99+
}
100+
84101
/**
85102
* Maps do not have a no arg constructor and so cannot be serialized by default. So, we serialize
86103
* them as `Array[(k,v)]`.

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,25 @@ class SQLQuerySuite extends QueryTest {
9696
test("count") {
9797
checkAnswer(
9898
sql("SELECT COUNT(*) FROM testData2"),
99-
testData2.count()
100-
)
99+
testData2.count())
100+
}
101+
102+
test("count distinct") {
103+
checkAnswer(
104+
sql("SELECT COUNT(DISTINCT b) FROM testData2"),
105+
2)
106+
}
107+
108+
test("approximate count distinct") {
109+
checkAnswer(
110+
sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"),
111+
3)
112+
}
113+
114+
test("approximate count distinct with user provided standard deviation") {
115+
checkAnswer(
116+
sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"),
117+
3)
101118
}
102119

103120
// No support for primitive nulls yet.

0 commit comments

Comments
 (0)