@@ -223,19 +223,51 @@ def pytest_configure(config):
223
223
default_format = default_format ))
224
224
225
225
226
+ def generate_test_name (item ):
227
+ """
228
+ Generate a unique name for this test.
229
+ """
230
+ if item .cls is not None :
231
+ name = f"{ item .module .__name__ } .{ item .cls .__name__ } .{ item .name } "
232
+ else :
233
+ name = f"{ item .module .__name__ } .{ item .name } "
234
+ return name
235
+
236
+
237
+ def wrap_array_interceptor (plugin , item ):
238
+ """
239
+ Intercept and store arrays returned by test functions.
240
+ """
241
+ # Only intercept array on marked array tests
242
+ if item .get_closest_marker ('array_compare' ) is not None :
243
+
244
+ # Use the full test name as a key to ensure correct array is being retrieved
245
+ test_name = generate_test_name (item )
246
+
247
+ def array_interceptor (store , obj ):
248
+ def wrapper (* args , ** kwargs ):
249
+ store .return_value [test_name ] = obj (* args , ** kwargs )
250
+ return wrapper
251
+
252
+ item .obj = array_interceptor (plugin , item .obj )
253
+
254
+
226
255
class ArrayComparison (object ):
227
256
228
257
def __init__ (self , config , reference_dir = None , generate_dir = None , default_format = 'text' ):
229
258
self .config = config
230
259
self .reference_dir = reference_dir
231
260
self .generate_dir = generate_dir
232
261
self .default_format = default_format
262
+ self .return_value = {}
233
263
234
- def pytest_runtest_setup (self , item ):
264
+ @pytest .hookimpl (hookwrapper = True )
265
+ def pytest_runtest_call (self , item ):
235
266
236
267
compare = item .get_closest_marker ('array_compare' )
237
268
238
269
if compare is None :
270
+ yield
239
271
return
240
272
241
273
file_format = compare .kwargs .get ('file_format' , self .default_format )
@@ -255,85 +287,75 @@ def pytest_runtest_setup(self, item):
255
287
256
288
write_kwargs = compare .kwargs .get ('write_kwargs' , {})
257
289
258
- original = item .function
290
+ reference_dir = compare .kwargs .get ('reference_dir' , None )
291
+ if reference_dir is None :
292
+ if self .reference_dir is None :
293
+ reference_dir = os .path .join (os .path .dirname (item .fspath .strpath ), 'reference' )
294
+ else :
295
+ reference_dir = self .reference_dir
296
+ else :
297
+ if not reference_dir .startswith (('http://' , 'https://' )):
298
+ reference_dir = os .path .join (os .path .dirname (item .fspath .strpath ), reference_dir )
259
299
260
- @wraps (item .function )
261
- def item_function_wrapper (* args , ** kwargs ):
300
+ baseline_remote = reference_dir .startswith ('http' )
262
301
263
- reference_dir = compare .kwargs .get ('reference_dir' , None )
264
- if reference_dir is None :
265
- if self .reference_dir is None :
266
- reference_dir = os .path .join (os .path .dirname (item .fspath .strpath ), 'reference' )
267
- else :
268
- reference_dir = self .reference_dir
269
- else :
270
- if not reference_dir .startswith (('http://' , 'https://' )):
271
- reference_dir = os .path .join (os .path .dirname (item .fspath .strpath ), reference_dir )
272
-
273
- baseline_remote = reference_dir .startswith ('http' )
274
-
275
- # Run test and get figure object
276
- import inspect
277
- if inspect .ismethod (original ): # method
278
- array = original (* args [1 :], ** kwargs )
279
- else : # function
280
- array = original (* args , ** kwargs )
281
-
282
- # Find test name to use as plot name
283
- filename = compare .kwargs .get ('filename' , None )
284
- if filename is None :
285
- if single_reference :
286
- filename = original .__name__ + '.' + extension
287
- else :
288
- filename = item .name + '.' + extension
289
- filename = filename .replace ('[' , '_' ).replace (']' , '_' )
290
- filename = filename .replace ('_.' + extension , '.' + extension )
291
-
292
- # What we do now depends on whether we are generating the reference
293
- # files or simply running the test.
294
- if self .generate_dir is None :
295
-
296
- # Save the figure
297
- result_dir = tempfile .mkdtemp ()
298
- test_array = os .path .abspath (os .path .join (result_dir , filename ))
299
-
300
- FORMATS [file_format ].write (test_array , array , ** write_kwargs )
301
-
302
- # Find path to baseline array
303
- if baseline_remote :
304
- baseline_file_ref = _download_file (reference_dir + filename )
305
- else :
306
- baseline_file_ref = os .path .abspath (os .path .join (os .path .dirname (item .fspath .strpath ), reference_dir , filename ))
307
-
308
- if not os .path .exists (baseline_file_ref ):
309
- raise Exception ("""File not found for comparison test
310
- Generated file:
311
- \t {test}
312
- This is expected for new tests.""" .format (
313
- test = test_array ))
314
-
315
- # setuptools may put the baseline arrays in non-accessible places,
316
- # copy to our tmpdir to be sure to keep them in case of failure
317
- baseline_file = os .path .abspath (os .path .join (result_dir , 'reference-' + filename ))
318
- shutil .copyfile (baseline_file_ref , baseline_file )
319
-
320
- identical , msg = FORMATS [file_format ].compare (baseline_file , test_array , atol = atol , rtol = rtol )
321
-
322
- if identical :
323
- shutil .rmtree (result_dir )
324
- else :
325
- raise Exception (msg )
302
+ # Run test and get array object
303
+ wrap_array_interceptor (self , item )
304
+ yield
305
+ test_name = generate_test_name (item )
306
+ if test_name not in self .return_value :
307
+ # Test function did not complete successfully
308
+ return
309
+ array = self .return_value [test_name ]
310
+
311
+ # Find test name to use as plot name
312
+ filename = compare .kwargs .get ('filename' , None )
313
+ if filename is None :
314
+ filename = item .name + '.' + extension
315
+ if not single_reference :
316
+ filename = filename .replace ('[' , '_' ).replace (']' , '_' )
317
+ filename = filename .replace ('_.' + extension , '.' + extension )
318
+
319
+ # What we do now depends on whether we are generating the reference
320
+ # files or simply running the test.
321
+ if self .generate_dir is None :
322
+
323
+ # Save the figure
324
+ result_dir = tempfile .mkdtemp ()
325
+ test_array = os .path .abspath (os .path .join (result_dir , filename ))
326
+
327
+ FORMATS [file_format ].write (test_array , array , ** write_kwargs )
326
328
329
+ # Find path to baseline array
330
+ if baseline_remote :
331
+ baseline_file_ref = _download_file (reference_dir + filename )
327
332
else :
333
+ baseline_file_ref = os .path .abspath (os .path .join (os .path .dirname (item .fspath .strpath ), reference_dir , filename ))
328
334
329
- if not os .path .exists (self .generate_dir ):
330
- os .makedirs (self .generate_dir )
335
+ if not os .path .exists (baseline_file_ref ):
336
+ raise Exception ("""File not found for comparison test
337
+ Generated file:
338
+ \t {test}
339
+ This is expected for new tests.""" .format (
340
+ test = test_array ))
331
341
332
- FORMATS [file_format ].write (os .path .abspath (os .path .join (self .generate_dir , filename )), array , ** write_kwargs )
342
+ # setuptools may put the baseline arrays in non-accessible places,
343
+ # copy to our tmpdir to be sure to keep them in case of failure
344
+ baseline_file = os .path .abspath (os .path .join (result_dir , 'reference-' + filename ))
345
+ shutil .copyfile (baseline_file_ref , baseline_file )
333
346
334
- pytest .skip ("Skipping test, since generating data" )
347
+ identical , msg = FORMATS [file_format ].compare (baseline_file , test_array , atol = atol , rtol = rtol )
348
+
349
+ if identical :
350
+ shutil .rmtree (result_dir )
351
+ else :
352
+ raise Exception (msg )
335
353
336
- if item .cls is not None :
337
- setattr (item .cls , item .function .__name__ , item_function_wrapper )
338
354
else :
339
- item .obj = item_function_wrapper
355
+
356
+ if not os .path .exists (self .generate_dir ):
357
+ os .makedirs (self .generate_dir )
358
+
359
+ FORMATS [file_format ].write (os .path .abspath (os .path .join (self .generate_dir , filename )), array , ** write_kwargs )
360
+
361
+ pytest .skip ("Skipping test, since generating data" )
0 commit comments