Skip to content

Commit f4d2e1e

Browse files
committed
Enabling serialization with pydata/sparse.
1 parent db1fb5a commit f4d2e1e

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

python/pyarrow/serialization.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,35 @@ def _deserialize_scipy_sparse(data):
350350
pass
351351

352352

353+
# ----------------------------------------------------------------------
354+
# Set up serialization for pydata/sparse tensors.
355+
356+
def _register_pydata_sparse_handlers(serialization_context):
357+
try:
358+
import sparse
359+
360+
def _serialize_pydata_sparse(obj):
361+
if isinstance(obj, sparse.coo.core.COO):
362+
return 'coo', pyarrow.SparseTensorCOO.from_numpy(
363+
obj.data, obj.coords.T, shape=obj.shape)
364+
365+
def _deserialize_pydata_sparse(data):
366+
if data[0] == 'coo':
367+
data_array, coords = data[1].to_numpy()
368+
return sparse.COO(
369+
data=data_array[:, 0],
370+
coords=coords.T, shape=data[1].shape)
371+
372+
serialization_context.register_type(
373+
sparse.coo.core.COO, 'sparse.coo.core.COO',
374+
custom_serializer=_serialize_pydata_sparse,
375+
custom_deserializer=_deserialize_pydata_sparse)
376+
377+
except ImportError:
378+
# no pydata/sparse
379+
pass
380+
381+
353382
def register_default_serialization_handlers(serialization_context):
354383

355384
# ----------------------------------------------------------------------
@@ -403,6 +432,7 @@ def register_default_serialization_handlers(serialization_context):
403432
_register_collections_serialization_handlers(serialization_context)
404433
_register_custom_pandas_handlers(serialization_context)
405434
_register_scipy_handlers(serialization_context)
435+
_register_pydata_sparse_handlers(serialization_context)
406436

407437

408438
def default_serialization_context():

python/pyarrow/tests/test_serialization.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@
5151
except ImportError:
5252
sparse = None
5353

54+
try:
55+
import sparse
56+
except ImportError:
57+
sparse = None
58+
5459

5560
def assert_equal(obj1, obj2):
5661
if torch is not None and torch.is_tensor(obj1) and torch.is_tensor(obj2):
@@ -621,6 +626,21 @@ def test_scipy_sparse_tensor_csr_serialization():
621626
assert np.array_equal(sparse_array.toarray(), result.toarray())
622627

623628

629+
@pytest.mark.skipif(not sparse, reason="requires pydata/sparse")
630+
def test_pydata_sparse__sparse_tensor_coo_serialization():
631+
data = np.array([1, 2, 3, 4, 5, 6, 7])
632+
row = np.array([0, 0, 2, 3, 1, 3, 0])
633+
col = np.array([0, 2, 0, 4, 5, 5, 0])
634+
coords = np.vstack([row, col]).T
635+
shape = (4, 6)
636+
637+
sparse_array = sparse.COO(data=data, coords=coords, shape=shape)
638+
serialized = pa.serialize(sparse_array)
639+
result = serialized.deserialize()
640+
641+
assert np.array_equal(sparse_array.todense(), result.todense())
642+
643+
624644
@pytest.mark.filterwarnings(
625645
"ignore:the matrix subclass:PendingDeprecationWarning")
626646
def test_numpy_matrix_serialization(tmpdir):

0 commit comments

Comments
 (0)