Skip to content

Commit d4a757a

Browse files
committed
clean up UserDefinedPythonDataSource.scala
1 parent 5064aa3 commit d4a757a

File tree

1 file changed

+59
-50
lines changed

1 file changed

+59
-50
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala

Lines changed: 59 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ import org.apache.spark.util.ArrayImplicits._
4343
* A user-defined Python data source. This is used by the Python API.
4444
* Defines the interation between Python and JVM.
4545
*
46-
* @param dataSourceCls
47-
* The Python data source class.
46+
* @param dataSourceCls The Python data source class.
4847
*/
4948
case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
5049

@@ -95,7 +94,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
9594
pythonResult: PythonDataSourceReader,
9695
outputSchema: StructType,
9796
isStreaming: Boolean): PythonDataSourceReadInfo = {
98-
new PartitionRunner(
97+
new UserDefinedPythonDataSourcePartitionRunner(
9998
createPythonFunction(pythonResult.reader),
10099
UserDefinedPythonDataSource.readInputSchema,
101100
outputSchema,
@@ -327,10 +326,64 @@ private class UserDefinedPythonDataSourceRunner(
327326
}
328327
}
329328

329+
case class PythonDataSourceReader(reader: Array[Byte], isStreaming: Boolean)
330+
331+
/**
332+
* Instantiate the reader of a Python data source.
333+
*
334+
* @param func
335+
* a Python data source instance
336+
* @param outputSchema
337+
* output schema of the Python data source
338+
* @param isStreaming
339+
* whether it is a streaming read
340+
*/
341+
private class UserDefinedPythonDataSourceReaderRunner(
342+
func: PythonFunction,
343+
outputSchema: StructType,
344+
isStreaming: Boolean)
345+
extends PythonPlannerRunner[PythonDataSourceReader](func) {
346+
347+
// See the logic in `pyspark.sql.worker.data_source_get_reader.py`.
348+
override val workerModule = "pyspark.sql.worker.data_source_get_reader"
349+
350+
override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = {
351+
// Send Python data source
352+
PythonWorkerUtils.writePythonFunction(func, dataOut)
353+
354+
// Send output schema
355+
PythonWorkerUtils.writeUTF(outputSchema.json, dataOut)
356+
357+
dataOut.writeBoolean(isStreaming)
358+
}
359+
360+
override protected def receiveFromPython(dataIn: DataInputStream): PythonDataSourceReader = {
361+
// Receive the picked reader or an exception raised in Python worker.
362+
val length = dataIn.readInt()
363+
if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
364+
val msg = PythonWorkerUtils.readUTF(dataIn)
365+
throw QueryCompilationErrors.pythonDataSourceError(action = "plan", tpe = "read", msg = msg)
366+
}
367+
368+
// Receive the pickled reader.
369+
val pickledFunction: Array[Byte] = PythonWorkerUtils.readBytes(length, dataIn)
370+
371+
PythonDataSourceReader(reader = pickledFunction, isStreaming = isStreaming)
372+
}
373+
}
374+
330375
case class PythonFilterPushdownResult(
331376
reader: PythonDataSourceReader,
332377
isFilterPushed: collection.Seq[Boolean])
333378

379+
/**
380+
* Push down filters to a Python data source.
381+
*
382+
* @param reader
383+
* a Python data source reader instance
384+
* @param filters
385+
* all filters to be pushed down
386+
*/
334387
private class UserDefinedPythonDataSourceFilterPushdownRunner(
335388
reader: PythonFunction,
336389
filters: collection.Seq[Filter])
@@ -346,7 +399,7 @@ private class UserDefinedPythonDataSourceFilterPushdownRunner(
346399
case (filter, i) =>
347400
filter match {
348401
case filter @ org.apache.spark.sql.sources.EqualTo(_, value: Int) =>
349-
val columnPath = filter.v2references.head
402+
val columnPath = filter.v2references.head
350403
Some(SerializedFilter("EqualTo", columnPath, value, i))
351404
case _ =>
352405
None
@@ -381,7 +434,7 @@ private class UserDefinedPythonDataSourceFilterPushdownRunner(
381434
throw QueryCompilationErrors.pythonDataSourceError(action = "plan", tpe = "read", msg = msg)
382435
}
383436

384-
// Receive the pickled 'reader'.
437+
// Receive the pickled reader.
385438
val pickledReader: Array[Byte] = PythonWorkerUtils.readBytes(length, dataIn)
386439

387440
// Receive the pushed filters as a list of indices.
@@ -399,50 +452,6 @@ private class UserDefinedPythonDataSourceFilterPushdownRunner(
399452
}
400453
}
401454

402-
case class PythonDataSourceReader(reader: Array[Byte], isStreaming: Boolean)
403-
404-
/**
405-
* Send information to a Python process to plan a Python data source read.
406-
*
407-
* @param func
408-
* an Python data source instance
409-
* @param outputSchema
410-
* output schema of the Python data source
411-
*/
412-
private class UserDefinedPythonDataSourceReaderRunner(
413-
func: PythonFunction,
414-
outputSchema: StructType,
415-
isStreaming: Boolean)
416-
extends PythonPlannerRunner[PythonDataSourceReader](func) {
417-
418-
// See the logic in `pyspark.sql.worker.data_source_get_reader.py`.
419-
override val workerModule = "pyspark.sql.worker.data_source_get_reader"
420-
421-
override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = {
422-
// Send Python data source
423-
PythonWorkerUtils.writePythonFunction(func, dataOut)
424-
425-
// Send output schema
426-
PythonWorkerUtils.writeUTF(outputSchema.json, dataOut)
427-
428-
dataOut.writeBoolean(isStreaming)
429-
}
430-
431-
override protected def receiveFromPython(dataIn: DataInputStream): PythonDataSourceReader = {
432-
// Receive the picked reader or an exception raised in Python worker.
433-
val length = dataIn.readInt()
434-
if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
435-
val msg = PythonWorkerUtils.readUTF(dataIn)
436-
throw QueryCompilationErrors.pythonDataSourceError(action = "plan", tpe = "read", msg = msg)
437-
}
438-
439-
// Receive the pickled 'read' function.
440-
val pickledFunction: Array[Byte] = PythonWorkerUtils.readBytes(length, dataIn)
441-
442-
PythonDataSourceReader(reader = pickledFunction, isStreaming = isStreaming)
443-
}
444-
}
445-
446455
case class PythonDataSourceReadInfo(
447456
func: Array[Byte],
448457
partitions: Seq[Array[Byte]])
@@ -459,7 +468,7 @@ case class PythonDataSourceReadInfo(
459468
* @param isStreaming
460469
* whether it is a streaming read
461470
*/
462-
private class PartitionRunner(
471+
private class UserDefinedPythonDataSourcePartitionRunner(
463472
reader: PythonFunction,
464473
inputSchema: StructType,
465474
outputSchema: StructType,

0 commit comments

Comments
 (0)