Skip to content

Commit 97ba62a

Browse files
author
dhirschf
committed
Add some serialization tests for arrow
1 parent 6b6af6c commit 97ba62a

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import pandas as pd
2+
import pytest
3+
4+
pa = pytest.importorskip('pyarrow')
5+
6+
from dask import delayed
7+
from distributed.protocol import arrow
8+
from distributed.utils_test import gen_cluster
9+
from distributed.protocol import deserialize, serialize
10+
from distributed.protocol.serialize import class_serializers, typename
11+
12+
13+
df = pd.DataFrame({'A': list('abc'), 'B': [1,2,3]})
14+
tbl = pa.Table.from_pandas(df, preserve_index=False)
15+
batch = pa.RecordBatch.from_pandas(df, preserve_index=False)
16+
17+
18+
@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"])
19+
def test_roundtrip(obj):
20+
# Test that the serialize/deserialize functions actually
21+
# work independent of distributed
22+
header, frames = serialize(obj)
23+
new_obj = deserialize(header, frames)
24+
assert obj.equals(new_obj)
25+
26+
27+
@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"])
28+
def test_typename(obj):
29+
# The typename used to register the custom serialization is hardcoded
30+
# ensure that the typename hasn't changed
31+
assert typename(type(obj)) in class_serializers
32+
33+
34+
def echo(arg):
35+
return arg
36+
37+
38+
@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"])
39+
def test_submit(obj):
40+
@gen_cluster(client=True)
41+
def run_test(client, scheduler, worker1, worker2):
42+
fut = client.submit(echo, obj)
43+
assert obj.equals(fut.result())
44+
run_test()
45+
46+
47+
@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"])
48+
def test_scatter(obj):
49+
@gen_cluster(client=True)
50+
def run_test(client, scheduler, worker1, worker2):
51+
obj_fut = client.scatter(obj)
52+
fut = client.submit(echo, obj_fut)
53+
assert obj.equals(fut.result())
54+
run_test()
55+
56+
57+
@pytest.mark.parametrize('obj', [batch, tbl], ids=["RecordBatch", "Table"])
58+
def test_delayed(obj):
59+
@gen_cluster(client=True)
60+
def run_test(client, scheduler, worker1, worker2):
61+
delayed_obj = delayed(obj)
62+
fut = client.submit(echo, delayed_obj)
63+
assert obj.equals(fut.result())
64+
run_test()

0 commit comments

Comments
 (0)