Skip to content

[SPARK-19800][SS][WIP] Implement one kind of streaming sampling - reservoir sampling #17141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,47 @@ private[spark] object SamplingUtils {
}
}

/**
* Weight reservoir sampling implementation.
*
* @param input input size
* @param k reservoir size
* @param seed random seed
* @return samples
*/
def reservoirSampleWithWeight[T: ClassTag](
input: Iterator[(T, Long)],
k: Int,
seed: Long = Random.nextLong())
: Array[T] = {
val reservoir = new Array[T](k)
// Put the first k elements in the reservoir.
var i = 0
while (i < k && input.hasNext) {
val item = input.next()
reservoir(i) = item._1
i += 1
}

if (i < k) {
val trimReservoir = new Array[T](i)
System.arraycopy(reservoir, 0, trimReservoir, 0, i)
trimReservoir
} else {
var l = i.toLong
val rand = new XORShiftRandom(seed)
while (input.hasNext) {
val item = input.next()
l += 1
val replacementIndex = Math.pow(rand.nextDouble(), 1 / item._2).toInt
if (replacementIndex < k) {
reservoir(replacementIndex.toInt) = item._1
}
}
reservoir
}
}

/**
* Returns a sampling rate that guarantees a sample of size greater than or equal to
* sampleSizeLowerBound 99.99% of the time.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,19 @@ case class Sample(
override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil
}

/**
* A logical plan for `reservoir`.
*/
case class ReservoirSample(
keys: Seq[Attribute],
child: LogicalPlan,
reservoirSize: Int,
streaming: Boolean = false)
extends UnaryNode {
override def maxRows: Option[Long] = child.maxRows
override def output: Seq[Attribute] = child.output
}

/**
* Returns a new logical plan that dedups input rows.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down
15 changes: 15 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2019,6 +2019,21 @@ class Dataset[T] private[sql](
Deduplicate(groupCols, logicalPlan, isStreaming)
}

/**
* :: Experimental ::
* (Scala-specific) Reservoir sampling implementation.
*
* @todo move this into sample operator.
* @group typedrel
* @since 2.0.0
*/
@Experimental
@InterfaceStability.Evolving
def reservoir(reservoirSize: Int): Dataset[T] = withTypedPlan {
val allColumns = queryExecution.analyzed.output
ReservoirSample(allColumns, logicalPlan, reservoirSize, isStreaming)
}

/**
* Returns a new Dataset with duplicate rows removed, considering only
* the subset of columns.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,14 @@
package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Strategy
import org.apache.spark.sql.{execution, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.exchange.ShuffleExchange
Expand Down Expand Up @@ -256,6 +254,18 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

/**
* Used to plan the streaming reservoir sample operator.
*/
object ReservoirSampleStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ReservoirSample(keys, child, reservoirSize, true) =>
StreamingReservoirSampleExec(keys, PlanLater(child), reservoirSize) :: Nil

case _ => Nil
}
}

/**
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
*/
Expand Down Expand Up @@ -411,6 +421,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.ReservoirSample(keys, child, reservoirSize, false) =>
execution.ReservoirSampleExec(reservoirSize, PlanLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
LocalTableScanExec(output, data) :: Nil
case logical.LocalLimit(IntegerLiteral(limit), child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration

import org.apache.spark.{InterruptibleIterator, SparkException, TaskContext}
import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
import org.apache.spark.sql.types.LongType
import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler, SamplingUtils}

/** Physical plan for Project. */
case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
Expand Down Expand Up @@ -657,3 +657,20 @@ object SubqueryExec {
private[execution] val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("subquery", 16))
}

case class ReservoirSampleExec(reservoirSize: Int, child: SparkPlan) extends UnaryExecNode {
override def output: Seq[Attribute] = child.output

override def outputPartitioning: Partitioning = child.outputPartitioning

protected override def doExecute(): RDD[InternalRow] = {
child.execute()
.mapPartitions(it => {
val (sample, count) = SamplingUtils.reservoirSampleAndCount(it, reservoirSize)
sample.map((_, count)).toIterator
})
.repartition(1)
.mapPartitions(it => {
SamplingUtils.reservoirSampleWithWeight(it, reservoirSize).iterator})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class IncrementalExecution(
sparkSession.sessionState.planner.FlatMapGroupsWithStateStrategy +:
sparkSession.sessionState.planner.StreamingRelationStrategy +:
sparkSession.sessionState.planner.StreamingDeduplicationStrategy +:
sparkSession.sessionState.planner.ReservoirSampleStrategy +:
sparkSession.sessionState.experimentalMethods.extraStrategies

// Modified planner with stateful operations.
Expand Down Expand Up @@ -83,7 +84,6 @@ class IncrementalExecution(
StateStoreRestoreExec(keys2, None, child))) =>
val stateId =
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)

StateStoreSaveExec(
keys,
Some(stateId),
Expand All @@ -98,13 +98,23 @@ class IncrementalExecution(
case StreamingDeduplicateExec(keys, child, None, None) =>
val stateId =
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)

StreamingDeduplicateExec(
keys,
child,
Some(stateId),
Some(offsetSeqMetadata.batchWatermarkMs))

case StreamingReservoirSampleExec(keys, child, reservoirSize, None, None, None) =>
val stateId =
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
StreamingReservoirSampleExec(
keys,
child,
reservoirSize,
Some(stateId),
Some(offsetSeqMetadata.batchWatermarkMs),
Some(outputMode))

case m: FlatMapGroupsWithStateExec =>
val stateId =
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@

package org.apache.spark.sql.execution.streaming

import scala.util.Random

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
Expand All @@ -32,6 +35,7 @@ import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode}
import org.apache.spark.sql.types._
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}


/** Used to identify the state store for a given operator. */
Expand Down Expand Up @@ -127,8 +131,8 @@ case class StateStoreRestoreExec(

child.execute().mapPartitionsWithStateStore(
getStateId.checkpointLocation,
operatorId = getStateId.operatorId,
storeVersion = getStateId.batchId,
getStateId.operatorId,
getStateId.batchId,
keyExpressions.toStructType,
child.output.toStructType,
sqlContext.sessionState,
Expand Down Expand Up @@ -322,3 +326,110 @@ object StreamingDeduplicateExec {
private val EMPTY_ROW =
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
}

/**
* Physical operator for executing streaming Sampling.
*
* @param reservoirSize number of random sample elements.
*/
case class StreamingReservoirSampleExec(
keyExpressions: Seq[Attribute],
child: SparkPlan,
reservoirSize: Int,
stateId: Option[OperatorStateId] = None,
eventTimeWatermark: Option[Long] = None,
outputMode: Option[OutputMode] = None)
extends UnaryExecNode with StateStoreWriter with WatermarkSupport {

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(keyExpressions) :: Nil

private val enc = Encoders.STRING.asInstanceOf[ExpressionEncoder[String]]
private val NUM_RECORDS_IN_PARTITION = enc.toRow("NUM_RECORDS_IN_PARTITION")
.asInstanceOf[UnsafeRow]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NUM_RECORDS_IN_PARTITION calculate the total number of records in current partiton, and update at the end of sample.

override protected def doExecute(): RDD[InternalRow] = {
metrics
val fieldTypes = (keyExpressions.map(_.dataType) ++ Seq(LongType)).toArray
val withSumFieldTypes = (keyExpressions.map(_.dataType) ++ Seq(LongType)).toArray

child.execute().mapPartitionsWithStateStore(
getStateId.checkpointLocation,
getStateId.operatorId,
getStateId.batchId,
keyExpressions.toStructType,
child.output.toStructType,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>

val numRecordsInPart = store.get(NUM_RECORDS_IN_PARTITION).map(value => {
value.get(0, LongType).asInstanceOf[Long]
}).getOrElse(0L)

val seed = Random.nextLong()
val rand = new XORShiftRandom(seed)
var numSamples = numRecordsInPart
var count = 0

val baseIterator = watermarkPredicate match {
case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
case None => iter
}

baseIterator.foreach { r =>
count += 1
if (numSamples < reservoirSize) {
numSamples += 1
store.put(enc.toRow(numSamples.toString).asInstanceOf[UnsafeRow],
r.asInstanceOf[UnsafeRow])
} else {
val randomIdx = (rand.nextDouble() * (numRecordsInPart + count)).toLong
if (randomIdx <= reservoirSize) {
val replacementIdx = enc.toRow(randomIdx.toString).asInstanceOf[UnsafeRow]
store.put(replacementIdx, r.asInstanceOf[UnsafeRow])
}
}
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In partiton, we just need to do once normal (without weight) reservoir sampling.


val numRecordsTillNow = UnsafeProjection.create(Array[DataType](LongType))
.apply(InternalRow.apply(numRecordsInPart + count))
store.put(NUM_RECORDS_IN_PARTITION, numRecordsTillNow)
store.commit()

outputMode match {
case Some(Complete) =>
CompletionIterator[InternalRow, Iterator[InternalRow]](
store.iterator().filter(kv => {
!kv._1.asInstanceOf[UnsafeRow].equals(NUM_RECORDS_IN_PARTITION)
}).map(kv => {
UnsafeProjection.create(withSumFieldTypes).apply(InternalRow.fromSeq(
new JoinedRow(kv._2, numRecordsTillNow)
.toSeq(withSumFieldTypes)))
}), {})
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we transfer the row to (row, numRecordsTillNow), and numRecordsTillNow is used to calculate the weight of item.

case Some(Update) =>
CompletionIterator[InternalRow, Iterator[InternalRow]](
store.updates()
.filter(update => !update.key.equals(NUM_RECORDS_IN_PARTITION))
.map(update => {
UnsafeProjection.create(withSumFieldTypes).apply(InternalRow.fromSeq(
new JoinedRow(update.value, numRecordsTillNow)
.toSeq(withSumFieldTypes)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

}), {})
case _ =>
throw new UnsupportedOperationException(s"Invalid output mode: $outputMode " +
s"for streaming sampling.")
}
}.repartition(1).mapPartitions(it => {
SamplingUtils.reservoirSampleWithWeight(
it.map(item => (item, item.getLong(keyExpressions.size))), reservoirSize)
.map(row =>
UnsafeProjection.create(fieldTypes)
.apply(InternalRow.fromSeq(row.toSeq(fieldTypes)))
).iterator
})
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here, we do once global weight reservoir sampling.


override def output: Seq[Attribute] = child.output

override def outputPartitioning: Partitioning = child.outputPartitioning
}
Loading