Skip to content

Commit bba8bb8

Browse files
Yicong-Huangzhengruifeng
authored andcommitted
[SPARK-54598][PYTHON] Extract logic to read UDFs
### What changes were proposed in this pull request? This PR refactors the UDF reading logic in `read_udfs()` to eliminate code duplication. Currently, the logic for reading UDFs (functions and their argument offsets) is duplicated across multiple `eval_type` branches, with different patterns for single UDF vs. multiple UDFs cases. ### Why are the changes needed? This duplication makes the code harder to maintain and increases the risk of inconsistencies. By centralizing the UDF reading logic at the beginning of `read_udfs()`, we can: - Reduce code duplication - Ensure consistent UDF reading behavior across all eval types - Make it easier to add new eval types in the future ### Does this PR introduce _any_ user-facing change? No, this is an internal refactoring that maintains backward compatibility. The API behavior remains the same from the user's perspective. ### How was this patch tested? Existing Tests ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#53330 from Yicong-Huang/SPARK-54598/refactor/udf-fetching-logic. Authored-by: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 191ce4c commit bba8bb8

File tree

1 file changed

+16
-40
lines changed

1 file changed

+16
-40
lines changed

python/pyspark/worker.py

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2878,7 +2878,12 @@ def read_udfs(pickleSer, infile, eval_type):
28782878
else:
28792879
profiler = None
28802880

2881+
# Read all UDFs
28812882
num_udfs = read_int(infile)
2883+
udfs = [
2884+
read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=i, profiler=profiler)
2885+
for i in range(num_udfs)
2886+
]
28822887

28832888
is_scalar_iter = eval_type in (
28842889
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
@@ -2896,9 +2901,7 @@ def read_udfs(pickleSer, infile, eval_type):
28962901
if is_map_arrow_iter:
28972902
assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here."
28982903

2899-
arg_offsets, udf = read_single_udf(
2900-
pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
2901-
)
2904+
arg_offsets, udf = udfs[0]
29022905

29032906
def func(_, iterator):
29042907
num_input_rows = 0
@@ -2994,9 +2997,7 @@ def extract_key_value_indexes(grouped_arg_offsets):
29942997

29952998
# See FlatMapGroupsInPandasExec for how arg_offsets are used to
29962999
# distinguish between grouping attributes and data attributes
2997-
arg_offsets, f = read_single_udf(
2998-
pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
2999-
)
3000+
arg_offsets, f = udfs[0]
30003001
parsed_offsets = extract_key_value_indexes(arg_offsets)
30013002

30023003
def mapper(series_iter):
@@ -3022,9 +3023,7 @@ def mapper(series_iter):
30223023

30233024
# See TransformWithStateInPandasExec for how arg_offsets are used to
30243025
# distinguish between grouping attributes and data attributes
3025-
arg_offsets, f = read_single_udf(
3026-
pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
3027-
)
3026+
arg_offsets, f = udfs[0]
30283027
parsed_offsets = extract_key_value_indexes(arg_offsets)
30293028
ser.key_offsets = parsed_offsets[0][0]
30303029
stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)
@@ -3053,9 +3052,7 @@ def values_gen():
30533052

30543053
# See TransformWithStateInPandasExec for how arg_offsets are used to
30553054
# distinguish between grouping attributes and data attributes
3056-
arg_offsets, f = read_single_udf(
3057-
pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
3058-
)
3055+
arg_offsets, f = udfs[0]
30593056
# parsed offsets:
30603057
# [
30613058
# [groupingKeyOffsets, dedupDataOffsets],
@@ -3091,9 +3088,7 @@ def values_gen():
30913088

30923089
# See TransformWithStateInPySparkExec for how arg_offsets are used to
30933090
# distinguish between grouping attributes and data attributes
3094-
arg_offsets, f = read_single_udf(
3095-
pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
3096-
)
3091+
arg_offsets, f = udfs[0]
30973092
parsed_offsets = extract_key_value_indexes(arg_offsets)
30983093
ser.key_offsets = parsed_offsets[0][0]
30993094
stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)
@@ -3118,9 +3113,7 @@ def mapper(a):
31183113

31193114
# See TransformWithStateInPandasExec for how arg_offsets are used to
31203115
# distinguish between grouping attributes and data attributes
3121-
arg_offsets, f = read_single_udf(
3122-
pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
3123-
)
3116+
arg_offsets, f = udfs[0]
31243117
# parsed offsets:
31253118
# [
31263119
# [groupingKeyOffsets, dedupDataOffsets],
@@ -3156,9 +3149,7 @@ def mapper(a):
31563149

31573150
# See FlatMapGroupsInPandasExec for how arg_offsets are used to
31583151
# distinguish between grouping attributes and data attributes
3159-
arg_offsets, f = read_single_udf(
3160-
pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
3161-
)
3152+
arg_offsets, f = udfs[0]
31623153
parsed_offsets = extract_key_value_indexes(arg_offsets)
31633154

31643155
def batch_from_offset(batch, offsets):
@@ -3187,9 +3178,7 @@ def mapper(a):
31873178

31883179
# See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to
31893180
# distinguish between grouping attributes and data attributes
3190-
arg_offsets, f = read_single_udf(
3191-
pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
3192-
)
3181+
arg_offsets, f = udfs[0]
31933182
parsed_offsets = extract_key_value_indexes(arg_offsets)
31943183

31953184
def mapper(a):
@@ -3223,9 +3212,7 @@ def mapper(a):
32233212
# We assume there is only one UDF here because cogrouped map doesn't
32243213
# support combining multiple UDFs.
32253214
assert num_udfs == 1
3226-
arg_offsets, f = read_single_udf(
3227-
pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
3228-
)
3215+
arg_offsets, f = udfs[0]
32293216

32303217
parsed_offsets = extract_key_value_indexes(arg_offsets)
32313218

@@ -3242,9 +3229,7 @@ def mapper(a):
32423229
# We assume there is only one UDF here because cogrouped map doesn't
32433230
# support combining multiple UDFs.
32443231
assert num_udfs == 1
3245-
arg_offsets, f = read_single_udf(
3246-
pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
3247-
)
3232+
arg_offsets, f = udfs[0]
32483233

32493234
parsed_offsets = extract_key_value_indexes(arg_offsets)
32503235

@@ -3269,9 +3254,7 @@ def mapper(a):
32693254
# support combining multiple UDFs.
32703255
assert num_udfs == 1
32713256

3272-
arg_offsets, f = read_single_udf(
3273-
pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
3274-
)
3257+
arg_offsets, f = udfs[0]
32753258

32763259
# Convert to iterator of batches: Iterator[pa.Array] for single column,
32773260
# or Iterator[Tuple[pa.Array, ...]] for multiple columns
@@ -3283,13 +3266,6 @@ def mapper(a):
32833266
return f(batch_iter)
32843267

32853268
else:
3286-
udfs = []
3287-
for i in range(num_udfs):
3288-
udfs.append(
3289-
read_single_udf(
3290-
pickleSer, infile, eval_type, runner_conf, udf_index=i, profiler=profiler
3291-
)
3292-
)
32933269

32943270
def mapper(a):
32953271
result = tuple(f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs)

0 commit comments

Comments
 (0)