Skip to content

Commit 21bbd6b

Browse files
committed
[SPARK-29248][SQL] Add PhysicalWriteInfo with number of partitions
1 parent 3d45779 commit 21bbd6b

File tree

19 files changed

+119
-48
lines changed

19 files changed

+119
-48
lines changed

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010
2020
import java.{util => ju}
2121

2222
import org.apache.spark.sql.catalyst.InternalRow
23-
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, WriterCommitMessage}
23+
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}
2424
import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery
2525
import org.apache.spark.sql.types.StructType
2626

@@ -40,7 +40,7 @@ private[kafka010] class KafkaBatchWrite(
4040

4141
validateQuery(schema.toAttributes, producerParams, topic)
4242

43-
override def createBatchWriterFactory(): KafkaBatchWriterFactory =
43+
override def createBatchWriterFactory(info: PhysicalWriteInfo): KafkaBatchWriterFactory =
4444
KafkaBatchWriterFactory(topic, producerParams, schema)
4545

4646
override def commit(messages: Array[WriterCommitMessage]): Unit = {}

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010
2020
import java.{util => ju}
2121

2222
import org.apache.spark.sql.catalyst.InternalRow
23-
import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
23+
import org.apache.spark.sql.connector.write.{DataWriter, PhysicalWriteInfo, WriterCommitMessage}
2424
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
2525
import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery
2626
import org.apache.spark.sql.types.StructType
@@ -41,7 +41,8 @@ private[kafka010] class KafkaStreamingWrite(
4141

4242
validateQuery(schema.toAttributes, producerParams, topic)
4343

44-
override def createStreamingWriterFactory(): KafkaStreamWriterFactory =
44+
override def createStreamingWriterFactory(
45+
info: PhysicalWriteInfo): KafkaStreamWriterFactory =
4546
KafkaStreamWriterFactory(topic, producerParams, schema)
4647

4748
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
* An interface that defines how to write the data to data source for batch processing.
2424
*
2525
* The writing procedure is:
26-
* 1. Create a writer factory by {@link #createBatchWriterFactory()}, serialize and send it to all
27-
* the partitions of the input data(RDD).
26+
* 1. Create a writer factory by {@link #createBatchWriterFactory(PhysicalWriteInfo)}, serialize
27+
* and send it to all the partitions of the input data(RDD).
2828
* 2. For each partition, create the data writer, and write the data of the partition with this
2929
* writer. If all the data are written successfully, call {@link DataWriter#commit()}. If
3030
* exception happens during the writing, call {@link DataWriter#abort()}.
@@ -45,8 +45,10 @@ public interface BatchWrite {
4545
*
4646
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
4747
* submitted.
48+
*
49+
* @param info Physical information about the input data that will be written to this table.
4850
*/
49-
DataWriterFactory createBatchWriterFactory();
51+
DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info);
5052

5153
/**
5254
* Returns whether Spark should use the commit coordinator to ensure that at most one task for

sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriterFactory.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
import org.apache.spark.sql.catalyst.InternalRow;
2525

2626
/**
27-
* A factory of {@link DataWriter} returned by {@link BatchWrite#createBatchWriterFactory()},
28-
* which is responsible for creating and initializing the actual data writer at executor side.
27+
* A factory of {@link DataWriter} returned by
28+
* {@link BatchWrite#createBatchWriterFactory(PhysicalWriteInfo)}, which is responsible for
29+
* creating and initializing the actual data writer at executor side.
2930
*
3031
* Note that, the writer factory will be serialized and sent to executors, then the data writer
3132
* will be created on executors and do the actual writing. So this interface must be
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.write;
19+
20+
import org.apache.spark.annotation.Evolving;
21+
import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory;
22+
23+
/**
24+
* This interface contains physical write information that data sources can use when
25+
* generating a {@link DataWriterFactory} or a {@link StreamingDataWriterFactory}.
26+
*/
27+
@Evolving
28+
public interface PhysicalWriteInfo {
29+
/**
30+
* The number of partitions of the input data that is going to be written.
31+
*/
32+
int numPartitions();
33+
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingDataWriterFactory.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@
2323
import org.apache.spark.annotation.Evolving;
2424
import org.apache.spark.sql.catalyst.InternalRow;
2525
import org.apache.spark.sql.connector.write.DataWriter;
26+
import org.apache.spark.sql.connector.write.PhysicalWriteInfo;
2627

2728
/**
2829
* A factory of {@link DataWriter} returned by
29-
* {@link StreamingWrite#createStreamingWriterFactory()}, which is responsible for creating
30-
* and initializing the actual data writer at executor side.
30+
* {@link StreamingWrite#createStreamingWriterFactory(PhysicalWriteInfo)}, which is responsible for
31+
* creating and initializing the actual data writer at executor side.
3132
*
3233
* Note that, the writer factory will be serialized and sent to executors, then the data writer
3334
* will be created on executors and do the actual writing. So this interface must be

sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919

2020
import org.apache.spark.annotation.Evolving;
2121
import org.apache.spark.sql.connector.write.DataWriter;
22+
import org.apache.spark.sql.connector.write.PhysicalWriteInfo;
2223
import org.apache.spark.sql.connector.write.WriterCommitMessage;
2324

2425
/**
2526
* An interface that defines how to write the data to data source in streaming queries.
2627
*
2728
* The writing procedure is:
28-
* 1. Create a writer factory by {@link #createStreamingWriterFactory()}, serialize and send it to
29-
* all the partitions of the input data(RDD).
29+
* 1. Create a writer factory by {@link #createStreamingWriterFactory(PhysicalWriteInfo)},
30+
* serialize and send it to all the partitions of the input data(RDD).
3031
* 2. For each epoch in each partition, create the data writer, and write the data of the epoch in
3132
* the partition with this writer. If all the data are written successfully, call
3233
* {@link DataWriter#commit()}. If exception happens during the writing, call
@@ -48,8 +49,10 @@ public interface StreamingWrite {
4849
*
4950
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
5051
* submitted.
52+
*
53+
* @param info Information about the RDD that will be written to this data writer
5154
*/
52-
StreamingDataWriterFactory createStreamingWriterFactory();
55+
StreamingDataWriterFactory createStreamingWriterFactory(PhysicalWriteInfo info);
5356

5457
/**
5558
* Commits this writing job for the specified epoch with a list of commit messages. The commit
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.write
19+
20+
private[sql] case class PhysicalWriteInfoImpl(numPartitions: Int) extends PhysicalWriteInfo

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class InMemoryTable(
122122
}
123123

124124
private abstract class TestBatchWrite extends BatchWrite {
125-
override def createBatchWriterFactory(): DataWriterFactory = {
125+
override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
126126
BufferedRowsWriterFactory
127127
}
128128

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
2323

2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider}
26-
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, SupportsTruncate, WriteBuilder, WriterCommitMessage}
26+
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, PhysicalWriteInfo, SupportsTruncate, WriteBuilder, WriterCommitMessage}
2727
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
2828
import org.apache.spark.sql.sources.DataSourceRegister
2929
import org.apache.spark.sql.types.StructType
@@ -58,7 +58,8 @@ private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsTruncate
5858
}
5959

6060
private[noop] object NoopBatchWrite extends BatchWrite {
61-
override def createBatchWriterFactory(): DataWriterFactory = NoopWriterFactory
61+
override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory =
62+
NoopWriterFactory
6263
override def commit(messages: Array[WriterCommitMessage]): Unit = {}
6364
override def abort(messages: Array[WriterCommitMessage]): Unit = {}
6465
}
@@ -74,8 +75,8 @@ private[noop] object NoopWriter extends DataWriter[InternalRow] {
7475
}
7576

7677
private[noop] object NoopStreamingWrite extends StreamingWrite {
77-
override def createStreamingWriterFactory(): StreamingDataWriterFactory =
78-
NoopStreamingDataWriterFactory
78+
override def createStreamingWriterFactory(
79+
info: PhysicalWriteInfo): StreamingDataWriterFactory = NoopStreamingDataWriterFactory
7980
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
8081
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
8182
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import org.apache.hadoop.mapreduce.Job
2020

2121
import org.apache.spark.internal.Logging
2222
import org.apache.spark.internal.io.FileCommitProtocol
23-
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, WriterCommitMessage}
23+
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}
2424
import org.apache.spark.sql.execution.datasources.{WriteJobDescription, WriteTaskResult}
2525
import org.apache.spark.sql.execution.datasources.FileFormatWriter.processStats
2626

@@ -44,7 +44,7 @@ class FileBatchWrite(
4444
committer.abortJob(job)
4545
}
4646

47-
override def createBatchWriterFactory(): DataWriterFactory = {
47+
override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
4848
FileWriterFactory(description, committer)
4949
}
5050
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
3232
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3333
import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, SupportsWrite, TableCatalog}
3434
import org.apache.spark.sql.connector.expressions.Transform
35-
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder, WriterCommitMessage}
35+
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfoImpl, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder, WriterCommitMessage}
3636
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
3737
import org.apache.spark.sql.sources.{AlwaysTrue, Filter}
3838
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -353,28 +353,31 @@ trait V2TableWriteExec extends UnaryExecNode {
353353
override def output: Seq[Attribute] = Nil
354354

355355
protected def writeWithV2(batchWrite: BatchWrite): RDD[InternalRow] = {
356-
val writerFactory = batchWrite.createBatchWriterFactory()
357-
val useCommitCoordinator = batchWrite.useCommitCoordinator
358-
val rdd = query.execute()
359-
// SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
360-
// partition rdd to make sure we at least set up one write task to write the metadata.
361-
val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) {
362-
sparkContext.parallelize(Array.empty[InternalRow], 1)
363-
} else {
364-
rdd
356+
val rdd: RDD[InternalRow] = {
357+
val tempRdd = query.execute()
358+
// SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
359+
// partition rdd to make sure we at least set up one write task to write the metadata.
360+
if (tempRdd.partitions.length == 0) {
361+
sparkContext.parallelize(Array.empty[InternalRow], 1)
362+
} else {
363+
tempRdd
364+
}
365365
}
366-
val messages = new Array[WriterCommitMessage](rddWithNonEmptyPartitions.partitions.length)
366+
val writerFactory = batchWrite.createBatchWriterFactory(
367+
PhysicalWriteInfoImpl(rdd.getNumPartitions))
368+
val useCommitCoordinator = batchWrite.useCommitCoordinator
369+
val messages = new Array[WriterCommitMessage](rdd.partitions.length)
367370
val totalNumRowsAccumulator = new LongAccumulator()
368371

369372
logInfo(s"Start processing data source write support: $batchWrite. " +
370373
s"The input RDD has ${messages.length} partitions.")
371374

372375
try {
373376
sparkContext.runJob(
374-
rddWithNonEmptyPartitions,
377+
rdd,
375378
(context: TaskContext, iter: Iterator[InternalRow]) =>
376379
DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator),
377-
rddWithNonEmptyPartitions.partitions.indices,
380+
rdd.partitions.indices,
378381
(index, result: DataWritingSparkTaskResult) => {
379382
val commitMessage = result.writerCommitMessage
380383
messages(index) = commitMessage

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging
2424
import org.apache.spark.rdd.RDD
2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.expressions.Attribute
27+
import org.apache.spark.sql.connector.write.PhysicalWriteInfoImpl
2728
import org.apache.spark.sql.connector.write.streaming.StreamingWrite
2829
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
2930
import org.apache.spark.sql.execution.streaming.StreamExecution
@@ -38,8 +39,10 @@ case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPl
3839
override def output: Seq[Attribute] = Nil
3940

4041
override protected def doExecute(): RDD[InternalRow] = {
41-
val writerFactory = write.createStreamingWriterFactory()
42-
val rdd = new ContinuousWriteRDD(query.execute(), writerFactory)
42+
val queryRdd = query.execute()
43+
val writerFactory = write.createStreamingWriterFactory(
44+
PhysicalWriteInfoImpl(queryRdd.getNumPartitions))
45+
val rdd = new ContinuousWriteRDD(queryRdd, writerFactory)
4346

4447
logInfo(s"Start processing data source write support: $write. " +
4548
s"The input RDD has ${rdd.partitions.length} partitions.")

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.sources
2020
import org.apache.spark.internal.Logging
2121
import org.apache.spark.sql.{Dataset, SparkSession}
2222
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
23-
import org.apache.spark.sql.connector.write.WriterCommitMessage
23+
import org.apache.spark.sql.connector.write.{PhysicalWriteInfo, WriterCommitMessage}
2424
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
2525
import org.apache.spark.sql.types.StructType
2626
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -38,7 +38,8 @@ class ConsoleWrite(schema: StructType, options: CaseInsensitiveStringMap)
3838
assert(SparkSession.getActiveSession.isDefined)
3939
protected val spark = SparkSession.getActiveSession.get
4040

41-
def createStreamingWriterFactory(): StreamingDataWriterFactory = PackedRowWriterFactory
41+
def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory =
42+
PackedRowWriterFactory
4243

4344
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
4445
// We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2727
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2828
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
2929
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability}
30-
import org.apache.spark.sql.connector.write.{DataWriter, SupportsTruncate, WriteBuilder, WriterCommitMessage}
30+
import org.apache.spark.sql.connector.write.{DataWriter, PhysicalWriteInfo, SupportsTruncate, WriteBuilder, WriterCommitMessage}
3131
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
3232
import org.apache.spark.sql.execution.python.PythonForeachWriter
3333
import org.apache.spark.sql.types.StructType
@@ -72,7 +72,8 @@ case class ForeachWriterTable[T](
7272
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
7373
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
7474

75-
override def createStreamingWriterFactory(): StreamingDataWriterFactory = {
75+
override def createStreamingWriterFactory(
76+
info: PhysicalWriteInfo): StreamingDataWriterFactory = {
7677
val rowConverter: InternalRow => T = converter match {
7778
case Left(enc) =>
7879
val boundEnc = enc.resolveAndBind(

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution.streaming.sources
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
21-
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, WriterCommitMessage}
21+
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}
2222
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
2323

2424
/**
@@ -36,8 +36,8 @@ class MicroBatchWrite(eppchId: Long, val writeSupport: StreamingWrite) extends B
3636
writeSupport.abort(eppchId, messages)
3737
}
3838

39-
override def createBatchWriterFactory(): DataWriterFactory = {
40-
new MicroBatchWriterFactory(eppchId, writeSupport.createStreamingWriterFactory())
39+
override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
40+
new MicroBatchWriterFactory(eppchId, writeSupport.createStreamingWriterFactory(info))
4141
}
4242
}
4343

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
3333
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
3434
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
3535
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability}
36-
import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory, SupportsTruncate, WriteBuilder, WriterCommitMessage}
36+
import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory, PhysicalWriteInfo, SupportsTruncate, WriteBuilder, WriterCommitMessage}
3737
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
3838
import org.apache.spark.sql.execution.streaming.Sink
3939
import org.apache.spark.sql.types.StructType
@@ -140,7 +140,7 @@ class MemoryStreamingWrite(
140140
val sink: MemorySink, schema: StructType, needTruncate: Boolean)
141141
extends StreamingWrite {
142142

143-
override def createStreamingWriterFactory: MemoryWriterFactory = {
143+
override def createStreamingWriterFactory(info: PhysicalWriteInfo): MemoryWriterFactory = {
144144
MemoryWriterFactory(schema)
145145
}
146146

0 commit comments

Comments
 (0)