15
15
# specific language governing permissions and limitations
16
16
# under the License.
17
17
18
- import pyarrow as pa
18
+ import pytz
19
+ import hypothesis as h
19
20
import hypothesis .strategies as st
21
+ import hypothesis .extra .numpy as npst
22
+ import hypothesis .extra .pytz as tzst
23
+ import numpy as np
24
+
25
+ import pyarrow as pa
20
26
21
27
22
28
# TODO(kszucs): alphanum_text, surrogate_text
69
75
pa .time64 ('us' ),
70
76
pa .time64 ('ns' )
71
77
])
72
- timestamp_types = st .sampled_from ([
73
- pa .timestamp ('s' ),
74
- pa .timestamp ('ms' ),
75
- pa .timestamp ('us' ),
76
- pa .timestamp ('ns' )
77
- ])
78
+ timestamp_types = st .builds (
79
+ pa .timestamp ,
80
+ unit = st .sampled_from (['s' , 'ms' , 'us' , 'ns' ]),
81
+ tz = tzst .timezones ()
82
+ )
78
83
temporal_types = st .one_of (date_types , time_types , timestamp_types )
79
84
80
85
primitive_types = st .one_of (
@@ -106,20 +111,21 @@ def complex_types(inner_strategy=primitive_types):
106
111
return list_types (inner_strategy ) | struct_types (inner_strategy )
107
112
108
113
109
- def nested_list_types (item_strategy = primitive_types ):
110
- return st .recursive (item_strategy , list_types )
114
+ def nested_list_types (item_strategy = primitive_types , max_leaves = 3 ):
115
+ return st .recursive (item_strategy , list_types , max_leaves = max_leaves )
111
116
112
117
113
- def nested_struct_types (item_strategy = primitive_types ):
114
- return st .recursive (item_strategy , struct_types )
118
+ def nested_struct_types (item_strategy = primitive_types , max_leaves = 3 ):
119
+ return st .recursive (item_strategy , struct_types , max_leaves = max_leaves )
115
120
116
121
117
- def nested_complex_types (inner_strategy = primitive_types ):
118
- return st .recursive (inner_strategy , complex_types )
122
+ def nested_complex_types (inner_strategy = primitive_types , max_leaves = 3 ):
123
+ return st .recursive (inner_strategy , complex_types , max_leaves = max_leaves )
119
124
120
125
121
- def schemas (type_strategy = primitive_types ):
122
- return st .builds (pa .schema , st .lists (fields (type_strategy )))
126
+ def schemas (type_strategy = primitive_types , max_fields = None ):
127
+ children = st .lists (fields (type_strategy ), max_size = max_fields )
128
+ return st .builds (pa .schema , children )
123
129
124
130
125
131
complex_schemas = schemas (complex_types ())
@@ -128,3 +134,125 @@ def schemas(type_strategy=primitive_types):
128
134
all_types = st .one_of (primitive_types , complex_types (), nested_complex_types ())
129
135
all_fields = fields (all_types )
130
136
all_schemas = schemas (all_types )
137
+
138
+
139
+ _default_array_sizes = st .integers (min_value = 0 , max_value = 20 )
140
+
141
+
142
+ @st .composite
143
+ def arrays (draw , type , size = None ):
144
+ if isinstance (type , st .SearchStrategy ):
145
+ type = draw (type )
146
+ elif not isinstance (type , pa .DataType ):
147
+ raise TypeError ('Type must be a pyarrow DataType' )
148
+
149
+ if isinstance (size , st .SearchStrategy ):
150
+ size = draw (size )
151
+ elif size is None :
152
+ size = draw (_default_array_sizes )
153
+ elif not isinstance (size , int ):
154
+ raise TypeError ('Size must be an integer' )
155
+
156
+ shape = (size ,)
157
+
158
+ if pa .types .is_list (type ):
159
+ offsets = draw (npst .arrays (np .uint8 (), shape = shape )).cumsum () // 20
160
+ offsets = np .insert (offsets , 0 , 0 , axis = 0 ) # prepend with zero
161
+ values = draw (arrays (type .value_type , size = int (offsets .sum ())))
162
+ return pa .ListArray .from_arrays (offsets , values )
163
+
164
+ if pa .types .is_struct (type ):
165
+ h .assume (len (type ) > 0 )
166
+ names , child_arrays = [], []
167
+ for field in type :
168
+ names .append (field .name )
169
+ child_arrays .append (draw (arrays (field .type , size = size )))
170
+ # fields' metadata are lost here, because from_arrays doesn't accept
171
+ # a fields argumentum, only names
172
+ return pa .StructArray .from_arrays (child_arrays , names = names )
173
+
174
+ if (pa .types .is_boolean (type ) or pa .types .is_integer (type ) or
175
+ pa .types .is_floating (type )):
176
+ values = npst .arrays (type .to_pandas_dtype (), shape = (size ,))
177
+ return pa .array (draw (values ), type = type )
178
+
179
+ if pa .types .is_null (type ):
180
+ value = st .none ()
181
+ elif pa .types .is_time (type ):
182
+ value = st .times ()
183
+ elif pa .types .is_date (type ):
184
+ value = st .dates ()
185
+ elif pa .types .is_timestamp (type ):
186
+ tz = pytz .timezone (type .tz ) if type .tz is not None else None
187
+ value = st .datetimes (timezones = st .just (tz ))
188
+ elif pa .types .is_binary (type ):
189
+ value = st .binary ()
190
+ elif pa .types .is_string (type ):
191
+ value = st .text ()
192
+ elif pa .types .is_decimal (type ):
193
+ # TODO(kszucs): properly limit the precision
194
+ # value = st.decimals(places=type.scale, allow_infinity=False)
195
+ h .reject ()
196
+ else :
197
+ raise NotImplementedError (type )
198
+
199
+ values = st .lists (value , min_size = size , max_size = size )
200
+ return pa .array (draw (values ), type = type )
201
+
202
+
203
+ @st .composite
204
+ def chunked_arrays (draw , type , min_chunks = 0 , max_chunks = None , chunk_size = None ):
205
+ if isinstance (type , st .SearchStrategy ):
206
+ type = draw (type )
207
+
208
+ # TODO(kszucs): remove it, field metadata is not kept
209
+ h .assume (not pa .types .is_struct (type ))
210
+
211
+ chunk = arrays (type , size = chunk_size )
212
+ chunks = st .lists (chunk , min_size = min_chunks , max_size = max_chunks )
213
+
214
+ return pa .chunked_array (draw (chunks ), type = type )
215
+
216
+
217
+ def columns (type , min_chunks = 0 , max_chunks = None , chunk_size = None ):
218
+ chunked_array = chunked_arrays (type , chunk_size = chunk_size ,
219
+ min_chunks = min_chunks ,
220
+ max_chunks = max_chunks )
221
+ return st .builds (pa .column , st .text (), chunked_array )
222
+
223
+
224
+ @st .composite
225
+ def record_batches (draw , type , rows = None , max_fields = None ):
226
+ if isinstance (rows , st .SearchStrategy ):
227
+ rows = draw (rows )
228
+ elif rows is None :
229
+ rows = draw (_default_array_sizes )
230
+ elif not isinstance (rows , int ):
231
+ raise TypeError ('Rows must be an integer' )
232
+
233
+ schema = draw (schemas (type , max_fields = max_fields ))
234
+ children = [draw (arrays (field .type , size = rows )) for field in schema ]
235
+ # TODO(kszucs): the names and schame arguments are not consistent with
236
+ # Table.from_array's arguments
237
+ return pa .RecordBatch .from_arrays (children , names = schema )
238
+
239
+
240
+ @st .composite
241
+ def tables (draw , type , rows = None , max_fields = None ):
242
+ if isinstance (rows , st .SearchStrategy ):
243
+ rows = draw (rows )
244
+ elif rows is None :
245
+ rows = draw (_default_array_sizes )
246
+ elif not isinstance (rows , int ):
247
+ raise TypeError ('Rows must be an integer' )
248
+
249
+ schema = draw (schemas (type , max_fields = max_fields ))
250
+ children = [draw (arrays (field .type , size = rows )) for field in schema ]
251
+ return pa .Table .from_arrays (children , schema = schema )
252
+
253
+
254
+ all_arrays = arrays (all_types )
255
+ all_chunked_arrays = chunked_arrays (all_types )
256
+ all_columns = columns (all_types )
257
+ all_record_batches = record_batches (all_types )
258
+ all_tables = tables (all_types )
0 commit comments