Skip to content

Commit 6dd9f8d

Browse files
rdblueotterc
authored andcommitted
[SPARK-23325] Use InternalRow when reading with DataSourceV2.
This updates the DataSourceV2 API to use InternalRow instead of Row for the default case with no scan mix-ins. Support for readers that produce Row is added through SupportsDeprecatedScanRow, which matches the previous API. Readers that used Row now implement this class and should be migrated to InternalRow. Readers that previously implemented SupportsScanUnsafeRow have been migrated to use no SupportsScan mix-ins and produce InternalRow. This uses existing tests. Author: Ryan Blue <blue@apache.org> Closes apache#21118 from rdblue/SPARK-23325-datasource-v2-internal-row. Ref: LIHADOOP-48531 RB=1855575 A=
1 parent 3934584 commit 6dd9f8d

27 files changed

+207
-198
lines changed

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.kafka.common.TopicPartition
2626
import org.apache.spark.TaskContext
2727
import org.apache.spark.internal.Logging
2828
import org.apache.spark.sql.SparkSession
29+
import org.apache.spark.sql.catalyst.InternalRow
2930
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
3031
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
3132
import org.apache.spark.sql.sources.v2.reader._
@@ -53,7 +54,7 @@ class KafkaContinuousReader(
5354
metadataPath: String,
5455
initialOffsets: KafkaOffsetRangeLimit,
5556
failOnDataLoss: Boolean)
56-
extends ContinuousReader with SupportsScanUnsafeRow with Logging {
57+
extends ContinuousReader with Logging {
5758

5859
private lazy val session = SparkSession.getActiveSession.get
5960
private lazy val sc = session.sparkContext
@@ -86,7 +87,7 @@ class KafkaContinuousReader(
8687
KafkaSourceOffset(JsonUtils.partitionOffsets(json))
8788
}
8889

89-
override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = {
90+
override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = {
9091
import scala.collection.JavaConverters._
9192

9293
val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset)
@@ -107,8 +108,8 @@ class KafkaContinuousReader(
107108
startOffsets.toSeq.map {
108109
case (topicPartition, start) =>
109110
KafkaContinuousInputPartition(
110-
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
111-
.asInstanceOf[InputPartition[UnsafeRow]]
111+
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss
112+
): InputPartition[InternalRow]
112113
}.asJava
113114
}
114115

@@ -161,9 +162,10 @@ case class KafkaContinuousInputPartition(
161162
startOffset: Long,
162163
kafkaParams: ju.Map[String, Object],
163164
pollTimeoutMs: Long,
164-
failOnDataLoss: Boolean) extends ContinuousInputPartition[UnsafeRow] {
165+
failOnDataLoss: Boolean) extends ContinuousInputPartition[InternalRow] {
165166

166-
override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[UnsafeRow] = {
167+
override def createContinuousReader(
168+
offset: PartitionOffset): InputPartitionReader[InternalRow] = {
167169
val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset]
168170
require(kafkaOffset.topicPartition == topicPartition,
169171
s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}")
@@ -192,7 +194,7 @@ class KafkaContinuousInputPartitionReader(
192194
startOffset: Long,
193195
kafkaParams: ju.Map[String, Object],
194196
pollTimeoutMs: Long,
195-
failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[UnsafeRow] {
197+
failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[InternalRow] {
196198
private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false)
197199
private val converter = new KafkaRecordToUnsafeRowConverter
198200

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@ import org.apache.spark.SparkEnv
2929
import org.apache.spark.internal.Logging
3030
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
3131
import org.apache.spark.sql.SparkSession
32+
import org.apache.spark.sql.catalyst.InternalRow
3233
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
3334
import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset}
3435
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
3536
import org.apache.spark.sql.sources.v2.DataSourceOptions
36-
import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow}
37+
import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader}
3738
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
3839
import org.apache.spark.sql.types.StructType
3940
import org.apache.spark.util.UninterruptibleThread
@@ -61,7 +62,7 @@ private[kafka010] class KafkaMicroBatchReader(
6162
metadataPath: String,
6263
startingOffsets: KafkaOffsetRangeLimit,
6364
failOnDataLoss: Boolean)
64-
extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
65+
extends MicroBatchReader with Logging {
6566

6667
private var startPartitionOffsets: PartitionOffsetMap = _
6768
private var endPartitionOffsets: PartitionOffsetMap = _
@@ -101,7 +102,7 @@ private[kafka010] class KafkaMicroBatchReader(
101102
}
102103
}
103104

104-
override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = {
105+
override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = {
105106
// Find the new partitions, and get their earliest offsets
106107
val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet)
107108
val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
@@ -142,11 +143,11 @@ private[kafka010] class KafkaMicroBatchReader(
142143
val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size
143144

144145
// Generate factories based on the offset ranges
145-
val factories = offsetRanges.map { range =>
146+
offsetRanges.map { range =>
146147
new KafkaMicroBatchInputPartition(
147-
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
148-
}
149-
factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava
148+
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer
149+
): InputPartition[InternalRow]
150+
}.asJava
150151
}
151152

152153
override def getStartOffset: Offset = {
@@ -305,11 +306,11 @@ private[kafka010] case class KafkaMicroBatchInputPartition(
305306
executorKafkaParams: ju.Map[String, Object],
306307
pollTimeoutMs: Long,
307308
failOnDataLoss: Boolean,
308-
reuseKafkaConsumer: Boolean) extends InputPartition[UnsafeRow] {
309+
reuseKafkaConsumer: Boolean) extends InputPartition[InternalRow] {
309310

310311
override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray
311312

312-
override def createPartitionReader(): InputPartitionReader[UnsafeRow] =
313+
override def createPartitionReader(): InputPartitionReader[InternalRow] =
313314
new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs,
314315
failOnDataLoss, reuseKafkaConsumer)
315316
}
@@ -320,7 +321,7 @@ private[kafka010] case class KafkaMicroBatchInputPartitionReader(
320321
executorKafkaParams: ju.Map[String, Object],
321322
pollTimeoutMs: Long,
322323
failOnDataLoss: Boolean,
323-
reuseKafkaConsumer: Boolean) extends InputPartitionReader[UnsafeRow] with Logging {
324+
reuseKafkaConsumer: Boolean) extends InputPartitionReader[InternalRow] with Logging {
324325

325326
private val consumer = KafkaDataConsumer.acquire(
326327
offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer)

external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase {
673673
Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))),
674674
Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L)))
675675
)
676-
val factories = reader.planUnsafeInputPartitions().asScala
676+
val factories = reader.planInputPartitions().asScala
677677
.map(_.asInstanceOf[KafkaMicroBatchInputPartition])
678678
withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") {
679679
assert(factories.size == numPartitionsGenerated)

sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import java.util.List;
2121

2222
import org.apache.spark.annotation.InterfaceStability;
23-
import org.apache.spark.sql.Row;
23+
import org.apache.spark.sql.catalyst.InternalRow;
2424
import org.apache.spark.sql.sources.v2.DataSourceOptions;
2525
import org.apache.spark.sql.sources.v2.ReadSupport;
2626
import org.apache.spark.sql.sources.v2.ReadSupportWithSchema;
@@ -43,7 +43,7 @@
4343
* Names of these interfaces start with `SupportsScan`. Note that a reader should only
4444
* implement at most one of the special scans, if more than one special scans are implemented,
4545
* only one of them would be respected, according to the priority list from high to low:
46-
* {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}.
46+
* {@link SupportsScanColumnarBatch}, {@link SupportsDeprecatedScanRow}.
4747
*
4848
* If an exception was throw when applying any of these query optimizations, the action will fail
4949
* and no Spark job will be submitted.
@@ -76,5 +76,5 @@ public interface DataSourceReader {
7676
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
7777
* submitted.
7878
*/
79-
List<InputPartition<Row>> planInputPartitions();
79+
List<InputPartition<InternalRow>> planInputPartitions();
8080
}

sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
* An input partition reader returned by {@link InputPartition#createPartitionReader()} and is responsible for
2727
* outputting data for a RDD partition.
2828
*
29-
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input
30-
* partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input partition
31-
* readers that mix in {@link SupportsScanUnsafeRow}.
29+
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}
30+
* for normal data source readers, {@link org.apache.spark.sql.vectorized.ColumnarBatch} for data
31+
* source readers that mix in {@link SupportsScanColumnarBatch}, or {@link org.apache.spark.sql.Row}
32+
* for data source readers that mix in {@link SupportsDeprecatedScanRow}.
3233
*/
3334
@InterfaceStability.Evolving
3435
public interface InputPartitionReader<T> extends Closeable {

sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java renamed to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,23 @@
1717

1818
package org.apache.spark.sql.sources.v2.reader;
1919

20-
import java.util.List;
21-
2220
import org.apache.spark.annotation.InterfaceStability;
2321
import org.apache.spark.sql.Row;
24-
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
22+
import org.apache.spark.sql.catalyst.InternalRow;
23+
24+
import java.util.List;
2525

2626
/**
2727
* A mix-in interface for {@link DataSourceReader}. Data source readers can implement this
28-
* interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side.
29-
* This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get
30-
* changed in the future Spark versions.
28+
* interface to output {@link Row} instead of {@link InternalRow}.
29+
* This is an experimental and unstable interface.
3130
*/
3231
@InterfaceStability.Unstable
33-
public interface SupportsScanUnsafeRow extends DataSourceReader {
34-
35-
@Override
36-
default List<InputPartition<Row>> planInputPartitions() {
32+
public interface SupportsDeprecatedScanRow extends DataSourceReader {
33+
default List<InputPartition<InternalRow>> planInputPartitions() {
3734
throw new IllegalStateException(
38-
"planInputPartitions not supported by default within SupportsScanUnsafeRow");
35+
"planInputPartitions not supported by default within SupportsDeprecatedScanRow");
3936
}
4037

41-
/**
42-
* Similar to {@link DataSourceReader#planInputPartitions()},
43-
* but returns data in unsafe row format.
44-
*/
45-
List<InputPartition<UnsafeRow>> planUnsafeInputPartitions();
38+
List<InputPartition<Row>> planRowInputPartitions();
4639
}

sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import java.util.List;
2121

2222
import org.apache.spark.annotation.InterfaceStability;
23-
import org.apache.spark.sql.Row;
23+
import org.apache.spark.sql.catalyst.InternalRow;
2424
import org.apache.spark.sql.vectorized.ColumnarBatch;
2525

2626
/**
@@ -30,7 +30,7 @@
3030
@InterfaceStability.Evolving
3131
public interface SupportsScanColumnarBatch extends DataSourceReader {
3232
@Override
33-
default List<InputPartition<Row>> planInputPartitions() {
33+
default List<InputPartition<InternalRow>> planInputPartitions() {
3434
throw new IllegalStateException(
3535
"planInputPartitions not supported by default within SupportsScanColumnarBatch.");
3636
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,13 @@ case class DataSourceV2ScanExec(
7575
case _ => super.outputPartitioning
7676
}
7777

78-
private lazy val partitions: Seq[InputPartition[UnsafeRow]] = reader match {
79-
case r: SupportsScanUnsafeRow => r.planUnsafeInputPartitions().asScala
80-
case _ =>
81-
reader.planInputPartitions().asScala.map {
82-
new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[UnsafeRow]
78+
private lazy val partitions: Seq[InputPartition[InternalRow]] = reader match {
79+
case r: SupportsDeprecatedScanRow =>
80+
r.planRowInputPartitions().asScala.map {
81+
new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[InternalRow]
8382
}
83+
case _ =>
84+
reader.planInputPartitions().asScala
8485
}
8586

8687
private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match {
@@ -132,11 +133,11 @@ case class DataSourceV2ScanExec(
132133
}
133134

134135
class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType)
135-
extends InputPartition[UnsafeRow] {
136+
extends InputPartition[InternalRow] {
136137

137138
override def preferredLocations: Array[String] = partition.preferredLocations
138139

139-
override def createPartitionReader: InputPartitionReader[UnsafeRow] = {
140+
override def createPartitionReader: InputPartitionReader[InternalRow] = {
140141
new RowToUnsafeInputPartitionReader(
141142
partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind())
142143
}
@@ -146,7 +147,7 @@ class RowToUnsafeInputPartitionReader(
146147
val rowReader: InputPartitionReader[Row],
147148
encoder: ExpressionEncoder[Row])
148149

149-
extends InputPartitionReader[UnsafeRow] {
150+
extends InputPartitionReader[InternalRow] {
150151

151152
override def next: Boolean = rowReader.next
152153

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,13 @@ object DataSourceV2Strategy extends Strategy {
124124
val filterCondition = postScanFilters.reduceLeftOption(And)
125125
val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan)
126126

127-
val withProjection = if (withFilter.output != project) {
128-
ProjectExec(project, withFilter)
129-
} else {
130-
withFilter
131-
}
132-
133-
withProjection :: Nil
127+
// always add the projection, which will produce unsafe rows required by some operators
128+
ProjectExec(project, withFilter) :: Nil
134129

135130
case r: StreamingDataSourceV2Relation =>
136-
DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil
131+
// ensure there is a projection, which will produce unsafe rows required by some operators
132+
ProjectExec(r.output,
133+
DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader)) :: Nil
137134

138135
case WriteToDataSourceV2(writer, query) =>
139136
WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ package org.apache.spark.sql.execution.streaming.continuous
2020
import org.apache.spark._
2121
import org.apache.spark.rdd.RDD
2222
import org.apache.spark.sql.Row
23-
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
23+
import org.apache.spark.sql.catalyst.InternalRow
2424
import org.apache.spark.sql.execution.datasources.v2.RowToUnsafeInputPartitionReader
2525
import org.apache.spark.sql.sources.v2.reader._
2626
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader
2727
import org.apache.spark.util.NextIterator
2828

2929
class ContinuousDataSourceRDDPartition(
3030
val index: Int,
31-
val inputPartition: InputPartition[UnsafeRow])
31+
val inputPartition: InputPartition[InternalRow])
3232
extends Partition with Serializable {
3333

3434
// This is semantically a lazy val - it's initialized once the first time a call to
@@ -51,11 +51,11 @@ class ContinuousDataSourceRDD(
5151
sc: SparkContext,
5252
dataQueueSize: Int,
5353
epochPollIntervalMs: Long,
54-
@transient private val readerFactories: Seq[InputPartition[UnsafeRow]])
55-
extends RDD[UnsafeRow](sc, Nil) {
54+
private val readerInputPartitions: Seq[InputPartition[InternalRow]])
55+
extends RDD[InternalRow](sc, Nil) {
5656

5757
override protected def getPartitions: Array[Partition] = {
58-
readerFactories.zipWithIndex.map {
58+
readerInputPartitions.zipWithIndex.map {
5959
case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition)
6060
}.toArray
6161
}
@@ -64,7 +64,7 @@ class ContinuousDataSourceRDD(
6464
* Initialize the shared reader for this partition if needed, then read rows from it until
6565
* it returns null to signal the end of the epoch.
6666
*/
67-
override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
67+
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
6868
// If attempt number isn't 0, this is a task retry, which we don't support.
6969
if (context.attemptNumber() != 0) {
7070
throw new ContinuousTaskRetryException()
@@ -74,15 +74,14 @@ class ContinuousDataSourceRDD(
7474
val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition]
7575
if (partition.queueReader == null) {
7676
partition.queueReader =
77-
new ContinuousQueuedDataReader(
78-
partition.inputPartition, context, dataQueueSize, epochPollIntervalMs)
77+
new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs)
7978
}
8079

8180
partition.queueReader
8281
}
8382

84-
new NextIterator[UnsafeRow] {
85-
override def getNext(): UnsafeRow = {
83+
new NextIterator[InternalRow] {
84+
override def getNext(): InternalRow = {
8685
readerForPartition.next() match {
8786
case null =>
8887
finished = true
@@ -102,9 +101,9 @@ class ContinuousDataSourceRDD(
102101

103102
object ContinuousDataSourceRDD {
104103
private[continuous] def getContinuousReader(
105-
reader: InputPartitionReader[UnsafeRow]): ContinuousInputPartitionReader[_] = {
104+
reader: InputPartitionReader[InternalRow]): ContinuousInputPartitionReader[_] = {
106105
reader match {
107-
case r: ContinuousInputPartitionReader[UnsafeRow] => r
106+
case r: ContinuousInputPartitionReader[InternalRow] => r
108107
case wrapped: RowToUnsafeInputPartitionReader =>
109108
wrapped.rowReader.asInstanceOf[ContinuousInputPartitionReader[Row]]
110109
case _ =>

0 commit comments

Comments
 (0)