Skip to content

Commit 8dc1adf

Browse files
author
Davies Liu
committed
improve performance, address comments
1 parent 8e6e5bc commit 8dc1adf

File tree

3 files changed

+86
-50
lines changed

3 files changed

+86
-50
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,30 +77,42 @@ private[spark] case class PythonFunction(
7777
broadcastVars: JList[Broadcast[PythonBroadcast]],
7878
accumulator: Accumulator[JList[Array[Byte]]])
7979

80+
/**
81+
* A wrapper for chained Python functions (from bottom to top).
82+
* @param funcs
83+
*/
84+
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])
8085

81-
object PythonRunner {
86+
private[spark] object PythonRunner {
8287
def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
83-
new PythonRunner(Seq(Seq(func)), bufferSize, reuse_worker, false, Seq(1))
88+
new PythonRunner(
89+
Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Seq(Seq(0)))
8490
}
8591
}
8692

8793
/**
8894
* A helper class to run Python mapPartition/UDFs in Spark.
95+
*
96+
* funcs is a list of independent Python functions, each one of them is a list of chained Python
97+
* functions (from bottom to top).
8998
*/
9099
private[spark] class PythonRunner(
91-
funcs: Seq[Seq[PythonFunction]],
100+
funcs: Seq[ChainedPythonFunctions],
92101
bufferSize: Int,
93102
reuse_worker: Boolean,
94103
isUDF: Boolean,
95-
numArgs: Seq[Int])
104+
argOffsets: Seq[Seq[Int]])
96105
extends Logging {
97106

107+
require(funcs.length == argOffsets.length, "numArgs should have the same length as funcs")
108+
98109
// All the Python functions should have the same exec, version and envvars.
99-
private val envVars = funcs.head.head.envVars
100-
private val pythonExec = funcs.head.head.pythonExec
101-
private val pythonVer = funcs.head.head.pythonVer
110+
private val envVars = funcs.head.funcs.head.envVars
111+
private val pythonExec = funcs.head.funcs.head.pythonExec
112+
private val pythonVer = funcs.head.funcs.head.pythonVer
102113

103-
private val accumulator = funcs.head.head.accumulator // TODO: support accumulator in multiple UDF
114+
// TODO: support accumulator in multiple UDF
115+
private val accumulator = funcs.head.funcs.head.accumulator
104116

105117
def compute(
106118
inputIterator: Iterator[_],
@@ -240,8 +252,8 @@ private[spark] class PythonRunner(
240252

241253
@volatile private var _exception: Exception = null
242254

243-
private val pythonIncludes = funcs.flatMap(_.flatMap(_.pythonIncludes.asScala)).toSet
244-
private val broadcastVars = funcs.flatMap(_.flatMap(_.broadcastVars.asScala))
255+
private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
256+
private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
245257

246258
setDaemon(true)
247259

@@ -295,17 +307,20 @@ private[spark] class PythonRunner(
295307
if (isUDF) {
296308
dataOut.writeInt(1)
297309
dataOut.writeInt(funcs.length)
298-
funcs.zip(numArgs).foreach { case (fs, numArg) =>
299-
dataOut.writeInt(numArg)
300-
dataOut.writeInt(fs.length)
301-
fs.foreach { f =>
310+
funcs.zip(argOffsets).foreach { case (chained, offsets) =>
311+
dataOut.writeInt(offsets.length)
312+
offsets.foreach { offset =>
313+
dataOut.writeInt(offset)
314+
}
315+
dataOut.writeInt(chained.funcs.length)
316+
chained.funcs.foreach { f =>
302317
dataOut.writeInt(f.command.length)
303318
dataOut.write(f.command)
304319
}
305320
}
306321
} else {
307322
dataOut.writeInt(0)
308-
val command = funcs.head.head.command
323+
val command = funcs.head.funcs.head.command
309324
dataOut.writeInt(command.length)
310325
dataOut.write(command)
311326
}

python/pyspark/worker.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,13 @@ def chain(f, g):
6363

6464

6565
def wrap_udf(f, return_type):
66-
return lambda *a: return_type.toInternal(f(*a))
66+
toInternal = return_type.toInternal
67+
return lambda *a: toInternal(f(*a))
6768

6869

6970
def read_single_udf(pickleSer, infile):
7071
num_arg = read_int(infile)
72+
arg_offsets = [read_int(infile) for i in range(num_arg)]
7173
row_func = None
7274
for i in range(read_int(infile)):
7375
f, return_type = read_command(pickleSer, infile)
@@ -76,27 +78,27 @@ def read_single_udf(pickleSer, infile):
7678
else:
7779
row_func = chain(row_func, f)
7880
# the last returnType will be the return type of UDF
79-
return num_arg, wrap_udf(row_func, return_type)
81+
return arg_offsets, wrap_udf(row_func, return_type)
8082

8183

8284
def read_udfs(pickleSer, infile):
8385
num_udfs = read_int(infile)
84-
udfs = []
85-
offset = 0
86-
for i in range(num_udfs):
87-
num_arg, udf = read_single_udf(pickleSer, infile)
88-
udfs.append((offset, offset + num_arg, udf))
89-
offset += num_arg
90-
9186
if num_udfs == 1:
92-
udf = udfs[0][2]
93-
9487
# fast path for single UDF
95-
def mapper(args):
96-
return udf(*args)
88+
_, udf = read_single_udf(pickleSer, infile)
89+
mapper = lambda a: udf(*a)
9790
else:
98-
def mapper(args):
99-
return tuple(udf(*args[start:end]) for start, end, udf in udfs)
91+
udfs = {}
92+
call_udf = []
93+
for i in range(num_udfs):
94+
arg_offsets, udf = read_single_udf(pickleSer, infile)
95+
udfs['f%d' % i] = udf
96+
args = ["a[%d]" % o for o in arg_offsets]
97+
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
98+
# Create function like this:
99+
# lambda a: (f0(a0), f1(a1, a2), f2(a3))
100+
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
101+
mapper = eval(mapper_str, udfs)
100102

101103
func = lambda _, it: map(mapper, it)
102104
ser = AutoBatchedSerializer(PickleSerializer())
@@ -149,8 +151,8 @@ def main(infile, outfile):
149151
_broadcastRegistry.pop(bid)
150152

151153
_accumulatorRegistry.clear()
152-
is_udf = read_int(infile)
153-
if is_udf:
154+
is_sql_udf = read_int(infile)
155+
if is_sql_udf:
154156
func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
155157
else:
156158
func, profiler, deserializer, serializer = read_command(pickleSer, infile)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,17 @@
1818
package org.apache.spark.sql.execution.python
1919

2020
import scala.collection.JavaConverters._
21+
import scala.collection.mutable.ArrayBuffer
2122

2223
import net.razorvine.pickle.{Pickler, Unpickler}
2324

2425
import org.apache.spark.TaskContext
25-
import org.apache.spark.api.python.{PythonFunction, PythonRunner}
26+
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonFunction, PythonRunner}
2627
import org.apache.spark.rdd.RDD
2728
import org.apache.spark.sql.catalyst.InternalRow
2829
import org.apache.spark.sql.catalyst.expressions._
2930
import org.apache.spark.sql.execution.SparkPlan
30-
import org.apache.spark.sql.types.{StructField, StructType}
31+
import org.apache.spark.sql.types.{DataType, StructField, StructType}
3132

3233

3334
/**
@@ -45,15 +46,15 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c
4546

4647
def children: Seq[SparkPlan] = child :: Nil
4748

48-
private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = {
49+
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
4950
udf.children match {
5051
case Seq(u: PythonUDF) =>
51-
val (fs, children) = collectFunctions(u)
52-
(fs ++ Seq(udf.func), children)
52+
val (chained, children) = collectFunctions(u)
53+
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
5354
case children =>
5455
// There should not be any other UDFs, or the children can't be evaluated directly.
5556
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
56-
(Seq(udf.func), udf.children)
57+
(ChainedPythonFunctions(Seq(udf.func)), udf.children)
5758
}
5859
}
5960

@@ -69,30 +70,48 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c
6970
// combine input with output from Python.
7071
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
7172

72-
val (pyFuncs, children) = udfs.map(collectFunctions).unzip
73-
val numArgs = children.map(_.length)
73+
val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
7474

75-
val pickle = new Pickler
75+
// Most of the inputs are primitives, do not use memo for better performance
76+
val pickle = new Pickler(false)
7677
// flatten all the arguments
77-
val allChildren = children.flatMap(x => x)
78-
val currentRow = newMutableProjection(allChildren, child.output)()
79-
val fields = allChildren.map(_.dataType)
80-
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
78+
val allInputs = new ArrayBuffer[Expression]
79+
val dataTypes = new ArrayBuffer[DataType]
80+
val argOffsets = inputs.map { input =>
81+
input.map { e =>
82+
if (allInputs.exists(_.semanticEquals(e))) {
83+
allInputs.indexWhere(_.semanticEquals(e))
84+
} else {
85+
allInputs += e
86+
dataTypes += e.dataType
87+
allInputs.length - 1
88+
}
89+
}
90+
}
91+
val projection = newMutableProjection(allInputs, child.output)()
8192

8293
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
8394
// For each row, add it to the queue.
84-
val inputIterator = iter.grouped(100).map { inputRows =>
85-
val toBePickled = inputRows.map { row =>
86-
queue.add(row)
87-
EvaluatePython.toJava(currentRow(row), schema)
95+
val inputIterator = iter.grouped(1024).map { inputRows =>
96+
val toBePickled = inputRows.map { inputRow =>
97+
queue.add(inputRow)
98+
val row = projection(inputRow)
99+
val fields = new Array[Any](row.numFields)
100+
var i = 0
101+
while (i < row.numFields) {
102+
val dt = dataTypes(i)
103+
fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
104+
i += 1
105+
}
106+
fields
88107
}.toArray
89108
pickle.dumps(toBePickled)
90109
}
91110

92111
val context = TaskContext.get()
93112

94113
// Output iterator for results from Python.
95-
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, numArgs)
114+
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets)
96115
.compute(inputIterator, context.partitionId(), context)
97116

98117
val unpickle = new Unpickler

0 commit comments

Comments
 (0)