Skip to content

[SPARK-23092][SQL] Migrate MemoryStream to DataSourceV2 APIs #20279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.spark.sql.execution.streaming

import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2}

/**
* A simple offset for sources that produce a single linear stream of data.
*/
case class LongOffset(offset: Long) extends Offset {
case class LongOffset(offset: Long) extends OffsetV2 {

override val json = offset.toString

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ class MicroBatchExecution(
toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
Optional.empty())

(s, Some(s.getEndOffset))
(s, Option(s.getEndOffset))
}
}.toMap
availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get)
Expand Down Expand Up @@ -396,10 +396,14 @@ class MicroBatchExecution(
case (reader: MicroBatchReader, available)
if committedOffsets.get(reader).map(_ != available).getOrElse(true) =>
val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json))
val availableV2: OffsetV2 = available match {
case v1: SerializedOffset => reader.deserializeOffset(v1.json)
case v2: OffsetV2 => v2
}
reader.setOffsetRange(
toJava(current),
Optional.of(available.asInstanceOf[OffsetV2]))
logDebug(s"Retrieving data from $reader: $current -> $available")
Optional.of(availableV2))
logDebug(s"Retrieving data from $reader: $current -> $availableV2")
Some(reader ->
new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader))
case _ => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

package org.apache.spark.sql.execution.streaming

import java.{util => ju}
import java.util.Optional
import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.util.control.NonFatal

Expand All @@ -31,7 +32,8 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.sources.v2.reader.{DataReader, ReadTask}
import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
Expand All @@ -51,9 +53,10 @@ object MemoryStream {
* available.
*/
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
extends Source with Logging {
extends MicroBatchReader with Logging {
protected val encoder = encoderFor[A]
protected val logicalPlan = StreamingExecutionRelation(this, sqlContext.sparkSession)
private val attributes = encoder.schema.toAttributes
protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
protected val output = logicalPlan.output

/**
Expand All @@ -66,15 +69,16 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
@GuardedBy("this")
protected var currentOffset: LongOffset = new LongOffset(-1)

private var startOffset = new LongOffset(-1)
private var endOffset = new LongOffset(-1)

/**
* Last offset that was discarded, or -1 if no commits have occurred. Note that the value
* -1 is used in calculations below and isn't just an arbitrary constant.
*/
@GuardedBy("this")
protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)

def schema: StructType = encoder.schema

def toDS(): Dataset[A] = {
Dataset(sqlContext.sparkSession, logicalPlan)
}
Expand All @@ -89,7 +93,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)

def addData(data: TraversableOnce[A]): Offset = {
val encoded = data.toVector.map(d => encoder.toRow(d).copy())
val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true)
val plan = new LocalRelation(attributes, encoded, isStreaming = false)
val ds = Dataset[A](sqlContext.sparkSession, plan)
logDebug(s"Adding ds: $ds")
this.synchronized {
Expand All @@ -101,19 +105,25 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)

override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]"

override def getOffset: Option[Offset] = synchronized {
if (currentOffset.offset == -1) {
None
} else {
Some(currentOffset)
override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = {
if (start.isPresent) {
startOffset = start.get().asInstanceOf[LongOffset]
}
endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset]
}

override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
override def readSchema(): StructType = encoder.schema

override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong)

override def getStartOffset: OffsetV2 = if (startOffset.offset == -1) null else startOffset

override def getEndOffset: OffsetV2 = if (endOffset.offset == -1) null else endOffset

override def createReadTasks(): ju.List[ReadTask[Row]] = {
// Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
val startOrdinal =
start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1
val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1
val startOrdinal = startOffset.offset.toInt + 1
val endOrdinal = endOffset.offset.toInt + 1

// Internal buffer only holds the batches after lastCommittedOffset.
val newBlocks = synchronized {
Expand All @@ -123,19 +133,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
batches.slice(sliceStart, sliceEnd)
}

if (newBlocks.isEmpty) {
return sqlContext.internalCreateDataFrame(
sqlContext.sparkContext.emptyRDD, schema, isStreaming = true)
}

logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal))

newBlocks
.map(_.toDF())
.reduceOption(_ union _)
.getOrElse {
sys.error("No data selected!")
}
newBlocks.map { ds =>
val items = ds.toDF().collect()
new MemoryStreamReadTask(items).asInstanceOf[ReadTask[Row]]
}.asJava
}

private def generateDebugString(
Expand All @@ -153,7 +156,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
}
}

override def commit(end: Offset): Unit = synchronized {
override def commit(end: OffsetV2): Unit = synchronized {
def check(newOffset: LongOffset): Unit = {
val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt

Expand Down Expand Up @@ -181,6 +184,24 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
}
}

class MemoryStreamReadTask(records: Array[Row]) extends ReadTask[Row] {
override def createDataReader(): DataReader[Row] = new MemoryStreamDataReader(records)
}

class MemoryStreamDataReader(records: Array[Row]) extends DataReader[Row] {
private var currentIndex = -1

override def next(): Boolean = {
// Return true as long as the new index is in the array.
currentIndex += 1
currentIndex < records.length
}

override def get(): Row = records(currentIndex)

override def close(): Unit = {}
}

/**
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
* tests and does not provide durability.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends ReadTask[Row] {
}

class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] {
var currentIndex = -1
private var currentIndex = -1

override def next(): Boolean = {
// Return true as long as the new index is in the seq.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,16 +492,16 @@ class StreamSuite extends StreamTest {

val explainWithoutExtended = q.explainInternal(false)
// `extended = false` only displays the physical plan.
assert("LocalRelation".r.findAllMatchIn(explainWithoutExtended).size === 0)
assert("LocalTableScan".r.findAllMatchIn(explainWithoutExtended).size === 1)
assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0)
assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1)
// Use "StateStoreRestore" to verify that it does output a streaming physical plan
assert(explainWithoutExtended.contains("StateStoreRestore"))

val explainWithExtended = q.explainInternal(true)
// `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical
// plan.
assert("LocalRelation".r.findAllMatchIn(explainWithExtended).size === 3)
assert("LocalTableScan".r.findAllMatchIn(explainWithExtended).size === 1)
assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3)
assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1)
// Use "StateStoreRestore" to verify that it does output a streaming physical plan
assert(explainWithExtended.contains("StateStoreRestore"))
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData {
override def toString: String = s"AddData to $source: ${data.mkString(",")}"

override def addData(query: Option[StreamExecution]): (Source, Offset) = {
override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
(source, source.addData(data))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.scheduler._
import org.apache.spark.sql.{Encoder, SparkSession}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2}
import org.apache.spark.sql.streaming.StreamingQueryListener._
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.util.JsonProtocol
Expand Down Expand Up @@ -273,9 +274,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
try {
val input = new MemoryStream[Int](0, sqlContext) {
@volatile var numTriggers = 0
override def getOffset: Option[Offset] = {
override def getEndOffset: OffsetV2 = {
numTriggers += 1
super.getOffset
super.getEndOffset
}
}
val clock = new StreamManualClock()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.streaming

import java.{util => ju}
import java.util.concurrent.CountDownLatch

import org.apache.commons.lang3.RandomStringUtils
Expand All @@ -29,10 +30,12 @@ import org.scalatest.mockito.MockitoSugar

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.reader.ReadTask
import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2}
import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ManualClock
Expand Down Expand Up @@ -207,18 +210,18 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
/** Custom MemoryStream that waits for manual clock to reach a time */
val inputData = new MemoryStream[Int](0, sqlContext) {
// getOffset should take 50 ms the first time it is called
override def getOffset: Option[Offset] = {
val offset = super.getOffset
if (offset.nonEmpty) {
override def getEndOffset: OffsetV2 = {
val offset = super.getEndOffset
if (offset != null) {
clock.waitTillTime(1050)
}
offset
}

// getBatch should take 100 ms the first time it is called
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
if (start.isEmpty) clock.waitTillTime(1150)
super.getBatch(start, end)
override def createReadTasks(): ju.List[ReadTask[Row]] = {
if (getStartOffset.asInstanceOf[LongOffset].offset == -1L) clock.waitTillTime(1150)
super.createReadTasks()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlanner
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState}

Expand Down Expand Up @@ -101,6 +102,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session
override def strategies: Seq[Strategy] = {
experimentalMethods.extraStrategies ++
extraPlanningStrategies ++ Seq(
DataSourceV2Strategy,
FileSourceStrategy,
DataSourceStrategy(conf),
SpecialLimits,
Expand Down