Skip to content

Commit 528f3a7

Browse files
zhengruifengHyukjinKwon
authored andcommitted
[SPARK-53301][PYTHON] Differentiate type hints of Pandas UDF and Arrow UDF
### What changes were proposed in this pull request? Differentiate type hints of Pandas UDF and Arrow UDF ### Why are the changes needed? The `arrow_udf` can works with a pandas udf, and the `pandas_udf` can works with a arrow udf, because the eval type inference didn't differentiate the pandas udf and arrow udf. But this is supposed to fail. before: ``` In [1]: import pyarrow as pa ...: ...: from pyspark.sql import functions as sf ...: from pyspark.sql.functions import arrow_udf, pandas_udf ...: ...: df = spark.range(10).withColumn("v", sf.col("id") + 1) ...: ...: ...: pandas_udf("long") ...: def multiply_arrow_func(a: pa.Array, b: pa.Array) -> pa.Array: ...: assert isinstance(a, pa.Array) ...: assert isinstance(b, pa.Array) ...: return pa.compute.multiply(a, b) ...: ...: In [2]: df.select("id", "v", multiply_arrow_func("id", "v").alias("m")).show() ...: +---+---+---+ | id| v| m| +---+---+---+ | 0| 1| 0| | 1| 2| 2| | 2| 3| 6| | 3| 4| 12| | 4| 5| 20| | 5| 6| 30| | 6| 7| 42| | 7| 8| 56| | 8| 9| 72| | 9| 10| 90| +---+---+---+ ``` after ``` In [2]: ...: pandas_udf("long") ...: ...: def multiply_arrow_func(a: pa.Array, b: pa.Array) -> pa.Array: ...: ...: assert isinstance(a, pa.Array) ...: ...: assert isinstance(b, pa.Array) ...: ...: return pa.compute.multiply(a, b) ...: --------------------------------------------------------------------------- PySparkNotImplementedError Traceback (most recent call last) ... PySparkNotImplementedError: [UNSUPPORTED_SIGNATURE] Unsupported signature: (a: pyarrow.lib.Array, b: pyarrow.lib.Array) -> pyarrow.lib.Array. ``` ### Does this PR introduce _any_ user-facing change? no, arrow_udf is not yet released ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes apache#52054 from zhengruifeng/arrow_pandas_type_hint. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 64bd0e2 commit 528f3a7

File tree

4 files changed

+161
-56
lines changed

4 files changed

+161
-56
lines changed

python/pyspark/sql/pandas/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ def _validate_vectorized_udf(f, evalType, kind: str = "pandas") -> int:
807807
type_hints = get_type_hints(f)
808808
except NameError:
809809
type_hints = {}
810-
evalType = infer_eval_type(signature(f), type_hints)
810+
evalType = infer_eval_type(signature(f), type_hints, kind)
811811
assert evalType is not None
812812

813813
if evalType is None:

python/pyspark/sql/pandas/typehints.py

Lines changed: 132 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -31,27 +31,19 @@
3131
)
3232

3333

34-
def infer_eval_type(
35-
sig: Signature, type_hints: Dict[str, Any]
36-
) -> Union[
37-
"PandasScalarUDFType",
38-
"PandasScalarIterUDFType",
39-
"PandasGroupedAggUDFType",
40-
"ArrowScalarUDFType",
41-
"ArrowScalarIterUDFType",
42-
"ArrowGroupedAggUDFType",
43-
]:
34+
def infer_pandas_eval_type(
35+
sig: Signature,
36+
type_hints: Dict[str, Any],
37+
) -> Optional[Union["PandasScalarUDFType", "PandasScalarIterUDFType", "PandasGroupedAggUDFType"]]:
4438
"""
4539
Infers the evaluation type in :class:`pyspark.util.PythonEvalType` from
4640
:class:`inspect.Signature` instance and type hints.
4741
"""
48-
from pyspark.sql.pandas.functions import PandasUDFType, ArrowUDFType
42+
from pyspark.sql.pandas.functions import PandasUDFType
4943

5044
require_minimum_pandas_version()
51-
require_minimum_pyarrow_version()
5245

5346
import pandas as pd
54-
import pyarrow as pa
5547

5648
annotations = {}
5749
for param in sig.parameters.values():
@@ -85,9 +77,8 @@ def infer_eval_type(
8577
)
8678
for a in parameters_sig
8779
) and (return_annotation == pd.Series or return_annotation == pd.DataFrame)
88-
89-
# pa.Array, ... -> pa.Array
90-
is_arrow_array = all(a == pa.Array for a in parameters_sig) and (return_annotation == pa.Array)
80+
if is_series_or_frame:
81+
return PandasUDFType.SCALAR
9182

9283
# Iterator[Tuple[Series, Frame or Union[DataFrame, Series], ...] -> Iterator[Series or Frame]
9384
is_iterator_tuple_series_or_frame = (
@@ -110,21 +101,8 @@ def infer_eval_type(
110101
return_annotation, parameter_check_func=lambda a: a == pd.DataFrame or a == pd.Series
111102
)
112103
)
113-
114-
# Iterator[Tuple[pa.Array, ...] -> Iterator[pa.Array]
115-
is_iterator_tuple_array = (
116-
len(parameters_sig) == 1
117-
and check_iterator_annotation( # Iterator
118-
parameters_sig[0],
119-
parameter_check_func=lambda a: check_tuple_annotation( # Tuple
120-
a,
121-
parameter_check_func=lambda ta: (ta == Ellipsis or ta == pa.Array),
122-
),
123-
)
124-
and check_iterator_annotation(
125-
return_annotation, parameter_check_func=lambda a: a == pa.Array
126-
)
127-
)
104+
if is_iterator_tuple_series_or_frame:
105+
return PandasUDFType.SCALAR_ITER
128106

129107
# Iterator[Series, Frame or Union[DataFrame, Series]] -> Iterator[Series or Frame]
130108
is_iterator_series_or_frame = (
@@ -143,18 +121,8 @@ def infer_eval_type(
143121
return_annotation, parameter_check_func=lambda a: a == pd.DataFrame or a == pd.Series
144122
)
145123
)
146-
147-
# Iterator[pa.Array] -> Iterator[pa.Array]
148-
is_iterator_array = (
149-
len(parameters_sig) == 1
150-
and check_iterator_annotation(
151-
parameters_sig[0],
152-
parameter_check_func=lambda a: a == pa.Array,
153-
)
154-
and check_iterator_annotation(
155-
return_annotation, parameter_check_func=lambda a: a == pa.Array
156-
)
157-
)
124+
if is_iterator_series_or_frame:
125+
return PandasUDFType.SCALAR_ITER
158126

159127
# Series, Frame or Union[DataFrame, Series], ... -> Any
160128
is_series_or_frame_agg = all(
@@ -173,32 +141,141 @@ def infer_eval_type(
173141
and not check_iterator_annotation(return_annotation)
174142
and not check_tuple_annotation(return_annotation)
175143
)
144+
if is_series_or_frame_agg:
145+
return PandasUDFType.GROUPED_AGG
146+
147+
return None
148+
149+
150+
def infer_arrow_eval_type(
151+
sig: Signature, type_hints: Dict[str, Any]
152+
) -> Optional[Union["ArrowScalarUDFType", "ArrowScalarIterUDFType", "ArrowGroupedAggUDFType"]]:
153+
"""
154+
Infers the evaluation type in :class:`pyspark.util.PythonEvalType` from
155+
:class:`inspect.Signature` instance and type hints.
156+
"""
157+
from pyspark.sql.pandas.functions import ArrowUDFType
158+
159+
require_minimum_pyarrow_version()
160+
161+
import pyarrow as pa
162+
163+
annotations = {}
164+
for param in sig.parameters.values():
165+
if param.annotation is not param.empty:
166+
annotations[param.name] = type_hints.get(param.name, param.annotation)
167+
168+
# Check if all arguments have type hints
169+
parameters_sig = [
170+
annotations[parameter] for parameter in sig.parameters if parameter in annotations
171+
]
172+
if len(parameters_sig) != len(sig.parameters):
173+
raise PySparkValueError(
174+
errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
175+
messageParameters={"target": "all parameters", "sig": str(sig)},
176+
)
177+
178+
# Check if the return has a type hint
179+
return_annotation = type_hints.get("return", sig.return_annotation)
180+
if sig.empty is return_annotation:
181+
raise PySparkValueError(
182+
errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
183+
messageParameters={"target": "the return type", "sig": str(sig)},
184+
)
185+
186+
# pa.Array, ... -> pa.Array
187+
is_arrow_array = all(a == pa.Array for a in parameters_sig) and (return_annotation == pa.Array)
188+
if is_arrow_array:
189+
return ArrowUDFType.SCALAR
190+
191+
# Iterator[Tuple[pa.Array, ...] -> Iterator[pa.Array]
192+
is_iterator_tuple_array = (
193+
len(parameters_sig) == 1
194+
and check_iterator_annotation( # Iterator
195+
parameters_sig[0],
196+
parameter_check_func=lambda a: check_tuple_annotation( # Tuple
197+
a,
198+
parameter_check_func=lambda ta: (ta == Ellipsis or ta == pa.Array),
199+
),
200+
)
201+
and check_iterator_annotation(
202+
return_annotation, parameter_check_func=lambda a: a == pa.Array
203+
)
204+
)
205+
if is_iterator_tuple_array:
206+
return ArrowUDFType.SCALAR_ITER
207+
208+
# Iterator[pa.Array] -> Iterator[pa.Array]
209+
is_iterator_array = (
210+
len(parameters_sig) == 1
211+
and check_iterator_annotation(
212+
parameters_sig[0],
213+
parameter_check_func=lambda a: a == pa.Array,
214+
)
215+
and check_iterator_annotation(
216+
return_annotation, parameter_check_func=lambda a: a == pa.Array
217+
)
218+
)
219+
if is_iterator_array:
220+
return ArrowUDFType.SCALAR_ITER
176221

177222
# pa.Array, ... -> Any
178223
is_array_agg = all(a == pa.Array for a in parameters_sig) and (
179224
return_annotation != pa.Array
180225
and not check_iterator_annotation(return_annotation)
181226
and not check_tuple_annotation(return_annotation)
182227
)
183-
184-
if is_series_or_frame:
185-
return PandasUDFType.SCALAR
186-
elif is_arrow_array:
187-
return ArrowUDFType.SCALAR
188-
elif is_iterator_tuple_series_or_frame or is_iterator_series_or_frame:
189-
return PandasUDFType.SCALAR_ITER
190-
elif is_iterator_tuple_array or is_iterator_array:
191-
return ArrowUDFType.SCALAR_ITER
192-
elif is_series_or_frame_agg:
193-
return PandasUDFType.GROUPED_AGG
194-
elif is_array_agg:
228+
if is_array_agg:
195229
return ArrowUDFType.GROUPED_AGG
230+
231+
return None
232+
233+
234+
def infer_eval_type(
235+
sig: Signature,
236+
type_hints: Dict[str, Any],
237+
kind: str = "all",
238+
) -> Union[
239+
"PandasScalarUDFType",
240+
"PandasScalarIterUDFType",
241+
"PandasGroupedAggUDFType",
242+
"ArrowScalarUDFType",
243+
"ArrowScalarIterUDFType",
244+
"ArrowGroupedAggUDFType",
245+
]:
246+
"""
247+
Infers the evaluation type in :class:`pyspark.util.PythonEvalType` from
248+
:class:`inspect.Signature` instance and type hints.
249+
"""
250+
assert kind in ["pandas", "arrow", "all"], "kind should be either 'pandas', 'arrow' or 'all'"
251+
252+
eval_type: Optional[
253+
Union[
254+
"PandasScalarUDFType",
255+
"PandasScalarIterUDFType",
256+
"PandasGroupedAggUDFType",
257+
"ArrowScalarUDFType",
258+
"ArrowScalarIterUDFType",
259+
"ArrowGroupedAggUDFType",
260+
]
261+
] = None
262+
if kind == "pandas":
263+
eval_type = infer_pandas_eval_type(sig, type_hints)
264+
elif kind == "arrow":
265+
eval_type = infer_arrow_eval_type(sig, type_hints)
196266
else:
267+
eval_type = infer_pandas_eval_type(sig, type_hints) or infer_arrow_eval_type(
268+
sig, type_hints
269+
)
270+
271+
if eval_type is None:
197272
raise PySparkNotImplementedError(
198273
errorClass="UNSUPPORTED_SIGNATURE",
199274
messageParameters={"signature": str(sig)},
200275
)
201276

277+
return eval_type
278+
202279

203280
def check_tuple_annotation(
204281
annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] = None

python/pyspark/sql/tests/arrow/test_arrow_udf_typehints.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
from pyspark.sql import functions as sf
2222
from pyspark.testing.utils import (
23+
have_pandas,
24+
pandas_requirement_message,
2325
have_pyarrow,
2426
pyarrow_requirement_message,
2527
have_numpy,
@@ -323,6 +325,19 @@ def func(col: Union["pa.Array", "pa.Array"], *, col2: "pa.Array") -> "pa.Array":
323325
infer_eval_type(signature(func), get_type_hints(func)), ArrowUDFType.SCALAR
324326
)
325327

328+
@unittest.skipIf(not have_pandas, pandas_requirement_message)
329+
def test_negative_with_pandas_udf(self):
330+
import pandas as pd
331+
332+
with self.assertRaisesRegex(
333+
Exception,
334+
"Unsupported signature:.*pandas.core.series.Series.",
335+
):
336+
337+
@arrow_udf("long")
338+
def multiply_pandas(a: pd.Series, b: pd.Series) -> pd.Series:
339+
return a * b
340+
326341

327342
if __name__ == "__main__":
328343
from pyspark.sql.tests.arrow.test_arrow_udf_typehints import * # noqa: #401

python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,19 @@ def func(col: "Union[pd.Series, pd.DataFrame]", *, col2: "pd.DataFrame") -> "pd.
377377
infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR
378378
)
379379

380+
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
381+
def test_negative_with_arrow_udf(self):
382+
import pyarrow as pa
383+
384+
with self.assertRaisesRegex(
385+
Exception,
386+
"Unsupported signature:.*pyarrow.lib.Array.",
387+
):
388+
389+
@pandas_udf("long")
390+
def multiply_arrow(a: pa.Array, b: pa.Array) -> pa.Array:
391+
return pa.compute.multiply(a, b)
392+
380393

381394
if __name__ == "__main__":
382395
from pyspark.sql.tests.pandas.test_pandas_udf_typehints import * # noqa: #401

0 commit comments

Comments
 (0)