Skip to content

Commit d06c664

Browse files
kszucsxhochy
authored andcommitted
ARROW-3903: [Python] Random array generator for Arrow conversion and Parquet testing
Generate random schemas, arrays, chunked_arrays, columns, record_batches and tables. Slow, but makes quiet easy to isolate corner cases (already created jira issues). In follow up PRs We should use these strategies to increase the coverage. It'll enable us to reduce the issues, We could even use it for generate benchmark datasets periodically (only if We persist somewhere). Example usage: Run 10 samples (dev profile): `pytest -sv pyarrow/tests/test_strategies.py::test_tables --enable-hypothesis --hypothesis-show-statistics --hypothesis-profile=dev` Print the generated examples (debug): `pytest -sv pyarrow/tests/test_strategies.py::test_schemas --enable-hypothesis --hypothesis-show-statistics --hypothesis-profile=debug` Author: Krisztián Szűcs <szucs.krisztian@gmail.com> Closes #3301 from kszucs/ARROW-3903 and squashes the following commits: ff6654c <Krisztián Szűcs> finalize 8b5e7ea <Krisztián Szűcs> rat 61fe01d <Krisztián Szűcs> strategies for chunked_arrays, columns, record batches; test the strategies themselves bdb63df <Krisztián Szűcs> hypothesis array strategy
1 parent 74f3f5f commit d06c664

File tree

4 files changed

+222
-18
lines changed

4 files changed

+222
-18
lines changed

python/pyarrow/table.pxi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,9 +1155,9 @@ cdef class Table(_PandasConvertible):
11551155
11561156
Parameters
11571157
----------
1158-
arrays: list of pyarrow.Array or pyarrow.Column
1158+
arrays : list of pyarrow.Array or pyarrow.Column
11591159
Equal-length arrays that should form the table.
1160-
names: list of str, optional
1160+
names : list of str, optional
11611161
Names for the table columns. If Columns passed, will be
11621162
inferred. If Arrays passed, this argument is required
11631163
schema : Schema, default None
@@ -1224,7 +1224,7 @@ cdef class Table(_PandasConvertible):
12241224
12251225
Parameters
12261226
----------
1227-
batches: sequence or iterator of RecordBatch
1227+
batches : sequence or iterator of RecordBatch
12281228
Sequence of RecordBatch to be converted, all schemas must be equal
12291229
schema : Schema, default None
12301230
If not passed, will be inferred from the first RecordBatch

python/pyarrow/tests/strategies.py

Lines changed: 143 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,14 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import pyarrow as pa
18+
import pytz
19+
import hypothesis as h
1920
import hypothesis.strategies as st
21+
import hypothesis.extra.numpy as npst
22+
import hypothesis.extra.pytz as tzst
23+
import numpy as np
24+
25+
import pyarrow as pa
2026

2127

2228
# TODO(kszucs): alphanum_text, surrogate_text
@@ -69,12 +75,11 @@
6975
pa.time64('us'),
7076
pa.time64('ns')
7177
])
72-
timestamp_types = st.sampled_from([
73-
pa.timestamp('s'),
74-
pa.timestamp('ms'),
75-
pa.timestamp('us'),
76-
pa.timestamp('ns')
77-
])
78+
timestamp_types = st.builds(
79+
pa.timestamp,
80+
unit=st.sampled_from(['s', 'ms', 'us', 'ns']),
81+
tz=tzst.timezones()
82+
)
7883
temporal_types = st.one_of(date_types, time_types, timestamp_types)
7984

8085
primitive_types = st.one_of(
@@ -106,20 +111,21 @@ def complex_types(inner_strategy=primitive_types):
106111
return list_types(inner_strategy) | struct_types(inner_strategy)
107112

108113

109-
def nested_list_types(item_strategy=primitive_types):
110-
return st.recursive(item_strategy, list_types)
114+
def nested_list_types(item_strategy=primitive_types, max_leaves=3):
115+
return st.recursive(item_strategy, list_types, max_leaves=max_leaves)
111116

112117

113-
def nested_struct_types(item_strategy=primitive_types):
114-
return st.recursive(item_strategy, struct_types)
118+
def nested_struct_types(item_strategy=primitive_types, max_leaves=3):
119+
return st.recursive(item_strategy, struct_types, max_leaves=max_leaves)
115120

116121

117-
def nested_complex_types(inner_strategy=primitive_types):
118-
return st.recursive(inner_strategy, complex_types)
122+
def nested_complex_types(inner_strategy=primitive_types, max_leaves=3):
123+
return st.recursive(inner_strategy, complex_types, max_leaves=max_leaves)
119124

120125

121-
def schemas(type_strategy=primitive_types):
122-
return st.builds(pa.schema, st.lists(fields(type_strategy)))
126+
def schemas(type_strategy=primitive_types, max_fields=None):
127+
children = st.lists(fields(type_strategy), max_size=max_fields)
128+
return st.builds(pa.schema, children)
123129

124130

125131
complex_schemas = schemas(complex_types())
@@ -128,3 +134,125 @@ def schemas(type_strategy=primitive_types):
128134
all_types = st.one_of(primitive_types, complex_types(), nested_complex_types())
129135
all_fields = fields(all_types)
130136
all_schemas = schemas(all_types)
137+
138+
139+
_default_array_sizes = st.integers(min_value=0, max_value=20)
140+
141+
142+
@st.composite
143+
def arrays(draw, type, size=None):
144+
if isinstance(type, st.SearchStrategy):
145+
type = draw(type)
146+
elif not isinstance(type, pa.DataType):
147+
raise TypeError('Type must be a pyarrow DataType')
148+
149+
if isinstance(size, st.SearchStrategy):
150+
size = draw(size)
151+
elif size is None:
152+
size = draw(_default_array_sizes)
153+
elif not isinstance(size, int):
154+
raise TypeError('Size must be an integer')
155+
156+
shape = (size,)
157+
158+
if pa.types.is_list(type):
159+
offsets = draw(npst.arrays(np.uint8(), shape=shape)).cumsum() // 20
160+
offsets = np.insert(offsets, 0, 0, axis=0) # prepend with zero
161+
values = draw(arrays(type.value_type, size=int(offsets.sum())))
162+
return pa.ListArray.from_arrays(offsets, values)
163+
164+
if pa.types.is_struct(type):
165+
h.assume(len(type) > 0)
166+
names, child_arrays = [], []
167+
for field in type:
168+
names.append(field.name)
169+
child_arrays.append(draw(arrays(field.type, size=size)))
170+
# fields' metadata are lost here, because from_arrays doesn't accept
171+
# a fields argumentum, only names
172+
return pa.StructArray.from_arrays(child_arrays, names=names)
173+
174+
if (pa.types.is_boolean(type) or pa.types.is_integer(type) or
175+
pa.types.is_floating(type)):
176+
values = npst.arrays(type.to_pandas_dtype(), shape=(size,))
177+
return pa.array(draw(values), type=type)
178+
179+
if pa.types.is_null(type):
180+
value = st.none()
181+
elif pa.types.is_time(type):
182+
value = st.times()
183+
elif pa.types.is_date(type):
184+
value = st.dates()
185+
elif pa.types.is_timestamp(type):
186+
tz = pytz.timezone(type.tz) if type.tz is not None else None
187+
value = st.datetimes(timezones=st.just(tz))
188+
elif pa.types.is_binary(type):
189+
value = st.binary()
190+
elif pa.types.is_string(type):
191+
value = st.text()
192+
elif pa.types.is_decimal(type):
193+
# TODO(kszucs): properly limit the precision
194+
# value = st.decimals(places=type.scale, allow_infinity=False)
195+
h.reject()
196+
else:
197+
raise NotImplementedError(type)
198+
199+
values = st.lists(value, min_size=size, max_size=size)
200+
return pa.array(draw(values), type=type)
201+
202+
203+
@st.composite
204+
def chunked_arrays(draw, type, min_chunks=0, max_chunks=None, chunk_size=None):
205+
if isinstance(type, st.SearchStrategy):
206+
type = draw(type)
207+
208+
# TODO(kszucs): remove it, field metadata is not kept
209+
h.assume(not pa.types.is_struct(type))
210+
211+
chunk = arrays(type, size=chunk_size)
212+
chunks = st.lists(chunk, min_size=min_chunks, max_size=max_chunks)
213+
214+
return pa.chunked_array(draw(chunks), type=type)
215+
216+
217+
def columns(type, min_chunks=0, max_chunks=None, chunk_size=None):
218+
chunked_array = chunked_arrays(type, chunk_size=chunk_size,
219+
min_chunks=min_chunks,
220+
max_chunks=max_chunks)
221+
return st.builds(pa.column, st.text(), chunked_array)
222+
223+
224+
@st.composite
225+
def record_batches(draw, type, rows=None, max_fields=None):
226+
if isinstance(rows, st.SearchStrategy):
227+
rows = draw(rows)
228+
elif rows is None:
229+
rows = draw(_default_array_sizes)
230+
elif not isinstance(rows, int):
231+
raise TypeError('Rows must be an integer')
232+
233+
schema = draw(schemas(type, max_fields=max_fields))
234+
children = [draw(arrays(field.type, size=rows)) for field in schema]
235+
# TODO(kszucs): the names and schame arguments are not consistent with
236+
# Table.from_array's arguments
237+
return pa.RecordBatch.from_arrays(children, names=schema)
238+
239+
240+
@st.composite
241+
def tables(draw, type, rows=None, max_fields=None):
242+
if isinstance(rows, st.SearchStrategy):
243+
rows = draw(rows)
244+
elif rows is None:
245+
rows = draw(_default_array_sizes)
246+
elif not isinstance(rows, int):
247+
raise TypeError('Rows must be an integer')
248+
249+
schema = draw(schemas(type, max_fields=max_fields))
250+
children = [draw(arrays(field.type, size=rows)) for field in schema]
251+
return pa.Table.from_arrays(children, schema=schema)
252+
253+
254+
all_arrays = arrays(all_types)
255+
all_chunked_arrays = chunked_arrays(all_types)
256+
all_columns = columns(all_types)
257+
all_record_batches = record_batches(all_types)
258+
all_tables = tables(all_types)

python/pyarrow/tests/test_array.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import collections
2020
import datetime
21+
import hypothesis as h
22+
import hypothesis.strategies as st
2123
import pickle
2224
import pytest
2325
import struct
@@ -32,6 +34,7 @@
3234
pickle5 = None
3335

3436
import pyarrow as pa
37+
import pyarrow.tests.strategies as past
3538
from pyarrow.pandas_compat import get_logical_type
3639

3740

@@ -802,6 +805,18 @@ def test_array_pickle(data, typ):
802805
assert array.equals(result)
803806

804807

808+
@h.given(
809+
past.arrays(
810+
past.all_types,
811+
size=st.integers(min_value=0, max_value=10)
812+
)
813+
)
814+
def test_pickling(arr):
815+
data = pickle.dumps(arr)
816+
restored = pickle.loads(data)
817+
assert arr.equals(restored)
818+
819+
805820
@pickle_test_parametrize
806821
def test_array_pickle5(data, typ):
807822
# Test zero-copy pickling with protocol 5 (PEP 574)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import hypothesis as h
19+
20+
import pyarrow as pa
21+
import pyarrow.tests.strategies as past
22+
23+
24+
@h.given(past.all_types)
25+
def test_types(ty):
26+
assert isinstance(ty, pa.lib.DataType)
27+
28+
29+
@h.given(past.all_fields)
30+
def test_fields(field):
31+
assert isinstance(field, pa.lib.Field)
32+
33+
34+
@h.given(past.all_schemas)
35+
def test_schemas(schema):
36+
assert isinstance(schema, pa.lib.Schema)
37+
38+
39+
@h.given(past.all_arrays)
40+
def test_arrays(array):
41+
assert isinstance(array, pa.lib.Array)
42+
43+
44+
@h.given(past.all_chunked_arrays)
45+
def test_chunked_arrays(chunked_array):
46+
assert isinstance(chunked_array, pa.lib.ChunkedArray)
47+
48+
49+
@h.given(past.all_columns)
50+
def test_columns(column):
51+
assert isinstance(column, pa.lib.Column)
52+
53+
54+
@h.given(past.all_record_batches)
55+
def test_record_batches(record_bath):
56+
assert isinstance(record_bath, pa.lib.RecordBatch)
57+
58+
59+
@h.given(past.all_tables)
60+
def test_tables(table):
61+
assert isinstance(table, pa.lib.Table)

0 commit comments

Comments
 (0)