Skip to content

Commit 059d8b6

Browse files
authored
GH-35599: [Python] Canonical fixed-shape tensor extension array/type is not picklable. (#35933)
This PR adds `__reduce__` method to the `FixedShapeTensorType`. * Closes: #35599 Authored-by: AlenkaF <frim.alenka@gmail.com> Signed-off-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
1 parent 98559fe commit 059d8b6

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

python/pyarrow/tests/test_extension_type.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,3 +1306,19 @@ def test_extension_to_pandas_storage_type(registered_period_type):
13061306
# Check the usage of types_mapper
13071307
result = table.to_pandas(types_mapper=pd.ArrowDtype)
13081308
assert isinstance(result["ext"].dtype, pd.ArrowDtype)
1309+
1310+
1311+
def test_tensor_type_is_picklable():
1312+
# GH-35599
1313+
1314+
expected_type = pa.fixed_shape_tensor(pa.int32(), (2, 2))
1315+
result = pickle.loads(pickle.dumps(expected_type))
1316+
1317+
assert result == expected_type
1318+
1319+
arr = [[1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400]]
1320+
storage = pa.array(arr, pa.list_(pa.int32(), 4))
1321+
expected_arr = pa.ExtensionArray.from_storage(expected_type, storage)
1322+
result = pickle.loads(pickle.dumps(expected_arr))
1323+
1324+
assert result == expected_arr

python/pyarrow/types.pxi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,6 +1586,10 @@ cdef class FixedShapeTensorType(BaseExtensionType):
15861586
def __arrow_ext_class__(self):
15871587
return FixedShapeTensorArray
15881588

1589+
def __reduce__(self):
1590+
return fixed_shape_tensor, (self.value_type, self.shape,
1591+
self.dim_names, self.permutation)
1592+
15891593

15901594
cdef class PyExtensionType(ExtensionType):
15911595
"""

0 commit comments

Comments
 (0)