@@ -43,8 +43,7 @@ import org.apache.spark.util.ArrayImplicits._
43
43
* A user-defined Python data source. This is used by the Python API.
44
44
* Defines the interation between Python and JVM.
45
45
*
46
- * @param dataSourceCls
47
- * The Python data source class.
46
+ * @param dataSourceCls The Python data source class.
48
47
*/
49
48
case class UserDefinedPythonDataSource (dataSourceCls : PythonFunction ) {
50
49
@@ -95,7 +94,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
95
94
pythonResult : PythonDataSourceReader ,
96
95
outputSchema : StructType ,
97
96
isStreaming : Boolean ): PythonDataSourceReadInfo = {
98
- new PartitionRunner (
97
+ new UserDefinedPythonDataSourcePartitionRunner (
99
98
createPythonFunction(pythonResult.reader),
100
99
UserDefinedPythonDataSource .readInputSchema,
101
100
outputSchema,
@@ -327,10 +326,64 @@ private class UserDefinedPythonDataSourceRunner(
327
326
}
328
327
}
329
328
329
+ case class PythonDataSourceReader (reader : Array [Byte ], isStreaming : Boolean )
330
+
331
+ /**
332
+ * Instantiate the reader of a Python data source.
333
+ *
334
+ * @param func
335
+ * a Python data source instance
336
+ * @param outputSchema
337
+ * output schema of the Python data source
338
+ * @param isStreaming
339
+ * whether it is a streaming read
340
+ */
341
+ private class UserDefinedPythonDataSourceReaderRunner (
342
+ func : PythonFunction ,
343
+ outputSchema : StructType ,
344
+ isStreaming : Boolean )
345
+ extends PythonPlannerRunner [PythonDataSourceReader ](func) {
346
+
347
+ // See the logic in `pyspark.sql.worker.data_source_get_reader.py`.
348
+ override val workerModule = " pyspark.sql.worker.data_source_get_reader"
349
+
350
+ override protected def writeToPython (dataOut : DataOutputStream , pickler : Pickler ): Unit = {
351
+ // Send Python data source
352
+ PythonWorkerUtils .writePythonFunction(func, dataOut)
353
+
354
+ // Send output schema
355
+ PythonWorkerUtils .writeUTF(outputSchema.json, dataOut)
356
+
357
+ dataOut.writeBoolean(isStreaming)
358
+ }
359
+
360
+ override protected def receiveFromPython (dataIn : DataInputStream ): PythonDataSourceReader = {
361
+ // Receive the picked reader or an exception raised in Python worker.
362
+ val length = dataIn.readInt()
363
+ if (length == SpecialLengths .PYTHON_EXCEPTION_THROWN ) {
364
+ val msg = PythonWorkerUtils .readUTF(dataIn)
365
+ throw QueryCompilationErrors .pythonDataSourceError(action = " plan" , tpe = " read" , msg = msg)
366
+ }
367
+
368
+ // Receive the pickled reader.
369
+ val pickledFunction : Array [Byte ] = PythonWorkerUtils .readBytes(length, dataIn)
370
+
371
+ PythonDataSourceReader (reader = pickledFunction, isStreaming = isStreaming)
372
+ }
373
+ }
374
+
330
375
case class PythonFilterPushdownResult (
331
376
reader : PythonDataSourceReader ,
332
377
isFilterPushed : collection.Seq [Boolean ])
333
378
379
+ /**
380
+ * Push down filters to a Python data source.
381
+ *
382
+ * @param reader
383
+ * a Python data source reader instance
384
+ * @param filters
385
+ * all filters to be pushed down
386
+ */
334
387
private class UserDefinedPythonDataSourceFilterPushdownRunner (
335
388
reader : PythonFunction ,
336
389
filters : collection.Seq [Filter ])
@@ -346,7 +399,7 @@ private class UserDefinedPythonDataSourceFilterPushdownRunner(
346
399
case (filter, i) =>
347
400
filter match {
348
401
case filter @ org.apache.spark.sql.sources.EqualTo (_, value : Int ) =>
349
- val columnPath = filter.v2references.head
402
+ val columnPath = filter.v2references.head
350
403
Some (SerializedFilter (" EqualTo" , columnPath, value, i))
351
404
case _ =>
352
405
None
@@ -381,7 +434,7 @@ private class UserDefinedPythonDataSourceFilterPushdownRunner(
381
434
throw QueryCompilationErrors .pythonDataSourceError(action = " plan" , tpe = " read" , msg = msg)
382
435
}
383
436
384
- // Receive the pickled ' reader' .
437
+ // Receive the pickled reader.
385
438
val pickledReader : Array [Byte ] = PythonWorkerUtils .readBytes(length, dataIn)
386
439
387
440
// Receive the pushed filters as a list of indices.
@@ -399,50 +452,6 @@ private class UserDefinedPythonDataSourceFilterPushdownRunner(
399
452
}
400
453
}
401
454
402
- case class PythonDataSourceReader (reader : Array [Byte ], isStreaming : Boolean )
403
-
404
- /**
405
- * Send information to a Python process to plan a Python data source read.
406
- *
407
- * @param func
408
- * an Python data source instance
409
- * @param outputSchema
410
- * output schema of the Python data source
411
- */
412
- private class UserDefinedPythonDataSourceReaderRunner (
413
- func : PythonFunction ,
414
- outputSchema : StructType ,
415
- isStreaming : Boolean )
416
- extends PythonPlannerRunner [PythonDataSourceReader ](func) {
417
-
418
- // See the logic in `pyspark.sql.worker.data_source_get_reader.py`.
419
- override val workerModule = " pyspark.sql.worker.data_source_get_reader"
420
-
421
- override protected def writeToPython (dataOut : DataOutputStream , pickler : Pickler ): Unit = {
422
- // Send Python data source
423
- PythonWorkerUtils .writePythonFunction(func, dataOut)
424
-
425
- // Send output schema
426
- PythonWorkerUtils .writeUTF(outputSchema.json, dataOut)
427
-
428
- dataOut.writeBoolean(isStreaming)
429
- }
430
-
431
- override protected def receiveFromPython (dataIn : DataInputStream ): PythonDataSourceReader = {
432
- // Receive the picked reader or an exception raised in Python worker.
433
- val length = dataIn.readInt()
434
- if (length == SpecialLengths .PYTHON_EXCEPTION_THROWN ) {
435
- val msg = PythonWorkerUtils .readUTF(dataIn)
436
- throw QueryCompilationErrors .pythonDataSourceError(action = " plan" , tpe = " read" , msg = msg)
437
- }
438
-
439
- // Receive the pickled 'read' function.
440
- val pickledFunction : Array [Byte ] = PythonWorkerUtils .readBytes(length, dataIn)
441
-
442
- PythonDataSourceReader (reader = pickledFunction, isStreaming = isStreaming)
443
- }
444
- }
445
-
446
455
case class PythonDataSourceReadInfo (
447
456
func : Array [Byte ],
448
457
partitions : Seq [Array [Byte ]])
@@ -459,7 +468,7 @@ case class PythonDataSourceReadInfo(
459
468
* @param isStreaming
460
469
* whether it is a streaming read
461
470
*/
462
- private class PartitionRunner (
471
+ private class UserDefinedPythonDataSourcePartitionRunner (
463
472
reader : PythonFunction ,
464
473
inputSchema : StructType ,
465
474
outputSchema : StructType ,
0 commit comments