Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""

from itertools import groupby
from typing import TYPE_CHECKING, Iterator, Optional
from typing import TYPE_CHECKING, Iterator, List, Optional, Set

import pyspark
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
Expand Down Expand Up @@ -539,6 +539,7 @@ def __init__(
self._ndarray_as_list = ndarray_as_list
self._arrow_cast = arrow_cast
self._input_types = input_types
self._used_column_offsets: Optional[List[int]] = None

def arrow_to_pandas(self, arrow_column, idx):
import pyarrow.types as types
Expand Down Expand Up @@ -581,6 +582,39 @@ def arrow_to_pandas(self, arrow_column, idx):
)
return s

def set_used_columns(self, used_offsets: Set[int]) -> Optional[dict[int, int]]:
"""
Set which columns to convert. Returns offset remap dict or None.
"""
if not used_offsets:
self._used_column_offsets = None
return None

self._used_column_offsets = sorted(used_offsets)
return {offset: i for i, offset in enumerate(self._used_column_offsets)}

def load_stream(self, stream) -> Iterator[List["pd.Series"]]:
"""
Deserialize ArrowRecordBatches to a list of pandas.Series.
"""
import pandas as pd
import pyspark

batches = ArrowStreamSerializer.load_stream(self, stream)

for batch in batches:
if batch.num_columns == 0:
yield [pd.Series([pyspark._NoValue] * batch.num_rows)]
continue

col_idx = (
self._used_column_offsets
if self._used_column_offsets is not None
else range(batch.num_columns)
)
pandas_batches = [self.arrow_to_pandas(batch.column(i), i) for i in col_idx]
yield pandas_batches

def _create_struct_array(
self,
df: "pd.DataFrame",
Expand Down
57 changes: 57 additions & 0 deletions python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2043,6 +2043,63 @@ def plus_two(iterator):
result = df.select(plus_two("id").alias("result")).collect()
self.assertEqual(expected, result)

def test_selective_column_conversion(self):
"""Test that selective conversion only converts used columns."""
# Create DataFrame with 10 columns, but UDF only uses 2
df = self.spark.range(100).select(*[col("id").alias(f"col_{i}") for i in range(10)])

@pandas_udf("long")
def add_cols(a: pd.Series, b: pd.Series) -> pd.Series:
return a + b

result = df.select(add_cols(col("col_0"), col("col_5")).alias("result")).collect()
expected = [Row(result=i + i) for i in range(100)]
self.assertEqual(result, expected)

def test_selective_conversion_all_columns(self):
"""Test when all columns are used."""
df = self.spark.range(100).select(col("id").alias("a"), (col("id") * 2).alias("b"))

@pandas_udf("long")
def add_all(a: pd.Series, b: pd.Series) -> pd.Series:
return a + b

result = df.select(add_all(col("a"), col("b")).alias("result")).collect()
expected = [Row(result=i + i * 2) for i in range(100)]
self.assertEqual(result, expected)

def test_selective_conversion_multiple_udfs(self):
"""Test with multiple UDFs using different columns."""
df = self.spark.range(100).select(
col("id").alias("a"),
(col("id") * 2).alias("b"),
(col("id") * 3).alias("c"),
)

@pandas_udf("long")
def use_a(a: pd.Series) -> pd.Series:
return a * 2

@pandas_udf("long")
def use_c(c: pd.Series) -> pd.Series:
return c + 1

result = df.select(use_a(col("a")).alias("r1"), use_c(col("c")).alias("r2")).collect()
expected = [Row(r1=i * 2, r2=i * 3 + 1) for i in range(100)]
self.assertEqual(result, expected)

def test_selective_conversion_single_column(self):
"""Test with a single column UDF from many columns."""
df = self.spark.range(100).select(*[col("id").alias(f"col_{i}") for i in range(5)])

@pandas_udf("long")
def double_it(x: pd.Series) -> pd.Series:
return x * 2

result = df.select(double_it(col("col_3")).alias("result")).collect()
expected = [Row(result=i * 2) for i in range(100)]
self.assertEqual(result, expected)


class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
41 changes: 33 additions & 8 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2916,6 +2916,16 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
for i in range(num_udfs)
]

# Only convert Arrow columns that are actually used by UDFs
offset_remap = None
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
if hasattr(ser, "set_used_columns"):
all_offsets = set()
for arg_offsets, _ in udfs:
if isinstance(arg_offsets, (list, tuple)):
all_offsets.update(arg_offsets)
offset_remap = ser.set_used_columns(all_offsets)

is_scalar_iter = eval_type in (
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
Expand Down Expand Up @@ -3392,15 +3402,30 @@ def mapper(batch_iter):
return result

else:
# Define mapper with or without index remapping for selective column conversion
if offset_remap is not None:
# Use remapped indices when selective conversion is enabled
def mapper(a):
result = tuple(
f(*[a[offset_remap[o]] for o in arg_offsets]) for arg_offsets, f in udfs
)
# In the special case of a single UDF this will return a single result rather
# than a tuple of results; this is the format that the JVM side expects.
if len(result) == 1:
return result[0]
else:
return result

def mapper(a):
result = tuple(f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs)
# In the special case of a single UDF this will return a single result rather
# than a tuple of results; this is the format that the JVM side expects.
if len(result) == 1:
return result[0]
else:
return result
else:
# Original mapper (no remapping)
def mapper(a):
result = tuple(f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs)
# In the special case of a single UDF this will return a single result rather
# than a tuple of results; this is the format that the JVM side expects.
if len(result) == 1:
return result[0]
else:
return result

def func(_, it):
return map(mapper, it)
Expand Down