@@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer
23
23
import net .razorvine .pickle .{Pickler , Unpickler }
24
24
25
25
import org .apache .spark .TaskContext
26
- import org .apache .spark .api .python .{ChainedPythonFunctions , PythonFunction , PythonRunner }
26
+ import org .apache .spark .api .python .{ChainedPythonFunctions , PythonRunner }
27
27
import org .apache .spark .rdd .RDD
28
28
import org .apache .spark .sql .catalyst .InternalRow
29
29
import org .apache .spark .sql .catalyst .expressions ._
@@ -72,8 +72,6 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c
72
72
73
73
val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
74
74
75
- // Most of the inputs are primitives, do not use memo for better performance
76
- val pickle = new Pickler (false )
77
75
// flatten all the arguments
78
76
val allInputs = new ArrayBuffer [Expression ]
79
77
val dataTypes = new ArrayBuffer [DataType ]
@@ -89,21 +87,30 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c
89
87
}.toArray
90
88
}.toArray
91
89
val projection = newMutableProjection(allInputs, child.output)()
90
+ val schema = StructType (dataTypes.map(dt => StructField (" " , dt)))
91
+ val needConversion = dataTypes.exists(EvaluatePython .needConversionInPython)
92
92
93
+ // enable memo iff we serialize the row with schema (schema and class should be memorized)
94
+ val pickle = new Pickler (needConversion)
93
95
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
94
96
// For each row, add it to the queue.
95
97
val inputIterator = iter.grouped(100 ).map { inputRows =>
96
98
val toBePickled = inputRows.map { inputRow =>
97
99
queue.add(inputRow)
98
100
val row = projection(inputRow)
99
- val fields = new Array [Any ](row.numFields)
100
- var i = 0
101
- while (i < row.numFields) {
102
- val dt = dataTypes(i)
103
- fields(i) = EvaluatePython .toJava(row.get(i, dt), dt)
104
- i += 1
101
+ if (needConversion) {
102
+ EvaluatePython .toJava(row, schema)
103
+ } else {
104
+ // fast path for these types that does not need conversion in Python
105
+ val fields = new Array [Any ](row.numFields)
106
+ var i = 0
107
+ while (i < row.numFields) {
108
+ val dt = dataTypes(i)
109
+ fields(i) = EvaluatePython .toJava(row.get(i, dt), dt)
110
+ i += 1
111
+ }
112
+ fields
105
113
}
106
- fields
107
114
}.toArray
108
115
pickle.dumps(toBePickled)
109
116
}
0 commit comments