Skip to content

Commit f0afafd

Browse files
Davies Liudavies
authored andcommitted
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request? This PR support multiple Python UDFs within single batch, also improve the performance. ```python >>> from pyspark.sql.types import IntegerType >>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType()) >>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType()) >>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True) == Parsed Logical Plan == 'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)] +- OneRowRelation$ == Analyzed Logical Plan == double(add(1, 2)): int, add(double(2), 1): int Project [double(add(1, 2))#14,add(double(2), 1)#15] +- Project [double(add(1, 2))#14,add(double(2), 1)#15] +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15] +- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18] +- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17] +- OneRowRelation$ == Optimized Logical Plan == Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15] +- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18] +- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17] +- OneRowRelation$ == Physical Plan == WholeStageCodegen : +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15] : +- INPUT +- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18] +- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17] +- Scan OneRowRelation[] ``` ## How was this patch tested? Added new tests. Using the following script to benchmark 1, 2 and 3 udfs, ``` df = sqlContext.range(1, 1 << 23, 1, 4) double = F.udf(lambda x: x * 2, LongType()) print df.select(double(df.id)).count() print df.select(double(df.id), double(df.id + 1)).count() print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count() ``` Here is the results: N | Before | After | speed up ---- |------------ | -------------|------ 1 | 22 s | 7 s | 3.1X 2 | 38 s | 13 s | 2.9X 3 | 58 s | 16 s | 3.6X This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering). Author: Davies Liu <davies@databricks.com> Closes #12057 from davies/multi_udfs.
1 parent 8de201b commit f0afafd

File tree

8 files changed

+233
-101
lines changed

8 files changed

+233
-101
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ private[spark] class PythonRDD(
5959
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
6060

6161
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
62-
val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false)
62+
val runner = PythonRunner(func, bufferSize, reuse_worker)
6363
runner.compute(firstParent.iterator(split, context), split.index, context)
6464
}
6565
}
@@ -78,21 +78,41 @@ private[spark] case class PythonFunction(
7878
accumulator: Accumulator[JList[Array[Byte]]])
7979

8080
/**
81-
* A helper class to run Python UDFs in Spark.
81+
* A wrapper for chained Python functions (from bottom to top).
82+
* @param funcs
83+
*/
84+
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])
85+
86+
private[spark] object PythonRunner {
87+
def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
88+
new PythonRunner(
89+
Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0)))
90+
}
91+
}
92+
93+
/**
94+
* A helper class to run Python mapPartition/UDFs in Spark.
95+
*
96+
* funcs is a list of independent Python functions, each one of them is a list of chained Python
97+
* functions (from bottom to top).
8298
*/
8399
private[spark] class PythonRunner(
84-
funcs: Seq[PythonFunction],
100+
funcs: Seq[ChainedPythonFunctions],
85101
bufferSize: Int,
86102
reuse_worker: Boolean,
87-
rowBased: Boolean)
103+
isUDF: Boolean,
104+
argOffsets: Array[Array[Int]])
88105
extends Logging {
89106

107+
require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
108+
90109
// All the Python functions should have the same exec, version and envvars.
91-
private val envVars = funcs.head.envVars
92-
private val pythonExec = funcs.head.pythonExec
93-
private val pythonVer = funcs.head.pythonVer
110+
private val envVars = funcs.head.funcs.head.envVars
111+
private val pythonExec = funcs.head.funcs.head.pythonExec
112+
private val pythonVer = funcs.head.funcs.head.pythonVer
94113

95-
private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF
114+
// TODO: support accumulator in multiple UDF
115+
private val accumulator = funcs.head.funcs.head.accumulator
96116

97117
def compute(
98118
inputIterator: Iterator[_],
@@ -232,8 +252,8 @@ private[spark] class PythonRunner(
232252

233253
@volatile private var _exception: Exception = null
234254

235-
private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet
236-
private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala)
255+
private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
256+
private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
237257

238258
setDaemon(true)
239259

@@ -284,11 +304,25 @@ private[spark] class PythonRunner(
284304
}
285305
dataOut.flush()
286306
// Serialized command:
287-
dataOut.writeInt(if (rowBased) 1 else 0)
288-
dataOut.writeInt(funcs.length)
289-
funcs.foreach { f =>
290-
dataOut.writeInt(f.command.length)
291-
dataOut.write(f.command)
307+
if (isUDF) {
308+
dataOut.writeInt(1)
309+
dataOut.writeInt(funcs.length)
310+
funcs.zip(argOffsets).foreach { case (chained, offsets) =>
311+
dataOut.writeInt(offsets.length)
312+
offsets.foreach { offset =>
313+
dataOut.writeInt(offset)
314+
}
315+
dataOut.writeInt(chained.funcs.length)
316+
chained.funcs.foreach { f =>
317+
dataOut.writeInt(f.command.length)
318+
dataOut.write(f.command)
319+
}
320+
}
321+
} else {
322+
dataOut.writeInt(0)
323+
val command = funcs.head.funcs.head.command
324+
dataOut.writeInt(command.length)
325+
dataOut.write(command)
292326
}
293327
// Data values
294328
PythonRDD.writeIteratorToStream(inputIterator, dataOut)

python/pyspark/sql/functions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,8 +1649,7 @@ def sort_array(col, asc=True):
16491649
# ---------------------------- User Defined Function ----------------------------------
16501650

16511651
def _wrap_function(sc, func, returnType):
1652-
ser = AutoBatchedSerializer(PickleSerializer())
1653-
command = (func, returnType, ser)
1652+
command = (func, returnType)
16541653
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
16551654
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
16561655
sc.pythonVer, broadcast_vars, sc._javaAccumulator)

python/pyspark/sql/tests.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def test_udf2(self):
305305
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
306306
self.assertEqual(4, res[0])
307307

308-
def test_chained_python_udf(self):
308+
def test_chained_udf(self):
309309
self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
310310
[row] = self.sqlCtx.sql("SELECT double(1)").collect()
311311
self.assertEqual(row[0], 2)
@@ -314,6 +314,16 @@ def test_chained_python_udf(self):
314314
[row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
315315
self.assertEqual(row[0], 6)
316316

317+
def test_multiple_udfs(self):
318+
self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType())
319+
[row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
320+
self.assertEqual(tuple(row), (2, 4))
321+
[row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
322+
self.assertEqual(tuple(row), (4, 12))
323+
self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType())
324+
[row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
325+
self.assertEqual(tuple(row), (6, 5))
326+
317327
def test_udf_with_array_type(self):
318328
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
319329
rdd = self.sc.parallelize(d)

python/pyspark/worker.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pyspark.broadcast import Broadcast, _broadcastRegistry
3030
from pyspark.files import SparkFiles
3131
from pyspark.serializers import write_with_length, write_int, read_long, \
32-
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
32+
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer
3333
from pyspark import shuffle
3434

3535
pickleSer = PickleSerializer()
@@ -59,7 +59,54 @@ def read_command(serializer, file):
5959

6060
def chain(f, g):
6161
"""chain two function together """
62-
return lambda x: g(f(x))
62+
return lambda *a: g(f(*a))
63+
64+
65+
def wrap_udf(f, return_type):
66+
if return_type.needConversion():
67+
toInternal = return_type.toInternal
68+
return lambda *a: toInternal(f(*a))
69+
else:
70+
return lambda *a: f(*a)
71+
72+
73+
def read_single_udf(pickleSer, infile):
74+
num_arg = read_int(infile)
75+
arg_offsets = [read_int(infile) for i in range(num_arg)]
76+
row_func = None
77+
for i in range(read_int(infile)):
78+
f, return_type = read_command(pickleSer, infile)
79+
if row_func is None:
80+
row_func = f
81+
else:
82+
row_func = chain(row_func, f)
83+
# the last returnType will be the return type of UDF
84+
return arg_offsets, wrap_udf(row_func, return_type)
85+
86+
87+
def read_udfs(pickleSer, infile):
88+
num_udfs = read_int(infile)
89+
if num_udfs == 1:
90+
# fast path for single UDF
91+
_, udf = read_single_udf(pickleSer, infile)
92+
mapper = lambda a: udf(*a)
93+
else:
94+
udfs = {}
95+
call_udf = []
96+
for i in range(num_udfs):
97+
arg_offsets, udf = read_single_udf(pickleSer, infile)
98+
udfs['f%d' % i] = udf
99+
args = ["a[%d]" % o for o in arg_offsets]
100+
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
101+
# Create function like this:
102+
# lambda a: (f0(a0), f1(a1, a2), f2(a3))
103+
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
104+
mapper = eval(mapper_str, udfs)
105+
106+
func = lambda _, it: map(mapper, it)
107+
ser = BatchedSerializer(PickleSerializer(), 100)
108+
# profiling is not supported for UDF
109+
return func, None, ser, ser
63110

64111

65112
def main(infile, outfile):
@@ -107,21 +154,10 @@ def main(infile, outfile):
107154
_broadcastRegistry.pop(bid)
108155

109156
_accumulatorRegistry.clear()
110-
row_based = read_int(infile)
111-
num_commands = read_int(infile)
112-
if row_based:
113-
profiler = None # profiling is not supported for UDF
114-
row_func = None
115-
for i in range(num_commands):
116-
f, returnType, deserializer = read_command(pickleSer, infile)
117-
if row_func is None:
118-
row_func = f
119-
else:
120-
row_func = chain(row_func, f)
121-
serializer = deserializer
122-
func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
157+
is_sql_udf = read_int(infile)
158+
if is_sql_udf:
159+
func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
123160
else:
124-
assert num_commands == 1
125161
func, profiler, deserializer, serializer = read_command(pickleSer, infile)
126162

127163
init_time = time.time()

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
426426
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
427427
exchange.ShuffleExchange(HashPartitioning(
428428
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
429-
case e @ python.EvaluatePython(udf, child, _) =>
430-
python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
429+
case e @ python.EvaluatePython(udfs, child, _) =>
430+
python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil
431431
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
432432
case BroadcastHint(child) => planLater(child) :: Nil
433433
case _ => Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,17 @@
1818
package org.apache.spark.sql.execution.python
1919

2020
import scala.collection.JavaConverters._
21+
import scala.collection.mutable.ArrayBuffer
2122

2223
import net.razorvine.pickle.{Pickler, Unpickler}
2324

2425
import org.apache.spark.TaskContext
25-
import org.apache.spark.api.python.{PythonFunction, PythonRunner}
26+
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner}
2627
import org.apache.spark.rdd.RDD
2728
import org.apache.spark.sql.catalyst.InternalRow
2829
import org.apache.spark.sql.catalyst.expressions._
2930
import org.apache.spark.sql.execution.SparkPlan
30-
import org.apache.spark.sql.types.{StructField, StructType}
31+
import org.apache.spark.sql.types.{DataType, StructField, StructType}
3132

3233

3334
/**
@@ -40,20 +41,20 @@ import org.apache.spark.sql.types.{StructField, StructType}
4041
* we drain the queue to find the original input row. Note that if the Python process is way too
4142
* slow, this could lead to the queue growing unbounded and eventually run out of memory.
4243
*/
43-
case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
44+
case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
4445
extends SparkPlan {
4546

4647
def children: Seq[SparkPlan] = child :: Nil
4748

48-
private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = {
49+
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
4950
udf.children match {
5051
case Seq(u: PythonUDF) =>
51-
val (fs, children) = collectFunctions(u)
52-
(fs ++ Seq(udf.func), children)
52+
val (chained, children) = collectFunctions(u)
53+
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
5354
case children =>
5455
// There should not be any other UDFs, or the children can't be evaluated directly.
5556
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
56-
(Seq(udf.func), udf.children)
57+
(ChainedPythonFunctions(Seq(udf.func)), udf.children)
5758
}
5859
}
5960

@@ -69,39 +70,78 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
6970
// combine input with output from Python.
7071
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
7172

72-
val (pyFuncs, children) = collectFunctions(udf)
73-
74-
val pickle = new Pickler
75-
val currentRow = newMutableProjection(children, child.output)()
76-
val fields = children.map(_.dataType)
77-
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
73+
val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
74+
75+
// flatten all the arguments
76+
val allInputs = new ArrayBuffer[Expression]
77+
val dataTypes = new ArrayBuffer[DataType]
78+
val argOffsets = inputs.map { input =>
79+
input.map { e =>
80+
if (allInputs.exists(_.semanticEquals(e))) {
81+
allInputs.indexWhere(_.semanticEquals(e))
82+
} else {
83+
allInputs += e
84+
dataTypes += e.dataType
85+
allInputs.length - 1
86+
}
87+
}.toArray
88+
}.toArray
89+
val projection = newMutableProjection(allInputs, child.output)()
90+
val schema = StructType(dataTypes.map(dt => StructField("", dt)))
91+
val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
7892

93+
// enable memo iff we serialize the row with schema (schema and class should be memorized)
94+
val pickle = new Pickler(needConversion)
7995
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
8096
// For each row, add it to the queue.
8197
val inputIterator = iter.grouped(100).map { inputRows =>
82-
val toBePickled = inputRows.map { row =>
83-
queue.add(row)
84-
EvaluatePython.toJava(currentRow(row), schema)
98+
val toBePickled = inputRows.map { inputRow =>
99+
queue.add(inputRow)
100+
val row = projection(inputRow)
101+
if (needConversion) {
102+
EvaluatePython.toJava(row, schema)
103+
} else {
104+
// fast path for these types that does not need conversion in Python
105+
val fields = new Array[Any](row.numFields)
106+
var i = 0
107+
while (i < row.numFields) {
108+
val dt = dataTypes(i)
109+
fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
110+
i += 1
111+
}
112+
fields
113+
}
85114
}.toArray
86115
pickle.dumps(toBePickled)
87116
}
88117

89118
val context = TaskContext.get()
90119

91120
// Output iterator for results from Python.
92-
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true)
121+
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets)
93122
.compute(inputIterator, context.partitionId(), context)
94123

95124
val unpickle = new Unpickler
96-
val row = new GenericMutableRow(1)
125+
val mutableRow = new GenericMutableRow(1)
97126
val joined = new JoinedRow
127+
val resultType = if (udfs.length == 1) {
128+
udfs.head.dataType
129+
} else {
130+
StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
131+
}
98132
val resultProj = UnsafeProjection.create(output, output)
99133

100134
outputIterator.flatMap { pickedResult =>
101135
val unpickledBatch = unpickle.loads(pickedResult)
102136
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
103137
}.map { result =>
104-
row(0) = EvaluatePython.fromJava(result, udf.dataType)
138+
val row = if (udfs.length == 1) {
139+
// fast path for single UDF
140+
mutableRow(0) = EvaluatePython.fromJava(result, resultType)
141+
mutableRow
142+
} else {
143+
EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
144+
}
105145
resultProj(joined(queue.poll(), row))
106146
}
107147
}

0 commit comments

Comments
 (0)