Skip to content

Commit 8f986c3

Browse files
committed
[SPARK-24073][SQL] Rename DataReaderFactory to InputPartition.
Renames: * `DataReaderFactory` to `InputPartition` * `DataReader` to `InputPartitionReader` * `createDataReaderFactories` to `planInputPartitions` * `createUnsafeDataReaderFactories` to `planUnsafeInputPartitions` * `createBatchDataReaderFactories` to `planBatchInputPartitions` This fixes the changes in SPARK-23219, which renamed ReadTask to DataReaderFactory. The intent of that change was to make the read and write API match (write side uses DataWriterFactory), but the underlying problem is that the two classes are not equivalent. ReadTask/DataReader function as Iterable/Iterator. One InputPartition is a specific partition of the data to be read, in contrast to DataWriterFactory where the same factory instance is used in all write tasks. InputPartition's purpose is to manage the lifecycle of the associated reader, which is now called InputPartitionReader, with an explicit create operation to mirror the close operation. This was no longer clear from the API because DataReaderFactory appeared to be more generic than it is and it isn't clear why a set of them is produced for a read. Existing tests, which have been updated to use the new name. Author: Ryan Blue <blue@apache.org> Closes apache#21145 from rdblue/SPARK-24073-revert-data-reader-factory-rename.
1 parent 6fbee72 commit 8f986c3

File tree

13 files changed

+123
-123
lines changed

13 files changed

+123
-123
lines changed

sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
* {@link ReadSupport#createReader(DataSourceOptions)} or
3232
* {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}.
3333
* It can mix in various query optimization interfaces to speed up the data scan. The actual scan
34-
* logic is delegated to {@link DataReaderFactory}s that are returned by
35-
* {@link #createDataReaderFactories()}.
34+
* logic is delegated to {@link InputPartition}s that are returned by
35+
* {@link #planInputPartitions()}.
3636
*
3737
* There are mainly 3 kinds of query optimizations:
3838
* 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column
@@ -62,8 +62,8 @@ public interface DataSourceReader {
6262
StructType readSchema();
6363

6464
/**
65-
* Returns a list of reader factories. Each factory is responsible for creating a data reader to
66-
* output data for one RDD partition. That means the number of factories returned here is same as
65+
* Returns a list of read tasks. Each task is responsible for creating a data reader to
66+
* output data for one RDD partition. That means the number of tasks returned here is same as
6767
* the number of RDD partitions this scan outputs.
6868
*
6969
* Note that, this may not be a full scan if the data source reader mixes in other optimization
@@ -73,5 +73,5 @@ public interface DataSourceReader {
7373
* If this method fails (by throwing an exception), the action would fail and no Spark job was
7474
* submitted.
7575
*/
76-
List<DataReaderFactory<Row>> createDataReaderFactories();
76+
List<InputPartition<Row>> planInputPartitions();
7777
}

sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java renamed to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,20 @@
2222
import org.apache.spark.annotation.InterfaceStability;
2323

2424
/**
25-
* A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is
25+
* An input partition returned by {@link DataSourceReader#planInputPartitions()} and is
2626
* responsible for creating the actual data reader. The relationship between
27-
* {@link DataReaderFactory} and {@link DataReader}
27+
* {@link InputPartition} and {@link InputPartitionReader}
2828
* is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}.
2929
*
30-
* Note that, the reader factory will be serialized and sent to executors, then the data reader
31-
* will be created on executors and do the actual reading. So {@link DataReaderFactory} must be
32-
* serializable and {@link DataReader} doesn't need to be.
30+
* Note that input partitions will be serialized and sent to executors, then the partition reader
31+
* will be created on executors and do the actual reading. So {@link InputPartition} must be
32+
* serializable and {@link InputPartitionReader} doesn't need to be.
3333
*/
3434
@InterfaceStability.Evolving
35-
public interface DataReaderFactory<T> extends Serializable {
35+
public interface InputPartition<T> extends Serializable {
3636

3737
/**
38-
* The preferred locations where the data reader returned by this reader factory can run faster,
38+
* The preferred locations where the data reader returned by this partition can run faster,
3939
* but Spark does not guarantee to run the data reader on these locations.
4040
* The implementations should make sure that it can be run on any location.
4141
* The location is a string representing the host name.
@@ -57,5 +57,5 @@ default String[] preferredLocations() {
5757
* If this method fails (by throwing an exception), the corresponding Spark task would fail and
5858
* get retried until hitting the maximum retry times.
5959
*/
60-
DataReader<T> createDataReader();
60+
InputPartitionReader<T> createPartitionReader();
6161
}

sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java renamed to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@
2323
import org.apache.spark.annotation.InterfaceStability;
2424

2525
/**
26-
* A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for
26+
* A data reader returned by {@link InputPartition#createPartitionReader()} and is responsible for
2727
* outputting data for a RDD partition.
2828
*
2929
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data
3030
* source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source
3131
* readers that mix in {@link SupportsScanUnsafeRow}.
3232
*/
3333
@InterfaceStability.Evolving
34-
public interface DataReader<T> extends Closeable {
34+
public interface InputPartitionReader<T> extends Closeable {
3535

3636
/**
3737
* Proceed to next record, returns false if there is no more records.

sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@
3333
public interface SupportsScanUnsafeRow extends DataSourceReader {
3434

3535
@Override
36-
default List<DataReaderFactory<Row>> createDataReaderFactories() {
36+
default List<InputPartition<Row>> planInputPartitions() {
3737
throw new IllegalStateException(
38-
"createDataReaderFactories not supported by default within SupportsScanUnsafeRow");
38+
"planInputPartitions not supported by default within SupportsScanUnsafeRow");
3939
}
4040

4141
/**
42-
* Similar to {@link DataSourceV2Reader#createDataReaderFactories()},
42+
* Similar to {@link DataSourceReader#planInputPartitions()},
4343
* but returns data in unsafe row format.
4444
*/
45-
List<DataReaderFactory<UnsafeRow>> createUnsafeRowReaderFactories();
45+
List<InputPartition<UnsafeRow>> planUnsafeInputPartitions();
4646
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,29 @@
1717

1818
package org.apache.spark.sql.execution.datasources.v2
1919

20-
import scala.collection.JavaConverters._
2120
import scala.reflect.ClassTag
2221

2322
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
2423
import org.apache.spark.rdd.RDD
25-
import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
24+
import org.apache.spark.sql.sources.v2.reader.InputPartition
2625

27-
class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T])
26+
class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: InputPartition[T])
2827
extends Partition with Serializable
2928

3029
class DataSourceRDD[T : ClassTag](
3130
sc: SparkContext,
32-
@transient private val readerFactories: java.util.List[DataReaderFactory[T]])
31+
@transient private val readerFactories: Seq[InputPartition[T]])
3332
extends RDD[T](sc, Nil) {
3433

3534
override protected def getPartitions: Array[Partition] = {
36-
readerFactories.asScala.zipWithIndex.map {
35+
readerFactories.zipWithIndex.map {
3736
case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
3837
}.toArray
3938
}
4039

4140
override def compute(split: Partition, context: TaskContext): Iterator[T] = {
42-
val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader()
41+
val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition
42+
.createPartitionReader()
4343
context.addTaskCompletionListener(_ => reader.close())
4444
val iter = new Iterator[T] {
4545
private[this] var valuePrepared = false
@@ -63,6 +63,6 @@ class DataSourceRDD[T : ClassTag](
6363
}
6464

6565
override def getPreferredLocations(split: Partition): Seq[String] = {
66-
split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations()
66+
split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition.preferredLocations()
6767
}
6868
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.datasources.v2
1919

20-
import java.util.Objects
21-
22-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
20+
import org.apache.spark.sql.catalyst.expressions.AttributeReference
2321
import org.apache.spark.sql.sources.v2.reader._
2422

2523
/**
@@ -48,17 +46,4 @@ trait DataSourceReaderHolder {
4846
}
4947
Seq(output, reader.getClass, reader.readSchema(), filters)
5048
}
51-
52-
def canEqual(other: Any): Boolean
53-
54-
override def equals(other: Any): Boolean = other match {
55-
case other: DataSourceReaderHolder =>
56-
canEqual(other) && metadata.length == other.metadata.length &&
57-
metadata.zip(other.metadata).forall { case (l, r) => l == r }
58-
case _ => false
59-
}
60-
61-
override def hashCode(): Int = {
62-
metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
63-
}
6449
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,34 @@ case class DataSourceV2ScanExec(
4343

4444
override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec]
4545

46-
override def references: AttributeSet = AttributeSet.empty
46+
// TODO: unify the equal/hashCode implementation for all data source v2 query plans.
47+
override def equals(other: Any): Boolean = other match {
48+
case other: DataSourceV2ScanExec =>
49+
output == other.output && reader.getClass == other.reader.getClass && options == other.options
50+
case _ => false
51+
}
52+
53+
override def hashCode(): Int = {
54+
Seq(output, source, options).hashCode()
55+
}
56+
57+
private lazy val partitions: Seq[InputPartition[UnsafeRow]] = reader match {
58+
case r: SupportsScanUnsafeRow => r.planUnsafeInputPartitions().asScala
59+
case _ =>
60+
reader.planInputPartitions().asScala.map {
61+
new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[UnsafeRow]
62+
}
63+
}
64+
65+
private lazy val inputRDD: RDD[InternalRow] = reader match {
66+
case _ =>
67+
new DataSourceRDD(sparkContext, partitions).asInstanceOf[RDD[InternalRow]]
68+
}
4769

4870
override lazy val metrics = Map(
4971
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
5072

5173
override protected def doExecute(): RDD[InternalRow] = {
52-
val readTasks: java.util.List[DataReaderFactory[UnsafeRow]] = reader match {
53-
case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories()
54-
case _ =>
55-
reader.createDataReaderFactories().asScala.map {
56-
new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow]
57-
}.asJava
58-
}
59-
60-
val inputRDD = new DataSourceRDD(sparkContext, readTasks)
61-
.asInstanceOf[RDD[InternalRow]]
6274
val numOutputRows = longMetric("numOutputRows")
6375
inputRDD.map { r =>
6476
numOutputRows += 1
@@ -67,19 +79,22 @@ case class DataSourceV2ScanExec(
6779
}
6880
}
6981

70-
class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType)
71-
extends DataReaderFactory[UnsafeRow] {
82+
class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType)
83+
extends InputPartition[UnsafeRow] {
7284

73-
override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations
85+
override def preferredLocations: Array[String] = partition.preferredLocations
7486

75-
override def createDataReader: DataReader[UnsafeRow] = {
76-
new RowToUnsafeDataReader(
77-
rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind())
87+
override def createPartitionReader: InputPartitionReader[UnsafeRow] = {
88+
new RowToUnsafeInputPartitionReader(
89+
partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind())
7890
}
7991
}
8092

81-
class RowToUnsafeDataReader(rowReader: DataReader[Row], encoder: ExpressionEncoder[Row])
82-
extends DataReader[UnsafeRow] {
93+
class RowToUnsafeInputPartitionReader(
94+
val rowReader: InputPartitionReader[Row],
95+
encoder: ExpressionEncoder[Row])
96+
97+
extends InputPartitionReader[UnsafeRow] {
8398

8499
override def next: Boolean = rowReader.next
85100

sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ public Filter[] pushedFilters() {
7979
}
8080

8181
@Override
82-
public List<DataReaderFactory<Row>> createDataReaderFactories() {
83-
List<DataReaderFactory<Row>> res = new ArrayList<>();
82+
public List<InputPartition<Row>> planInputPartitions() {
83+
List<InputPartition<Row>> res = new ArrayList<>();
8484

8585
Integer lowerBound = null;
8686
for (Filter filter : filters) {
@@ -94,33 +94,33 @@ public List<DataReaderFactory<Row>> createDataReaderFactories() {
9494
}
9595

9696
if (lowerBound == null) {
97-
res.add(new JavaAdvancedDataReaderFactory(0, 5, requiredSchema));
98-
res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema));
97+
res.add(new JavaAdvancedInputPartition(0, 5, requiredSchema));
98+
res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema));
9999
} else if (lowerBound < 4) {
100-
res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 5, requiredSchema));
101-
res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema));
100+
res.add(new JavaAdvancedInputPartition(lowerBound + 1, 5, requiredSchema));
101+
res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema));
102102
} else if (lowerBound < 9) {
103-
res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 10, requiredSchema));
103+
res.add(new JavaAdvancedInputPartition(lowerBound + 1, 10, requiredSchema));
104104
}
105105

106106
return res;
107107
}
108108
}
109109

110-
static class JavaAdvancedDataReaderFactory implements DataReaderFactory<Row>, DataReader<Row> {
110+
static class JavaAdvancedInputPartition implements InputPartition<Row>, InputPartitionReader<Row> {
111111
private int start;
112112
private int end;
113113
private StructType requiredSchema;
114114

115-
JavaAdvancedDataReaderFactory(int start, int end, StructType requiredSchema) {
115+
JavaAdvancedInputPartition(int start, int end, StructType requiredSchema) {
116116
this.start = start;
117117
this.end = end;
118118
this.requiredSchema = requiredSchema;
119119
}
120120

121121
@Override
122-
public DataReader<Row> createDataReader() {
123-
return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema);
122+
public InputPartitionReader<Row> createPartitionReader() {
123+
return new JavaAdvancedInputPartition(start - 1, end, requiredSchema);
124124
}
125125

126126
@Override

sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import org.apache.spark.sql.sources.v2.DataSourceV2;
2525
import org.apache.spark.sql.sources.v2.ReadSupportWithSchema;
2626
import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
27-
import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
27+
import org.apache.spark.sql.sources.v2.reader.InputPartition;
2828
import org.apache.spark.sql.types.StructType;
2929

3030
public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema {
@@ -42,7 +42,7 @@ public StructType readSchema() {
4242
}
4343

4444
@Override
45-
public List<DataReaderFactory<Row>> createDataReaderFactories() {
45+
public List<InputPartition<Row>> planInputPartitions() {
4646
return java.util.Collections.emptyList();
4747
}
4848
}

sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
import org.apache.spark.sql.sources.v2.DataSourceV2;
2626
import org.apache.spark.sql.sources.v2.DataSourceOptions;
2727
import org.apache.spark.sql.sources.v2.ReadSupport;
28-
import org.apache.spark.sql.sources.v2.reader.DataReader;
29-
import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
28+
import org.apache.spark.sql.sources.v2.reader.InputPartitionReader;
29+
import org.apache.spark.sql.sources.v2.reader.InputPartition;
3030
import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
3131
import org.apache.spark.sql.types.StructType;
3232

@@ -41,25 +41,25 @@ public StructType readSchema() {
4141
}
4242

4343
@Override
44-
public List<DataReaderFactory<Row>> createDataReaderFactories() {
44+
public List<InputPartition<Row>> planInputPartitions() {
4545
return java.util.Arrays.asList(
46-
new JavaSimpleDataReaderFactory(0, 5),
47-
new JavaSimpleDataReaderFactory(5, 10));
46+
new JavaSimpleInputPartition(0, 5),
47+
new JavaSimpleInputPartition(5, 10));
4848
}
4949
}
5050

51-
static class JavaSimpleDataReaderFactory implements DataReaderFactory<Row>, DataReader<Row> {
51+
static class JavaSimpleInputPartition implements InputPartition<Row>, InputPartitionReader<Row> {
5252
private int start;
5353
private int end;
5454

55-
JavaSimpleDataReaderFactory(int start, int end) {
55+
JavaSimpleInputPartition(int start, int end) {
5656
this.start = start;
5757
this.end = end;
5858
}
5959

6060
@Override
61-
public DataReader<Row> createDataReader() {
62-
return new JavaSimpleDataReaderFactory(start - 1, end);
61+
public InputPartitionReader<Row> createPartitionReader() {
62+
return new JavaSimpleInputPartition(start - 1, end);
6363
}
6464

6565
@Override

0 commit comments

Comments
 (0)