Skip to content

Commit

Permalink
Infer Pandas string columns in Arrow conversion on Python 2
Browse files Browse the repository at this point in the history
When serializing a Pandas dataframe using Arrow under Python 2, Arrow
can't tell if string columns should be serialized as string type or as
binary (due to how Python 2 stores strings). The result is that Arrow
serializes string columns in Pandas dataframes to binary ones.

We can remove this when we discontinue support for Python 2.

See original PR [1] and follow-up for 'mixed' type columns [2].

[1] #679
[2] #702
  • Loading branch information
rshkv committed Mar 9, 2021
1 parent 1209fb7 commit bfa8d60
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 1 deletion.
2 changes: 2 additions & 0 deletions FORK.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
* [palantir/spark#381](https://github.com/palantir/spark/pull/381) Gradle plugin to easily create custom docker images for use with k8s
* [palantir/spark#521](https://github.com/palantir/spark/pull/521) K8s local file mounting
* [palantir/spark#600](https://github.com/palantir/spark/pull/600) K8s local deploy mode
* Support Arrow-serialization of Python 2 strings [(#679)](https://github.com/palantir/spark/pull/679)
* [palantir/spark#678](https://github.com/palantir/spark/issues/678) TODO: Revert 679 once we move off of python 2



6 changes: 5 additions & 1 deletion python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):

from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
from pyspark.sql.types import TimestampType
from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type
from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type, \
_infer_binary_columns_as_arrow_string
from pyspark.sql.pandas.utils import require_minimum_pandas_version, \
require_minimum_pyarrow_version

Expand All @@ -395,6 +396,9 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
# Create the Spark schema from list of names passed in with Arrow types
if isinstance(schema, (list, tuple)):
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
# TODO(rshkv): Remove when we stop supporting Python 2 (#678)
if sys.version < '3':
arrow_schema = _infer_binary_columns_as_arrow_string(arrow_schema, pdf)
struct = StructType()
for name, field in zip(schema, arrow_schema):
struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
Expand Down
33 changes: 33 additions & 0 deletions python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
pandas instances during the type conversion.
"""

import sys
from pyspark.sql.types import *


Expand Down Expand Up @@ -107,6 +108,10 @@ def from_arrow_type(at):
elif types.is_list(at):
if types.is_timestamp(at.value_type):
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
# TODO(rshkv): Support binary array values once we move off Python 2 (#678)
if sys.version < '3' and types.is_binary(at.value_type):
raise TypeError("Unsupported type in conversion from Arrow: " + str(at) +
"\nPlease use Python3 for support of BinaryType in arrays.")
spark_type = ArrayType(from_arrow_type(at.value_type))
elif types.is_struct(at):
if any(types.is_struct(field.type) for field in at):
Expand Down Expand Up @@ -266,3 +271,31 @@ def _check_series_convert_timestamps_tz_local(s, timezone):
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
"""
return _check_series_convert_timestamps_localize(s, timezone, None)


# TODO(rshkv): Remove after we drop Python2 support
def _infer_binary_columns_as_arrow_string(schema, pandas_df):
"""
Infer if a Pandas column considered of type binary should be treated as string instead.
This workaround is only necessary on Python 2.
"""
import pandas as pd
import pyarrow as pa
import six

for field_index, field in enumerate(schema):
if not field.type == pa.binary():
continue

inferred_dtype = pd.api.types.infer_dtype(pandas_df.iloc[:, field_index], skipna=True)
if inferred_dtype == 'string':
is_string_column = True
elif inferred_dtype == 'mixed' and len(pandas_df.index) > 0:
first_value = pandas_df.iloc[0, field_index]
is_string_column = isinstance(first_value, six.string_types)

if is_string_column:
field_as_string = pa.field(field.name, pa.string())
schema = schema.set(field_index, field_as_string)

return schema
37 changes: 37 additions & 0 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,43 @@ def test_createDataFrame_with_int_col_names(self):
self.assertEqual(pdf_col_names, df.columns)
self.assertEqual(pdf_col_names, df_arrow.columns)

def test_createDataFrame_with_str_col(self):
import pandas as pd
pdf = pd.DataFrame({"a": ["x"]})

df, df_arrow = self._createDataFrame_toggle(pdf)
self.assertEqual(df.schema, df_arrow.schema)

def test_createDataFrame_with_str_array_col(self):
import pandas as pd
pdf = pd.DataFrame({"a": [["x"]]})

with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
df, df_arrow = self._createDataFrame_toggle(pdf)
self.assertEqual(df.schema, df_arrow.schema)

def test_createDataFrame_with_str_struct_col(self):
import pandas as pd
pdf = pd.DataFrame({"a": [{"x": "x"}]})

with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
df, df_arrow = self._createDataFrame_toggle(pdf)
self.assertEqual(df.schema, df_arrow.schema)

def test_createDataFrame_with_str_binary_mixed(self):
import pandas as pd
pdf = pd.DataFrame({"a": [u"unicode-value", "binary-under-python-2"]})

df, df_arrow = self._createDataFrame_toggle(pdf)
self.assertEqual(df.schema, df_arrow.schema)

def test_createDataFrame_with_real_binary(self):
import pandas as pd
pdf = pd.DataFrame({"a": [bytearray(b"a"), bytearray(b"c")]})

df, df_arrow = self._createDataFrame_toggle(pdf)
self.assertEqual(df.schema, df_arrow.schema)

def test_createDataFrame_fallback_enabled(self):
with QuietTest(self.sc):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
Expand Down

0 comments on commit bfa8d60

Please sign in to comment.