Skip to content

[SPARK-23014][SS] Fully remove V1 memory sink. #24403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions core/src/main/scala/org/apache/spark/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,20 @@ private[spark] object TestUtils {
assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did")
}

/**
* Asserts that exception message contains the message. Please note this checks all
* exceptions in the tree.
*/
def assertExceptionMsg(exception: Throwable, msg: String): Unit = {
var e = exception
var contains = e.getMessage.contains(msg)
while (e.getCause != null && !contains) {
e = e.getCause
contains = e.getMessage.contains(msg)
}
assert(contains, s"Exception tree doesn't contain the expected message: $msg")
}

/**
* Test if a command is available.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.apache.spark.sql.test.TestSparkSession
// Trait to configure StreamTest for kafka continuous execution tests.
trait KafkaContinuousTest extends KafkaSourceTest {
override val defaultTrigger = Trigger.Continuous(1000)
override val defaultUseV2Sink = true

// We need more than the default local[2] to be able to schedule all partitions simultaneously.
override protected def createSparkSession = new TestSparkSession(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -696,12 +696,14 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging {
withClue("transform should fail when ids exceed integer range. ") {
val model = als.fit(df)
def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = {
assert(intercept[SparkException] {
val e1 = intercept[SparkException] {
model.transform(dataFrame).first
}.getMessage.contains(msg))
assert(intercept[StreamingQueryException] {
}
TestUtils.assertExceptionMsg(e1, msg)
val e2 = intercept[StreamingQueryException] {
testTransformer[A](dataFrame, model, "prediction") { _ => }
}.getMessage.contains(msg))
}
TestUtils.assertExceptionMsg(e2, msg)
}
testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"),
df("item")))
Expand Down
10 changes: 3 additions & 7 deletions mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.io.File

import org.scalatest.Suite

import org.apache.spark.{DebugFilesystem, SparkConf, SparkContext}
import org.apache.spark.{DebugFilesystem, SparkConf, SparkContext, TestUtils}
import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK
import org.apache.spark.ml.{Model, PredictionModel, Transformer}
import org.apache.spark.ml.linalg.Vector
Expand Down Expand Up @@ -129,21 +129,17 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
expectedMessagePart : String,
firstResultCol: String) {

def hasExpectedMessage(exception: Throwable): Boolean =
exception.getMessage.contains(expectedMessagePart) ||
(exception.getCause != null && exception.getCause.getMessage.contains(expectedMessagePart))

withClue(s"""Expected message part "${expectedMessagePart}" is not found in DF test.""") {
val exceptionOnDf = intercept[Throwable] {
testTransformerOnDF(dataframe, transformer, firstResultCol)(_ => Unit)
}
assert(hasExpectedMessage(exceptionOnDf))
TestUtils.assertExceptionMsg(exceptionOnDf, expectedMessagePart)
}
withClue(s"""Expected message part "${expectedMessagePart}" is not found in stream test.""") {
val exceptionOnStreamData = intercept[Throwable] {
testTransformerOnStreamData(dataframe, transformer, firstResultCol)(_ => Unit)
}
assert(hasExpectedMessage(exceptionOnStreamData))
TestUtils.assertExceptionMsg(exceptionOnStreamData, expectedMessagePart)
}
}

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def exception(self):
je = self._jsq.exception().get()
msg = je.toString().split(': ', 1)[1] # Drop the Java StreamingQueryException type info
stackTrace = '\n\t at '.join(map(lambda x: x.toString(), je.getStackTrace()))
return StreamingQueryException(msg, stackTrace)
return StreamingQueryException(msg, stackTrace, je.getCause())
else:
return None

Expand Down
12 changes: 10 additions & 2 deletions python/pyspark/sql/tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,19 @@ def test_stream_exception(self):
self.fail("bad udf should fail the query")
except StreamingQueryException as e:
# This is expected
self.assertTrue("ZeroDivisionError" in e.desc)
self._assert_exception_tree_contains_msg(e, "ZeroDivisionError")
finally:
sq.stop()
self.assertTrue(type(sq.exception()) is StreamingQueryException)
self.assertTrue("ZeroDivisionError" in sq.exception().desc)
self._assert_exception_tree_contains_msg(sq.exception(), "ZeroDivisionError")

def _assert_exception_tree_contains_msg(self, exception, msg):
e = exception
contains = msg in e.desc
while e.cause is not None and not contains:
e = e.cause
contains = msg in e.desc
self.assertTrue(contains, "Exception tree doesn't contain the expected message: %s" % msg)

def test_query_manager_await_termination(self):
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
Expand Down
49 changes: 32 additions & 17 deletions python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@


class CapturedException(Exception):
def __init__(self, desc, stackTrace):
def __init__(self, desc, stackTrace, cause=None):
self.desc = desc
self.stackTrace = stackTrace
self.cause = convert_exception(cause) if cause is not None else None

def __str__(self):
return repr(self.desc)
Expand Down Expand Up @@ -57,27 +58,41 @@ class QueryExecutionException(CapturedException):
"""


class UnknownException(CapturedException):
"""
None of the above exceptions.
"""


def convert_exception(e):
s = e.toString()
stackTrace = '\n\t at '.join(map(lambda x: x.toString(), e.getStackTrace()))
c = e.getCause()
if s.startswith('org.apache.spark.sql.AnalysisException: '):
return AnalysisException(s.split(': ', 1)[1], stackTrace, c)
if s.startswith('org.apache.spark.sql.catalyst.analysis'):
return AnalysisException(s.split(': ', 1)[1], stackTrace, c)
if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '):
return ParseException(s.split(': ', 1)[1], stackTrace, c)
if s.startswith('org.apache.spark.sql.streaming.StreamingQueryException: '):
return StreamingQueryException(s.split(': ', 1)[1], stackTrace, c)
if s.startswith('org.apache.spark.sql.execution.QueryExecutionException: '):
return QueryExecutionException(s.split(': ', 1)[1], stackTrace, c)
if s.startswith('java.lang.IllegalArgumentException: '):
return IllegalArgumentException(s.split(': ', 1)[1], stackTrace, c)
return UnknownException(s, stackTrace, c)


def capture_sql_exception(f):
def deco(*a, **kw):
try:
return f(*a, **kw)
except py4j.protocol.Py4JJavaError as e:
s = e.java_exception.toString()
stackTrace = '\n\t at '.join(map(lambda x: x.toString(),
e.java_exception.getStackTrace()))
if s.startswith('org.apache.spark.sql.AnalysisException: '):
raise AnalysisException(s.split(': ', 1)[1], stackTrace)
if s.startswith('org.apache.spark.sql.catalyst.analysis'):
raise AnalysisException(s.split(': ', 1)[1], stackTrace)
if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '):
raise ParseException(s.split(': ', 1)[1], stackTrace)
if s.startswith('org.apache.spark.sql.streaming.StreamingQueryException: '):
raise StreamingQueryException(s.split(': ', 1)[1], stackTrace)
if s.startswith('org.apache.spark.sql.execution.QueryExecutionException: '):
raise QueryExecutionException(s.split(': ', 1)[1], stackTrace)
if s.startswith('java.lang.IllegalArgumentException: '):
raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace)
raise
converted = convert_exception(e.java_exception)
if not isinstance(converted, UnknownException):
raise converted
else:
raise
return deco


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.execution.python._
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2
import org.apache.spark.sql.execution.streaming.sources.MemoryPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery}
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -624,9 +624,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case r: RunnableCommand => ExecutedCommandExec(r) :: Nil

case MemoryPlan(sink, output) =>
val encoder = RowEncoder(sink.schema)
LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil
case MemoryPlanV2(sink, output) =>
val encoder = RowEncoder(StructType.fromAttributes(output))
LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,19 @@ import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.util.control.NonFatal
import scala.collection.mutable.ListBuffer

import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

Expand Down Expand Up @@ -276,85 +272,3 @@ trait MemorySinkBase extends BaseStreamingSink {
def dataSinceBatch(sinceBatchId: Long): Seq[Row]
def latestBatchId: Option[Long]
}

/**
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
* tests and does not provide durability.
*/
class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink
with MemorySinkBase with Logging {

private case class AddedData(batchId: Long, data: Array[Row])

/** An order list of batches that have been written to this [[Sink]]. */
@GuardedBy("this")
private val batches = new ArrayBuffer[AddedData]()

/** Returns all rows that are stored in this [[Sink]]. */
def allData: Seq[Row] = synchronized {
batches.flatMap(_.data)
}

def latestBatchId: Option[Long] = synchronized {
batches.lastOption.map(_.batchId)
}

def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) }

def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized {
batches.filter(_.batchId > sinceBatchId).flatMap(_.data)
}

def toDebugString: String = synchronized {
batches.map { case AddedData(batchId, data) =>
val dataStr = try data.mkString(" ") catch {
case NonFatal(e) => "[Error converting to string]"
}
s"$batchId: $dataStr"
}.mkString("\n")
}

override def addBatch(batchId: Long, data: DataFrame): Unit = {
val notCommitted = synchronized {
latestBatchId.isEmpty || batchId > latestBatchId.get
}
if (notCommitted) {
logDebug(s"Committing batch $batchId to $this")
outputMode match {
case Append | Update =>
val rows = AddedData(batchId, data.collect())
synchronized { batches += rows }

case Complete =>
val rows = AddedData(batchId, data.collect())
synchronized {
batches.clear()
batches += rows
}

case _ =>
throw new IllegalArgumentException(
s"Output mode $outputMode is not supported by MemorySink")
}
} else {
logDebug(s"Skipping already committed batch: $batchId")
}
}

def clear(): Unit = synchronized {
batches.clear()
}

override def toString(): String = "MemorySink"
}

/**
* Used to query the data that has been written into a [[MemorySink]].
*/
case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode {
def this(sink: MemorySink) = this(sink, sink.schema.toAttributes)

private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes)

override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
* tests and does not provide durability.
*/
class MemorySinkV2 extends Table with SupportsWrite with MemorySinkBase with Logging {
class MemorySink extends Table with SupportsWrite with MemorySinkBase with Logging {

override def name(): String = "MemorySinkV2"
override def name(): String = "MemorySink"

override def schema(): StructType = StructType(Nil)

Expand All @@ -69,7 +69,7 @@ class MemorySinkV2 extends Table with SupportsWrite with MemorySinkBase with Log
}

override def buildForStreaming(): StreamingWrite = {
new MemoryStreamingWrite(MemorySinkV2.this, inputSchema, needTruncate)
new MemoryStreamingWrite(MemorySink.this, inputSchema, needTruncate)
}
}
}
Expand Down Expand Up @@ -130,14 +130,14 @@ class MemorySinkV2 extends Table with SupportsWrite with MemorySinkBase with Log
batches.clear()
}

override def toString(): String = "MemorySinkV2"
override def toString(): String = "MemorySink"
}

case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row])
extends WriterCommitMessage {}

class MemoryStreamingWrite(
val sink: MemorySinkV2, schema: StructType, needTruncate: Boolean)
val sink: MemorySink, schema: StructType, needTruncate: Boolean)
extends StreamingWrite {

override def createStreamingWriterFactory: MemoryWriterFactory = {
Expand Down Expand Up @@ -195,9 +195,9 @@ class MemoryDataWriter(partition: Int, schema: StructType)


/**
* Used to query the data that has been written into a [[MemorySinkV2]].
* Used to query the data that has been written into a [[MemorySink]].
*/
case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode {
case class MemoryPlan(sink: MemorySink, override val output: Seq[Attribute]) extends LeafNode {
private val sizePerRow = EstimationUtils.getSizePerRow(output)

override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
if (extraOptions.get("queryName").isEmpty) {
throw new AnalysisException("queryName must be specified for memory sink")
}
val (sink, resultDf) = trigger match {
case _: ContinuousTrigger =>
val s = new MemorySinkV2()
val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes))
(s, r)
case _ =>
val s = new MemorySink(df.schema, outputMode)
val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s))
(s, r)
}
val sink = new MemorySink()
val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink, df.schema.toAttributes))
val chkpointLoc = extraOptions.get("checkpointLocation")
val recoverFromChkpoint = outputMode == OutputMode.Complete()
val query = df.sparkSession.sessionState.streamingQueryManager.startQuery(
Expand Down
Loading