Skip to content

Commit 02d44aa

Browse files
committed
bug fix
1 parent 1ddb82e commit 02d44aa

File tree

6 files changed

+24
-19
lines changed

6 files changed

+24
-19
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,7 @@ case class Sample(
848848
case class ReservoirSample(
849849
keys: Seq[Attribute],
850850
child: LogicalPlan,
851-
k: Int,
851+
reservoirSize: Int,
852852
streaming: Boolean = false)
853853
extends UnaryNode {
854854
override def maxRows: Option[Long] = child.maxRows

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,9 +2029,9 @@ class Dataset[T] private[sql](
20292029
*/
20302030
@Experimental
20312031
@InterfaceStability.Evolving
2032-
def reservoir(k: Int): Dataset[T] = withTypedPlan {
2032+
def reservoir(reservoirSize: Int): Dataset[T] = withTypedPlan {
20332033
val allColumns = queryExecution.analyzed.output
2034-
ReservoirSample(allColumns, logicalPlan, k, isStreaming)
2034+
ReservoirSample(allColumns, logicalPlan, reservoirSize, isStreaming)
20352035
}
20362036

20372037
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
259259
*/
260260
object ReservoirSampleStrategy extends Strategy {
261261
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
262-
case ReservoirSample(keys, child, k, true) =>
263-
StreamingReservoirSampleExec(keys, PlanLater(child), k) :: Nil
262+
case ReservoirSample(keys, child, reservoirSize, true) =>
263+
StreamingReservoirSampleExec(keys, PlanLater(child), reservoirSize) :: Nil
264264

265265
case _ => Nil
266266
}
@@ -421,8 +421,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
421421
execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
422422
case logical.Sample(lb, ub, withReplacement, seed, child) =>
423423
execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil
424-
case logical.ReservoirSample(keys, child, k, false) =>
425-
execution.ReservoirSampleExec(k, PlanLater(child)) :: Nil
424+
case logical.ReservoirSample(keys, child, reservoirSize, false) =>
425+
execution.ReservoirSampleExec(reservoirSize, PlanLater(child)) :: Nil
426426
case logical.LocalRelation(output, data) =>
427427
LocalTableScanExec(output, data) :: Nil
428428
case logical.LocalLimit(IntegerLiteral(limit), child) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -658,19 +658,19 @@ object SubqueryExec {
658658
ThreadUtils.newDaemonCachedThreadPool("subquery", 16))
659659
}
660660

661-
case class ReservoirSampleExec(k: Int, child: SparkPlan) extends UnaryExecNode {
661+
case class ReservoirSampleExec(reservoirSize: Int, child: SparkPlan) extends UnaryExecNode {
662662
override def output: Seq[Attribute] = child.output
663663

664664
override def outputPartitioning: Partitioning = child.outputPartitioning
665665

666666
protected override def doExecute(): RDD[InternalRow] = {
667667
child.execute()
668668
.mapPartitions(it => {
669-
val (sample, count) = SamplingUtils.reservoirSampleAndCount(it, k)
669+
val (sample, count) = SamplingUtils.reservoirSampleAndCount(it, reservoirSize)
670670
sample.map((_, count)).toIterator
671671
})
672672
.repartition(1)
673673
.mapPartitions(it => {
674-
SamplingUtils.reservoirSampleWithWeight(it, k).iterator})
674+
SamplingUtils.reservoirSampleWithWeight(it, reservoirSize).iterator})
675675
}
676676
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,18 @@ class IncrementalExecution(
103103
child,
104104
Some(stateId),
105105
Some(offsetSeqMetadata.batchWatermarkMs))
106-
107-
case StreamingReservoirSampleExec(k, keys, child, None, None, None) =>
106+
107+
case StreamingReservoirSampleExec(keys, child, reservoirSize, None, None, None) =>
108108
val stateId =
109109
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
110110
StreamingReservoirSampleExec(
111-
k, keys, child, Some(stateId), Some(currentEventTimeWatermark), Some(outputMode))
112-
111+
keys,
112+
child,
113+
reservoirSize,
114+
Some(stateId),
115+
Some(offsetSeqMetadata.batchWatermarkMs),
116+
Some(outputMode))
117+
113118
case m: FlatMapGroupsWithStateExec =>
114119
val stateId =
115120
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,12 +330,12 @@ object StreamingDeduplicateExec {
330330
/**
331331
* Physical operator for executing streaming Sampling.
332332
*
333-
* @param k random sample k elements.
333+
* @param reservoirSize number of random sample elements.
334334
*/
335335
case class StreamingReservoirSampleExec(
336336
keyExpressions: Seq[Attribute],
337337
child: SparkPlan,
338-
k: Int,
338+
reservoirSize: Int,
339339
stateId: Option[OperatorStateId] = None,
340340
eventTimeWatermark: Option[Long] = None,
341341
outputMode: Option[OutputMode] = None)
@@ -378,13 +378,13 @@ case class StreamingReservoirSampleExec(
378378

379379
baseIterator.foreach { r =>
380380
count += 1
381-
if (numSamples < k) {
381+
if (numSamples < reservoirSize) {
382382
numSamples += 1
383383
store.put(enc.toRow(numSamples.toString).asInstanceOf[UnsafeRow],
384384
r.asInstanceOf[UnsafeRow])
385385
} else {
386386
val randomIdx = (rand.nextDouble() * (numRecordsInPart + count)).toLong
387-
if (randomIdx <= k) {
387+
if (randomIdx <= reservoirSize) {
388388
val replacementIdx = enc.toRow(randomIdx.toString).asInstanceOf[UnsafeRow]
389389
store.put(replacementIdx, r.asInstanceOf[UnsafeRow])
390390
}
@@ -421,7 +421,7 @@ case class StreamingReservoirSampleExec(
421421
}
422422
}.repartition(1).mapPartitions(it => {
423423
SamplingUtils.reservoirSampleWithWeight(
424-
it.map(item => (item, item.getLong(keyExpressions.size))), k)
424+
it.map(item => (item, item.getLong(keyExpressions.size))), reservoirSize)
425425
.map(row =>
426426
UnsafeProjection.create(fieldTypes)
427427
.apply(InternalRow.fromSeq(row.toSeq(fieldTypes)))

0 commit comments

Comments
 (0)