-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-19800][SS][WIP] Implement one kind of streaming sampling - reservoir sampling #17141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3c7dc19
23738cf
288c124
c4008cd
1ddb82e
02d44aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,9 +17,12 @@ | |
|
||
package org.apache.spark.sql.execution.streaming | ||
|
||
import scala.util.Random | ||
|
||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.sql.Encoders | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder | ||
import org.apache.spark.sql.catalyst.encoders._ | ||
import org.apache.spark.sql.catalyst.errors._ | ||
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} | ||
|
@@ -32,6 +35,7 @@ import org.apache.spark.sql.execution.streaming.state._ | |
import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode} | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.util.CompletionIterator | ||
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} | ||
|
||
|
||
/** Used to identify the state store for a given operator. */ | ||
|
@@ -127,8 +131,8 @@ case class StateStoreRestoreExec( | |
|
||
child.execute().mapPartitionsWithStateStore( | ||
getStateId.checkpointLocation, | ||
operatorId = getStateId.operatorId, | ||
storeVersion = getStateId.batchId, | ||
getStateId.operatorId, | ||
getStateId.batchId, | ||
keyExpressions.toStructType, | ||
child.output.toStructType, | ||
sqlContext.sessionState, | ||
|
@@ -322,3 +326,110 @@ object StreamingDeduplicateExec { | |
private val EMPTY_ROW = | ||
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) | ||
} | ||
|
||
/** | ||
* Physical operator for executing streaming Sampling. | ||
* | ||
* @param reservoirSize number of random sample elements. | ||
*/ | ||
case class StreamingReservoirSampleExec( | ||
keyExpressions: Seq[Attribute], | ||
child: SparkPlan, | ||
reservoirSize: Int, | ||
stateId: Option[OperatorStateId] = None, | ||
eventTimeWatermark: Option[Long] = None, | ||
outputMode: Option[OutputMode] = None) | ||
extends UnaryExecNode with StateStoreWriter with WatermarkSupport { | ||
|
||
override def requiredChildDistribution: Seq[Distribution] = | ||
ClusteredDistribution(keyExpressions) :: Nil | ||
|
||
private val enc = Encoders.STRING.asInstanceOf[ExpressionEncoder[String]] | ||
private val NUM_RECORDS_IN_PARTITION = enc.toRow("NUM_RECORDS_IN_PARTITION") | ||
.asInstanceOf[UnsafeRow] | ||
|
||
override protected def doExecute(): RDD[InternalRow] = { | ||
metrics | ||
val fieldTypes = (keyExpressions.map(_.dataType) ++ Seq(LongType)).toArray | ||
val withSumFieldTypes = (keyExpressions.map(_.dataType) ++ Seq(LongType)).toArray | ||
|
||
child.execute().mapPartitionsWithStateStore( | ||
getStateId.checkpointLocation, | ||
getStateId.operatorId, | ||
getStateId.batchId, | ||
keyExpressions.toStructType, | ||
child.output.toStructType, | ||
sqlContext.sessionState, | ||
Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => | ||
|
||
val numRecordsInPart = store.get(NUM_RECORDS_IN_PARTITION).map(value => { | ||
value.get(0, LongType).asInstanceOf[Long] | ||
}).getOrElse(0L) | ||
|
||
val seed = Random.nextLong() | ||
val rand = new XORShiftRandom(seed) | ||
var numSamples = numRecordsInPart | ||
var count = 0 | ||
|
||
val baseIterator = watermarkPredicate match { | ||
case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) | ||
case None => iter | ||
} | ||
|
||
baseIterator.foreach { r => | ||
count += 1 | ||
if (numSamples < reservoirSize) { | ||
numSamples += 1 | ||
store.put(enc.toRow(numSamples.toString).asInstanceOf[UnsafeRow], | ||
r.asInstanceOf[UnsafeRow]) | ||
} else { | ||
val randomIdx = (rand.nextDouble() * (numRecordsInPart + count)).toLong | ||
if (randomIdx <= reservoirSize) { | ||
val replacementIdx = enc.toRow(randomIdx.toString).asInstanceOf[UnsafeRow] | ||
store.put(replacementIdx, r.asInstanceOf[UnsafeRow]) | ||
} | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In partiton, we just need to do once normal (without weight) reservoir sampling. |
||
|
||
val numRecordsTillNow = UnsafeProjection.create(Array[DataType](LongType)) | ||
.apply(InternalRow.apply(numRecordsInPart + count)) | ||
store.put(NUM_RECORDS_IN_PARTITION, numRecordsTillNow) | ||
store.commit() | ||
|
||
outputMode match { | ||
case Some(Complete) => | ||
CompletionIterator[InternalRow, Iterator[InternalRow]]( | ||
store.iterator().filter(kv => { | ||
!kv._1.asInstanceOf[UnsafeRow].equals(NUM_RECORDS_IN_PARTITION) | ||
}).map(kv => { | ||
UnsafeProjection.create(withSumFieldTypes).apply(InternalRow.fromSeq( | ||
new JoinedRow(kv._2, numRecordsTillNow) | ||
.toSeq(withSumFieldTypes))) | ||
}), {}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, we transfer the row to (row, numRecordsTillNow), and |
||
case Some(Update) => | ||
CompletionIterator[InternalRow, Iterator[InternalRow]]( | ||
store.updates() | ||
.filter(update => !update.key.equals(NUM_RECORDS_IN_PARTITION)) | ||
.map(update => { | ||
UnsafeProjection.create(withSumFieldTypes).apply(InternalRow.fromSeq( | ||
new JoinedRow(update.value, numRecordsTillNow) | ||
.toSeq(withSumFieldTypes))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
}), {}) | ||
case _ => | ||
throw new UnsupportedOperationException(s"Invalid output mode: $outputMode " + | ||
s"for streaming sampling.") | ||
} | ||
}.repartition(1).mapPartitions(it => { | ||
SamplingUtils.reservoirSampleWithWeight( | ||
it.map(item => (item, item.getLong(keyExpressions.size))), reservoirSize) | ||
.map(row => | ||
UnsafeProjection.create(fieldTypes) | ||
.apply(InternalRow.fromSeq(row.toSeq(fieldTypes))) | ||
).iterator | ||
}) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here, we do once global weight reservoir sampling. |
||
|
||
override def output: Seq[Attribute] = child.output | ||
|
||
override def outputPartitioning: Partitioning = child.outputPartitioning | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NUM_RECORDS_IN_PARTITION
calculate the total number of records in current partiton, and update at the end of sample.