Skip to content

Commit e966c38

Browse files
LucaCanalicloud-fan
authored andcommitted
[SPARK-34265][PYTHON][SQL] Instrument Python UDFs using SQL metrics
### What changes are proposed in this pull request? This proposes to add SQLMetrics instrumentation for Python UDF execution, including Pandas UDF, and related operations such as MapInPandas and MapInArrow. The proposed metrics are: - data sent to Python workers - data returned from Python workers - number of output rows ### Why are the changes needed? This aims at improving monitoring and performance troubleshooting of Python UDFs. In particular it is intended as an aid to answer performance-related questions such as: why is the UDF slow?, how much work has been done so far?, etc. ### Does this PR introduce _any_ user-facing change? SQL metrics are made available in the WEB UI. See the following examples: ![image1](https://issues.apache.org/jira/secure/attachment/13038693/PandasUDF_ArrowEvalPython_Metrics.png) ### How was this patch tested? Manually tested + a Python unit test and a Scala unit test have been added. Example code used for testing: ``` from pyspark.sql.functions import col, pandas_udf import time pandas_udf("long") def test_pandas(col1): time.sleep(0.02) return col1 * col1 spark.udf.register("test_pandas", test_pandas) spark.sql("select rand(42)*rand(51)*rand(12) col1 from range(10000000)").createOrReplaceTempView("t1") spark.sql("select max(test_pandas(col1)) from t1").collect() ``` This is used to test with more data pushed to the Python workers: ``` from pyspark.sql.functions import col, pandas_udf import time pandas_udf("long") def test_pandas(col1,col2,col3,col4,col5,col6,col7,col8,col9,col10,col11,col12,col13,col14,col15,col16,col17): time.sleep(0.02) return col1 spark.udf.register("test_pandas", test_pandas) spark.sql("select rand(42)*rand(51)*rand(12) col1 from range(10000000)").createOrReplaceTempView("t1") spark.sql("select max(test_pandas(col1,col1+1,col1+2,col1+3,col1+4,col1+5,col1+6,col1+7,col1+8,col1+9,col1+10,col1+11,col1+12,col1+13,col1+14,col1+15,col1+16)) from t1").collect() ``` This (from the Spark doc) has been used to test with MapInPandas, where the number of output rows is different from the number of input rows: ``` import pandas as pd from pyspark.sql.functions import pandas_udf, PandasUDFType df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) def filter_func(iterator): for pdf in iterator: yield pdf[pdf.id == 1] df.mapInPandas(filter_func, schema=df.schema).show() ``` This for testing BatchEvalPython and metrics related to data transfer (bytes sent and received): ``` from pyspark.sql.functions import udf udf def test_udf(col1, col2): return col1 * col1 spark.sql("select id, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa' col2 from range(10)").select(test_udf("id", "col2")).collect() ``` Closes #33559 from LucaCanali/pythonUDFKeySQLMetrics. Authored-by: Luca Canali <luca.canali@cern.ch> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent b7a88cd commit e966c38

20 files changed

+193
-23
lines changed

dev/sparktestsupport/modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ def __hash__(self):
484484
"pyspark.sql.tests.pandas.test_pandas_udf_typehints",
485485
"pyspark.sql.tests.pandas.test_pandas_udf_typehints_with_future_annotations",
486486
"pyspark.sql.tests.pandas.test_pandas_udf_window",
487+
"pyspark.sql.tests.test_pandas_sqlmetrics",
487488
"pyspark.sql.tests.test_readwriter",
488489
"pyspark.sql.tests.test_serde",
489490
"pyspark.sql.tests.test_session",

docs/web-ui.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,8 @@ Here is the list of SQL metrics:
406406
<tr><td> <code>time to build hash map</code> </td><td> the time spent on building hash map </td><td> ShuffledHashJoin </td></tr>
407407
<tr><td> <code>task commit time</code> </td><td> the time spent on committing the output of a task after the writes succeed </td><td> any write operation on a file-based table </td></tr>
408408
<tr><td> <code>job commit time</code> </td><td> the time spent on committing the output of a job after the writes succeed </td><td> any write operation on a file-based table </td></tr>
409+
<tr><td> <code>data sent to Python workers</code> </td><td> the number of bytes of serialized data sent to the Python workers </td><td> ArrowEvalPython, AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas, FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas, PythonMapInArrow, WindowsInPandas </td></tr>
410+
<tr><td> <code>data returned from Python workers</code> </td><td> the number of bytes of serialized data received back from the Python workers </td><td> ArrowEvalPython, AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas, FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas, PythonMapInArrow, WindowsInPandas </td></tr>
409411
</table>
410412

411413
## Structured Streaming Tab
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import unittest
19+
from typing import cast
20+
21+
from pyspark.sql.functions import pandas_udf
22+
from pyspark.testing.sqlutils import (
23+
ReusedSQLTestCase,
24+
have_pandas,
25+
have_pyarrow,
26+
pandas_requirement_message,
27+
pyarrow_requirement_message,
28+
)
29+
30+
31+
@unittest.skipIf(
32+
not have_pandas or not have_pyarrow,
33+
cast(str, pandas_requirement_message or pyarrow_requirement_message),
34+
)
35+
class PandasSQLMetrics(ReusedSQLTestCase):
36+
def test_pandas_sql_metrics_basic(self):
37+
# SPARK-34265: Instrument Python UDFs using SQL metrics
38+
39+
python_sql_metrics = [
40+
"data sent to Python workers",
41+
"data returned from Python workers",
42+
"number of output rows",
43+
]
44+
45+
@pandas_udf("long")
46+
def test_pandas(col1):
47+
return col1 * col1
48+
49+
self.spark.range(10).select(test_pandas("id")).collect()
50+
51+
statusStore = self.spark._jsparkSession.sharedState().statusStore()
52+
lastExecId = statusStore.executionsList().last().executionId()
53+
executionMetrics = statusStore.execution(lastExecId).get().metrics().mkString()
54+
55+
for metric in python_sql_metrics:
56+
self.assertIn(metric, executionMetrics)
57+
58+
59+
if __name__ == "__main__":
60+
from pyspark.sql.tests.test_pandas_sqlmetrics import * # noqa: F401
61+
62+
try:
63+
import xmlrunner # type: ignore[import]
64+
65+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
66+
except ImportError:
67+
testRunner = None
68+
unittest.main(testRunner=testRunner, verbosity=2)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ case class AggregateInPandasExec(
4646
udfExpressions: Seq[PythonUDF],
4747
resultExpressions: Seq[NamedExpression],
4848
child: SparkPlan)
49-
extends UnaryExecNode {
49+
extends UnaryExecNode with PythonSQLMetrics {
5050

5151
override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
5252

@@ -163,7 +163,8 @@ case class AggregateInPandasExec(
163163
argOffsets,
164164
aggInputSchema,
165165
sessionLocalTimeZone,
166-
pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context)
166+
pythonRunnerConf,
167+
pythonMetrics).compute(projectedRowIter, context.partitionId(), context)
167168

168169
val joinedAttributes =
169170
groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.api.python.PythonSQLUtils
3232
import org.apache.spark.sql.catalyst.InternalRow
3333
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
3434
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
35+
import org.apache.spark.sql.execution.metric.SQLMetric
3536
import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
3637
import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
3738
import org.apache.spark.sql.execution.streaming.GroupStateImpl
@@ -58,7 +59,8 @@ class ApplyInPandasWithStatePythonRunner(
5859
stateEncoder: ExpressionEncoder[Row],
5960
keySchema: StructType,
6061
outputSchema: StructType,
61-
stateValueSchema: StructType)
62+
stateValueSchema: StructType,
63+
val pythonMetrics: Map[String, SQLMetric])
6264
extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
6365
with PythonArrowInput[InType]
6466
with PythonArrowOutput[OutType] {
@@ -116,6 +118,7 @@ class ApplyInPandasWithStatePythonRunner(
116118
val w = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch)
117119

118120
while (inputIterator.hasNext) {
121+
val startData = dataOut.size()
119122
val (keyRow, groupState, dataIter) = inputIterator.next()
120123
assert(dataIter.hasNext, "should have at least one data row!")
121124
w.startNewGroup(keyRow, groupState)
@@ -126,6 +129,8 @@ class ApplyInPandasWithStatePythonRunner(
126129
}
127130

128131
w.finalizeGroup()
132+
val deltaData = dataOut.size() - startData
133+
pythonMetrics("pythonDataSent") += deltaData
129134
}
130135

131136
w.finalizeData()

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ private[spark] class BatchIterator[T](iter: Iterator[T], batchSize: Int)
6161
*/
6262
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan,
6363
evalType: Int)
64-
extends EvalPythonExec {
64+
extends EvalPythonExec with PythonSQLMetrics {
6565

6666
private val batchSize = conf.arrowMaxRecordsPerBatch
6767
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
@@ -85,7 +85,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
8585
argOffsets,
8686
schema,
8787
sessionLocalTimeZone,
88-
pythonRunnerConf).compute(batchIter, context.partitionId(), context)
88+
pythonRunnerConf,
89+
pythonMetrics).compute(batchIter, context.partitionId(), context)
8990

9091
columnarBatchIter.flatMap { batch =>
9192
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python
1919

2020
import org.apache.spark.api.python._
2121
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.execution.metric.SQLMetric
2223
import org.apache.spark.sql.internal.SQLConf
2324
import org.apache.spark.sql.types._
2425
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -32,7 +33,8 @@ class ArrowPythonRunner(
3233
argOffsets: Array[Array[Int]],
3334
protected override val schema: StructType,
3435
protected override val timeZoneId: String,
35-
protected override val workerConf: Map[String, String])
36+
protected override val workerConf: Map[String, String],
37+
val pythonMetrics: Map[String, SQLMetric])
3638
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets)
3739
with BasicPythonArrowInput
3840
with BasicPythonArrowOutput {

sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
3232
* A physical plan that evaluates a [[PythonUDF]]
3333
*/
3434
case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan)
35-
extends EvalPythonExec {
35+
extends EvalPythonExec with PythonSQLMetrics {
3636

3737
protected override def evaluate(
3838
funcs: Seq[ChainedPythonFunctions],
@@ -77,7 +77,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
7777
}.grouped(100).map(x => pickle.dumps(x.toArray))
7878

7979
// Output iterator for results from Python.
80-
val outputIterator = new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets)
80+
val outputIterator =
81+
new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets, pythonMetrics)
8182
.compute(inputIterator, context.partitionId(), context)
8283

8384
val unpickle = new Unpickler
@@ -94,6 +95,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
9495
val unpickledBatch = unpickle.loads(pickedResult)
9596
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
9697
}.map { result =>
98+
pythonMetrics("pythonNumRowsReceived") += 1
9799
if (udfs.length == 1) {
98100
// fast path for single UDF
99101
mutableRow(0) = fromJava(result)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.{SparkEnv, TaskContext}
2727
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD}
2828
import org.apache.spark.sql.catalyst.InternalRow
2929
import org.apache.spark.sql.execution.arrow.ArrowWriter
30+
import org.apache.spark.sql.execution.metric.SQLMetric
3031
import org.apache.spark.sql.internal.SQLConf
3132
import org.apache.spark.sql.types.StructType
3233
import org.apache.spark.sql.util.ArrowUtils
@@ -45,7 +46,8 @@ class CoGroupedArrowPythonRunner(
4546
leftSchema: StructType,
4647
rightSchema: StructType,
4748
timeZoneId: String,
48-
conf: Map[String, String])
49+
conf: Map[String, String],
50+
val pythonMetrics: Map[String, SQLMetric])
4951
extends BasePythonRunner[
5052
(Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs, evalType, argOffsets)
5153
with BasicPythonArrowOutput {
@@ -77,10 +79,14 @@ class CoGroupedArrowPythonRunner(
7779
// For each we first send the number of dataframes in each group then send
7880
// first df, then send second df. End of data is marked by sending 0.
7981
while (inputIterator.hasNext) {
82+
val startData = dataOut.size()
8083
dataOut.writeInt(2)
8184
val (nextLeft, nextRight) = inputIterator.next()
8285
writeGroup(nextLeft, leftSchema, dataOut, "left")
8386
writeGroup(nextRight, rightSchema, dataOut, "right")
87+
88+
val deltaData = dataOut.size() - startData
89+
pythonMetrics("pythonDataSent") += deltaData
8490
}
8591
dataOut.writeInt(0)
8692
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ case class FlatMapCoGroupsInPandasExec(
5454
output: Seq[Attribute],
5555
left: SparkPlan,
5656
right: SparkPlan)
57-
extends SparkPlan with BinaryExecNode {
57+
extends SparkPlan with BinaryExecNode with PythonSQLMetrics {
5858

5959
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
6060
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
@@ -77,7 +77,6 @@ case class FlatMapCoGroupsInPandasExec(
7777
}
7878

7979
override protected def doExecute(): RDD[InternalRow] = {
80-
8180
val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup)
8281
val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup)
8382

@@ -97,7 +96,8 @@ case class FlatMapCoGroupsInPandasExec(
9796
StructType.fromAttributes(leftDedup),
9897
StructType.fromAttributes(rightDedup),
9998
sessionLocalTimeZone,
100-
pythonRunnerConf)
99+
pythonRunnerConf,
100+
pythonMetrics)
101101

102102
executePython(data, output, runner)
103103
}

0 commit comments

Comments
 (0)