Skip to content

Commit 8e6e5bc

Browse files
author
Davies Liu
committed
fast path for single UDF
1 parent f6b7373 commit 8e6e5bc

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

python/pyspark/sql/tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def test_chained_udf(self):
315315
self.assertEqual(row[0], 6)
316316

317317
def test_multiple_udfs(self):
318+
self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType())
318319
[row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
319320
self.assertEqual(tuple(row), (2, 4))
320321
[row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()

python/pyspark/worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ def read_udfs(pickleSer, infile):
9191
if num_udfs == 1:
9292
udf = udfs[0][2]
9393

94+
# fast path for single UDF
9495
def mapper(args):
95-
return (udf(*args),)
96+
return udf(*args)
9697
else:
9798
def mapper(args):
9899
return tuple(udf(*args[start:end]) for start, end, udf in udfs)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c
7171

7272
val (pyFuncs, children) = udfs.map(collectFunctions).unzip
7373
val numArgs = children.map(_.length)
74-
val resultType = StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
7574

7675
val pickle = new Pickler
7776
// flatten all the arguments
@@ -97,15 +96,26 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c
9796
.compute(inputIterator, context.partitionId(), context)
9897

9998
val unpickle = new Unpickler
100-
val row = new GenericMutableRow(1)
99+
val mutableRow = new GenericMutableRow(1)
101100
val joined = new JoinedRow
101+
val resultType = if (udfs.length == 1) {
102+
udfs.head.dataType
103+
} else {
104+
StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
105+
}
102106
val resultProj = UnsafeProjection.create(output, output)
103107

104108
outputIterator.flatMap { pickedResult =>
105109
val unpickledBatch = unpickle.loads(pickedResult)
106110
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
107111
}.map { result =>
108-
val row = EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
112+
val row = if (udfs.length == 1) {
113+
// fast path for single UDF
114+
mutableRow(0) = EvaluatePython.fromJava(result, resultType)
115+
mutableRow
116+
} else {
117+
EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
118+
}
109119
resultProj(joined(queue.poll(), row))
110120
}
111121
}

0 commit comments

Comments
 (0)