Skip to content

[Python] Failed to convert 'float' to 'double' with using pandas_udf and pyspark #21514

@asfimport

Description

@asfimport

Hi everyone,

I would like to report a (potential) bug. I followed an official guide on [Usage Guide for Pandas with Apache Arrow](https://spark.apache.org/docs/2.4.0/sql-pyspark-pandas-with-arrow.html).

 However, libarrrow throws me error for type conversion from float -> double. Here is the example and its output.

 pyarrow==0.12.1

from pyspark.sql import SparkSession, SQLContext
from pyspark.sql.functions import pandas_udf, PandasUDFType, col

spark = SparkSession.builder.appName('ReproduceBug') .getOrCreate()

df = spark.createDataFrame(
    [(1, "a"), (1, "a"), (1, "b")],
    ("id", "value"))
df.show()
# Spark DataFrame
# +---+-----+
# | id|value|
# +---+-----+
# |  1|    a|
# |  1|    a|
# |  1|    b|
# +---+-----+

# Potential Bug # 
@pandas_udf('double', PandasUDFType.SCALAR)
def compute_frequencies(value_col):
    total      = value_col.count()
    per_groups = value_col.groupby(value_col).transform('count')
    score      = per_groups / total
    return score

df.groupBy("id")\
  .agg(compute_frequencies(col('value')))\
  .show()

spark.stop()

 

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-3-d4f781f64db1> in <module>
     32 
     33 df.groupBy("id")\
---> 34   .agg(compute_frequencies(col('value')))\
     35   .show()
     36 

/usr/local/spark/python/pyspark/sql/dataframe.py in show(self, n, truncate, vertical)
    376         """
    377         if isinstance(truncate, bool) and truncate:
--> 378             print(self._jdf.showString(n, 20, vertical))
    379         else:
    380             print(self._jdf.showString(n, int(truncate), vertical))

/usr/local/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

/usr/local/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

/usr/local/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Py4JJavaError: An error occurred while calling o186.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 44 in stage 23.0 failed 1 times, most recent failure: Lost task 44.0 in stage 23.0 (TID 601, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/usr/local/spark/python/lib/pyspark.zip/pyspark/worker.py", line 372, in main
    process()
  File "/usr/local/spark/python/lib/pyspark.zip/pyspark/worker.py", line 367, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/usr/local/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 284, in dump_stream
    batch = _create_batch(series, self._timezone)
  File "/usr/local/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 253, in _create_batch
    arrs = [create_array(s, t) for s, t in series]
  File "/usr/local/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 253, in <listcomp>
    arrs = [create_array(s, t) for s, t in series]
  File "/usr/local/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 251, in create_array
    return pa.Array.from_pandas(s, mask=mask, type=t)
  File "pyarrow/array.pxi", line 536, in pyarrow.lib.Array.from_pandas
  File "pyarrow/array.pxi", line 176, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 85, in pyarrow.lib._ndarray_to_array
  File "pyarrow/error.pxi", line 81, in pyarrow.lib.check_status
pyarrow.lib.ArrowInvalid: Could not convert 0    0.666667
1    0.666667
2    0.333333
Name: _0, dtype: float64 with type Series: tried to convert to double

Please let me know if you would like to know more any further information.

Environment: Linux 68b0517ddf1c 3.10.0-862.11.6.el7.x86_64 #1 SMP GNU/Linux

Reporter: Dat Nguyen

Note: This issue was originally created as ARROW-5016. Please see the migration documentation for further details.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions