diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 6203d4d19d866..076226865f3a7 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -510,8 +510,8 @@ def _create_batch(self, series): # If it returns a pd.Series, it should throw an error. if not isinstance(s, pd.DataFrame): raise PySparkValueError( - "A field of type StructType expects a pandas.DataFrame, " - "but got: %s" % str(type(s)) + "Invalid return type. Please make sure that the UDF returns a " + "pandas.DataFrame when the specified return type is StructType." ) arrs.append(self._create_struct_array(s, t)) else: diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py index 6720dfc37d0cc..228fc30b497cc 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py @@ -339,6 +339,19 @@ def noop(s: pd.Series) -> pd.Series: self.assertEqual(df.schema[0].dataType.simpleString(), "interval day to second") self.assertEqual(df.first()[0], datetime.timedelta(microseconds=123)) + def test_pandas_udf_return_type_error(self): + import pandas as pd + + @pandas_udf("s string") + def upper(s: pd.Series) -> pd.Series: + return s.str.upper() + + df = self.spark.createDataFrame([("a",)], schema="s string") + + self.assertRaisesRegex( + PythonException, "Invalid return type", df.select(upper("s")).collect + ) + class PandasUDFTests(PandasUDFTestsMixin, ReusedSQLTestCase): pass