Skip to content

[SPARK-24971][SQL] remove SupportsDeprecatedScanRow #21921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,7 @@
* pruning), etc. Names of these interfaces start with `SupportsPushDown`.
* 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc.
* Names of these interfaces start with `SupportsReporting`.
* 3. Special scans. E.g, columnar scan, unsafe row scan, etc.
* Names of these interfaces start with `SupportsScan`. Note that a reader should only
* implement at most one of the special scans, if more than one special scans are implemented,
* only one of them would be respected, according to the priority list from high to low:
* {@link SupportsScanColumnarBatch}, {@link SupportsDeprecatedScanRow}.
* 3. Columnar scan if implements {@link SupportsScanColumnarBatch}.
*
* If an exception was throw when applying any of these query optimizations, the action will fail
* and no Spark job will be submitted.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
*
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}
* for normal data source readers, {@link org.apache.spark.sql.vectorized.ColumnarBatch} for data
* source readers that mix in {@link SupportsScanColumnarBatch}, or {@link org.apache.spark.sql.Row}
* for data source readers that mix in {@link SupportsDeprecatedScanRow}.
* source readers that mix in {@link SupportsScanColumnarBatch}.
*/
@InterfaceStability.Evolving
public interface InputPartitionReader<T> extends Closeable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.JavaConverters._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
Expand All @@ -31,7 +29,6 @@ import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.sources.v2.DataSourceV2
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
Expand Down Expand Up @@ -75,13 +72,8 @@ case class DataSourceV2ScanExec(
case _ => super.outputPartitioning
}

private lazy val partitions: Seq[InputPartition[InternalRow]] = reader match {
case r: SupportsDeprecatedScanRow =>
r.planRowInputPartitions().asScala.map {
new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[InternalRow]
}
case _ =>
reader.planInputPartitions().asScala
private lazy val partitions: Seq[InputPartition[InternalRow]] = {
reader.planInputPartitions().asScala
}

private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match {
Expand Down Expand Up @@ -131,27 +123,3 @@ case class DataSourceV2ScanExec(
}
}
}

class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType)
extends InputPartition[InternalRow] {

override def preferredLocations: Array[String] = partition.preferredLocations

override def createPartitionReader: InputPartitionReader[InternalRow] = {
new RowToUnsafeInputPartitionReader(
partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind())
}
}

class RowToUnsafeInputPartitionReader(
val rowReader: InputPartitionReader[Row],
encoder: ExpressionEncoder[Row])

extends InputPartitionReader[InternalRow] {

override def next: Boolean = rowReader.next

override def get: UnsafeRow = encoder.toRow(rowReader.get).asInstanceOf[UnsafeRow]

override def close(): Unit = rowReader.close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ package org.apache.spark.sql.execution.streaming.continuous

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.v2.RowToUnsafeInputPartitionReader
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader
import org.apache.spark.util.NextIterator
Expand Down Expand Up @@ -104,8 +102,6 @@ object ContinuousDataSourceRDD {
reader: InputPartitionReader[InternalRow]): ContinuousInputPartitionReader[_] = {
reader match {
case r: ContinuousInputPartitionReader[InternalRow] => r
case wrapped: RowToUnsafeInputPartitionReader =>
wrapped.rowReader.asInstanceOf[ContinuousInputPartitionReader[Row]]
case _ =>
throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
import org.json4s.DefaultFormats
import org.json4s.jackson.Serialization

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
Expand All @@ -34,8 +34,7 @@ import org.apache.spark.sql.types.StructType
case class RateStreamPartitionOffset(
partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset

class RateStreamContinuousReader(options: DataSourceOptions)
extends ContinuousReader with SupportsDeprecatedScanRow {
class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousReader {
implicit val defaultFormats: DefaultFormats = DefaultFormats

val creationTime = System.currentTimeMillis()
Expand Down Expand Up @@ -67,7 +66,7 @@ class RateStreamContinuousReader(options: DataSourceOptions)

override def getStartOffset(): Offset = offset

override def planRowInputPartitions(): java.util.List[InputPartition[Row]] = {
override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = {
val partitionStartMap = offset match {
case off: RateStreamOffset => off.partitionToValueAndRunTimeMs
case off =>
Expand All @@ -91,7 +90,7 @@ class RateStreamContinuousReader(options: DataSourceOptions)
i,
numPartitions,
perPartitionRate)
.asInstanceOf[InputPartition[Row]]
.asInstanceOf[InputPartition[InternalRow]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this cast necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't dig into it as the cast was already there. The reason seems to be, java.util.List isn't covariant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a good idea to leave casts. Can you check to see if this can be avoided? I found in #21118 that many of the casts were unnecessary if variables had declared types and it is much better to avoid explicit casts that work around the type system.

}.asJava
}

Expand Down Expand Up @@ -119,9 +118,10 @@ case class RateStreamContinuousInputPartition(
partitionIndex: Int,
increment: Long,
rowsPerSecond: Double)
extends ContinuousInputPartition[Row] {
extends ContinuousInputPartition[InternalRow] {

override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[Row] = {
override def createContinuousReader(
offset: PartitionOffset): InputPartitionReader[InternalRow] = {
val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset]
require(rateStreamOffset.partition == partitionIndex,
s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}")
Expand All @@ -133,7 +133,7 @@ case class RateStreamContinuousInputPartition(
rowsPerSecond)
}

override def createPartitionReader(): InputPartitionReader[Row] =
override def createPartitionReader(): InputPartitionReader[InternalRow] =
new RateStreamContinuousInputPartitionReader(
startValue, startTimeMs, partitionIndex, increment, rowsPerSecond)
}
Expand All @@ -144,12 +144,12 @@ class RateStreamContinuousInputPartitionReader(
partitionIndex: Int,
increment: Long,
rowsPerSecond: Double)
extends ContinuousInputPartitionReader[Row] {
extends ContinuousInputPartitionReader[InternalRow] {
private var nextReadTime: Long = startTimeMs
private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong

private var currentValue = startValue
private var currentRow: Row = null
private var currentRow: InternalRow = null

override def next(): Boolean = {
currentValue += increment
Expand All @@ -165,14 +165,14 @@ class RateStreamContinuousInputPartitionReader(
return false
}

currentRow = Row(
DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(nextReadTime)),
currentRow = InternalRow(
DateTimeUtils.fromMillis(nextReadTime),
currentValue)

true
}

override def get: Row = currentRow
override def get: InternalRow = currentRow

override def close(): Unit = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.collection.SortedMap
import scala.collection.mutable.ListBuffer

import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.sql.{Encoder, Row, SQLContext}
import org.apache.spark.sql.{Encoder, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord
import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions}
import org.apache.spark.sql.sources.v2.reader.{InputPartition, SupportsDeprecatedScanRow}
import org.apache.spark.sql.sources.v2.reader.InputPartition
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.RpcUtils
Expand All @@ -49,8 +49,7 @@ import org.apache.spark.util.RpcUtils
* the specified offset within the list, or null if that offset doesn't yet have a record.
*/
class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2)
extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport
with SupportsDeprecatedScanRow {
extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport {
private implicit val formats = Serialization.formats(NoTypeHints)

protected val logicalPlan =
Expand Down Expand Up @@ -100,7 +99,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa
)
}

override def planRowInputPartitions(): ju.List[InputPartition[Row]] = {
override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = {
synchronized {
val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id"
endpointRef =
Expand All @@ -109,7 +108,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa
startOffset.partitionNums.map {
case (part, index) =>
new ContinuousMemoryStreamInputPartition(
endpointName, part, index): InputPartition[Row]
endpointName, part, index): InputPartition[InternalRow]
}.toList.asJava
}
}
Expand Down Expand Up @@ -141,7 +140,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa
val buf = records(part)
val record = if (buf.size <= index) None else Some(buf(index))

context.reply(record.map(Row(_)))
context.reply(record.map(r => encoder.toRow(r).copy()))
}
}
}
Expand All @@ -164,7 +163,7 @@ object ContinuousMemoryStream {
class ContinuousMemoryStreamInputPartition(
driverEndpointName: String,
partition: Int,
startOffset: Int) extends InputPartition[Row] {
startOffset: Int) extends InputPartition[InternalRow] {
override def createPartitionReader: ContinuousMemoryStreamInputPartitionReader =
new ContinuousMemoryStreamInputPartitionReader(driverEndpointName, partition, startOffset)
}
Expand All @@ -177,14 +176,14 @@ class ContinuousMemoryStreamInputPartition(
class ContinuousMemoryStreamInputPartitionReader(
driverEndpointName: String,
partition: Int,
startOffset: Int) extends ContinuousInputPartitionReader[Row] {
startOffset: Int) extends ContinuousInputPartitionReader[InternalRow] {
private val endpoint = RpcUtils.makeDriverRef(
driverEndpointName,
SparkEnv.get.conf,
SparkEnv.get.rpcEnv)

private var currentOffset = startOffset
private var current: Option[Row] = None
private var current: Option[InternalRow] = None

// Defense-in-depth against failing to propagate the task context. Since it's not inheritable,
// we have to do a bit of error prone work to get it into every thread used by continuous
Expand All @@ -204,15 +203,15 @@ class ContinuousMemoryStreamInputPartitionReader(
true
}

override def get(): Row = current.get
override def get(): InternalRow = current.get

override def close(): Unit = {}

override def getOffset: ContinuousMemoryStreamPartitionOffset =
ContinuousMemoryStreamPartitionOffset(partition, currentOffset)

private def getRecord: Option[Row] =
endpoint.askSync[Option[Row]](
private def getRecord: Option[InternalRow] =
endpoint.askSync[Option[InternalRow]](
GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.commons.io.IOUtils
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.sources.v2.DataSourceOptions
Expand All @@ -38,7 +39,7 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{ManualClock, SystemClock}

class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String)
extends MicroBatchReader with SupportsDeprecatedScanRow with Logging {
extends MicroBatchReader with Logging {
import RateStreamProvider._

private[sources] val clock = {
Expand Down Expand Up @@ -134,7 +135,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation:
LongOffset(json.toLong)
}

override def planRowInputPartitions(): java.util.List[InputPartition[Row]] = {
override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = {
val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L)
val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
Expand Down Expand Up @@ -169,7 +170,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation:
(0 until numPartitions).map { p =>
new RateStreamMicroBatchInputPartition(
p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
: InputPartition[Row]
: InputPartition[InternalRow]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? Doesn't RateStreamMicroBatchInputPartition implement InputPartition[InternalRow]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine since it isn't a cast, but it's generally better to check whether these are still necessary after refactoring.

}.toList.asJava
}

Expand All @@ -188,9 +189,9 @@ class RateStreamMicroBatchInputPartition(
rangeStart: Long,
rangeEnd: Long,
localStartTimeMs: Long,
relativeMsPerValue: Double) extends InputPartition[Row] {
relativeMsPerValue: Double) extends InputPartition[InternalRow] {

override def createPartitionReader(): InputPartitionReader[Row] =
override def createPartitionReader(): InputPartitionReader[InternalRow] =
new RateStreamMicroBatchInputPartitionReader(partitionId, numPartitions, rangeStart, rangeEnd,
localStartTimeMs, relativeMsPerValue)
}
Expand All @@ -201,22 +202,18 @@ class RateStreamMicroBatchInputPartitionReader(
rangeStart: Long,
rangeEnd: Long,
localStartTimeMs: Long,
relativeMsPerValue: Double) extends InputPartitionReader[Row] {
relativeMsPerValue: Double) extends InputPartitionReader[InternalRow] {
private var count: Long = 0

override def next(): Boolean = {
rangeStart + partitionId + numPartitions * count < rangeEnd
}

override def get(): Row = {
override def get(): InternalRow = {
val currValue = rangeStart + partitionId + numPartitions * count
count += 1
val relative = math.round((currValue - rangeStart) * relativeMsPerValue)
Row(
DateTimeUtils.toJavaTimestamp(
DateTimeUtils.fromMillis(relative + localStartTimeMs)),
currValue
)
InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), currValue)
}

override def close(): Unit = {}
Expand Down
Loading