Skip to content

Commit 41c6a88

Browse files
committed
Fixed bug
1 parent 407f672 commit 41c6a88

File tree

3 files changed

+45
-12
lines changed

3 files changed

+45
-12
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala

+7-5
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap}
2424

2525
import org.apache.spark.sql.{Dataset, SparkSession}
2626
import org.apache.spark.sql.catalyst.encoders.RowEncoder
27-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
28-
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
27+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
28+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
2929
import org.apache.spark.sql.execution.SQLExecution
3030
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2}
3131
import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter}
@@ -431,7 +431,11 @@ class MicroBatchExecution(
431431
s"Invalid batch: ${Utils.truncatedString(output, ",")} != " +
432432
s"${Utils.truncatedString(dataPlan.output, ",")}")
433433
replacements ++= output.zip(dataPlan.output)
434-
dataPlan
434+
435+
val aliases = output.zip(dataPlan.output).map { case (to, from) =>
436+
Alias(from, to.name)(exprId = to.exprId, explicitMetadata = Some(from.metadata))
437+
}
438+
Project(aliases, dataPlan)
435439
}.getOrElse {
436440
LocalRelation(output, isStreaming = true)
437441
}
@@ -440,8 +444,6 @@ class MicroBatchExecution(
440444
// Rewire the plan to use the new attributes that were returned by the source.
441445
val replacementMap = AttributeMap(replacements)
442446
val newAttributePlan = newBatchesPlan transformAllExpressions {
443-
case a: Attribute if replacementMap.contains(a) =>
444-
replacementMap(a).withMetadata(a.metadata)
445447
case ct: CurrentTimestamp =>
446448
CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
447449
ct.dataType)

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala

+14-6
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ package org.apache.spark.sql.execution.streaming
2020
import org.apache.spark.rdd.RDD
2121
import org.apache.spark.sql.SparkSession
2222
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
2324
import org.apache.spark.sql.catalyst.expressions.Attribute
24-
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
25-
import org.apache.spark.sql.catalyst.plans.logical.Statistics
25+
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
2626
import org.apache.spark.sql.execution.LeafExecNode
2727
import org.apache.spark.sql.execution.datasources.DataSource
2828
import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2}
@@ -42,7 +42,7 @@ object StreamingRelation {
4242
* passing to [[StreamExecution]] to run a query.
4343
*/
4444
case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute])
45-
extends LeafNode {
45+
extends LeafNode with MultiInstanceRelation {
4646
override def isStreaming: Boolean = true
4747
override def toString: String = sourceName
4848

@@ -53,6 +53,8 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output:
5353
override def computeStats(): Statistics = Statistics(
5454
sizeInBytes = BigInt(dataSource.sparkSession.sessionState.conf.defaultSizeInBytes)
5555
)
56+
57+
override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))
5658
}
5759

5860
/**
@@ -62,7 +64,7 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output:
6264
case class StreamingExecutionRelation(
6365
source: BaseStreamingSource,
6466
output: Seq[Attribute])(session: SparkSession)
65-
extends LeafNode {
67+
extends LeafNode with MultiInstanceRelation {
6668

6769
override def isStreaming: Boolean = true
6870
override def toString: String = source.toString
@@ -74,6 +76,8 @@ case class StreamingExecutionRelation(
7476
override def computeStats(): Statistics = Statistics(
7577
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
7678
)
79+
80+
override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session)
7781
}
7882

7983
// We have to pack in the V1 data source as a shim, for the case when a source implements
@@ -92,13 +96,15 @@ case class StreamingRelationV2(
9296
extraOptions: Map[String, String],
9397
output: Seq[Attribute],
9498
v1Relation: Option[StreamingRelation])(session: SparkSession)
95-
extends LeafNode {
99+
extends LeafNode with MultiInstanceRelation {
96100
override def isStreaming: Boolean = true
97101
override def toString: String = sourceName
98102

99103
override def computeStats(): Statistics = Statistics(
100104
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
101105
)
106+
107+
override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session)
102108
}
103109

104110
/**
@@ -108,7 +114,7 @@ case class ContinuousExecutionRelation(
108114
source: ContinuousReadSupport,
109115
extraOptions: Map[String, String],
110116
output: Seq[Attribute])(session: SparkSession)
111-
extends LeafNode {
117+
extends LeafNode with MultiInstanceRelation {
112118

113119
override def isStreaming: Boolean = true
114120
override def toString: String = source.toString
@@ -120,6 +126,8 @@ case class ContinuousExecutionRelation(
120126
override def computeStats(): Statistics = Statistics(
121127
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
122128
)
129+
130+
override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session)
123131
}
124132

125133
/**

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala

+24-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession}
2828
import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper
2929
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Literal}
3030
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter}
31-
import org.apache.spark.sql.execution.LogicalRDD
31+
import org.apache.spark.sql.catalyst.trees.TreeNode
32+
import org.apache.spark.sql.execution.{FileSourceScanExec, LogicalRDD}
33+
import org.apache.spark.sql.execution.datasources.LogicalRelation
3234
import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper}
3335
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId}
3436
import org.apache.spark.sql.functions._
@@ -323,6 +325,27 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
323325
assert(e.toString.contains("Stream stream joins without equality predicate is not supported"))
324326
}
325327

328+
test("stream stream self join") {
329+
val input = MemoryStream[Int]
330+
val df = input.toDF
331+
val join =
332+
df.select('value % 5 as "key", 'value).join(
333+
df.select('value % 5 as "key", 'value), "key")
334+
335+
testStream(join)(
336+
AddData(input, 1, 2),
337+
CheckAnswer((1, 1, 1), (2, 2, 2)),
338+
StopStream,
339+
StartStream(),
340+
AddData(input, 3, 6),
341+
/*
342+
(1, 1) (1, 1)
343+
(2, 2) x (2, 2) = (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6)
344+
(1, 6) (1, 6)
345+
*/
346+
CheckAnswer((3, 3, 3), (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6)))
347+
}
348+
326349
test("locality preferences of StateStoreAwareZippedRDD") {
327350
import StreamingSymmetricHashJoinHelper._
328351

0 commit comments

Comments
 (0)