Skip to content

Commit 94f5f4f

Browse files
itholicdongjoon-hyun
authored andcommitted
[SPARK-43704][CONNECT][PS] Support MultiIndex for to_series()
### What changes were proposed in this pull request? This PR proposes to support `MultiIndex` for `to_series()`. ### Why are the changes needed? So far, `to_series()` for `MultiIndex` is not working properly since the underlying data structure is different from Pandas and Spark. See the below examples in the next section for more detail. ### Does this PR introduce _any_ user-facing change? **Before** ```python >>> psmidx = ps.MultiIndex.from_tuples([("A", "B")]) >>> psmidx.to_series() A B {'__index_level_0__': 'A', '__index_level_1__'... C {'__index_level_0__': 'A', '__index_level_1__'... B C {'__index_level_0__': 'B', '__index_level_1__'... dtype: object ``` **After** ```python >>> psmidx = ps.MultiIndex.from_tuples([("A", "B")]) >>> psmidx.to_series() A B [A, B] C [A, C] B C [B, C] dtype: object ``` ### How was this patch tested? Enabling the existing UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43228 from itholic/SPARK-43704. Authored-by: Haejoon Lee <haejoon.lee@databricks.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
1 parent 8d3199d commit 94f5f4f

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

python/pyspark/pandas/indexes/base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,19 @@ def to_series(self, name: Optional[Name] = None) -> Series:
916916
data_fields=[field],
917917
column_label_names=None,
918918
)
919-
return first_series(DataFrame(internal))
919+
920+
result = first_series(DataFrame(internal))
921+
if self._internal.index_level == 1:
922+
return result
923+
else:
924+
# MultiIndex
925+
def struct_to_array(scol: Column) -> Column:
926+
field_names = result._internal.spark_type_for(
927+
scol
928+
).fieldNames() # type: ignore[attr-defined]
929+
return F.array([scol[field] for field in field_names])
930+
931+
return result.spark.transform(struct_to_array)
920932

921933
def to_frame(self, index: bool = True, name: Optional[Name] = None) -> DataFrame:
922934
"""

python/pyspark/pandas/indexing.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,12 +1077,16 @@ def _select_rows_by_slice(
10771077

10781078
return reduce(lambda x, y: x & y, conds), None, None
10791079
else:
1080-
from pyspark.sql.types import StructType
1080+
from pyspark.sql.types import ArrayType, StructType
10811081

10821082
index = self._psdf_or_psser.index
1083-
index_data_type = [ # type: ignore[assignment]
1084-
f.dataType for f in cast(StructType, index.to_series().spark.data_type)
1085-
]
1083+
data_type = index.to_series().spark.data_type
1084+
if isinstance(data_type, StructType):
1085+
index_data_type = [f.dataType for f in data_type] # type: ignore[assignment]
1086+
elif isinstance(data_type, ArrayType):
1087+
index_data_type = [ # type: ignore[assignment]
1088+
data_type.elementType for _ in range(index._internal.index_level)
1089+
]
10861090

10871091
start = rows_sel.start
10881092
if start is not None:

python/pyspark/pandas/tests/connect/indexes/test_parity_base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@ class IndexesParityTests(
2929
def psdf(self):
3030
return ps.from_pandas(self.pdf)
3131

32-
@unittest.skip("TODO(SPARK-43704): Enable IndexesParityTests.test_to_series.")
33-
def test_to_series(self):
34-
super().test_to_series()
35-
3632

3733
if __name__ == "__main__":
3834
from pyspark.pandas.tests.connect.indexes.test_parity_base import * # noqa: F401

0 commit comments

Comments
 (0)