Skip to content

Commit 04e365c

Browse files
committed
Test inside pytest_runtest_call hook
1 parent 1af2acd commit 04e365c

File tree

1 file changed

+96
-74
lines changed

1 file changed

+96
-74
lines changed

pytest_arraydiff/plugin.py

Lines changed: 96 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -223,19 +223,51 @@ def pytest_configure(config):
223223
default_format=default_format))
224224

225225

226+
def generate_test_name(item):
227+
"""
228+
Generate a unique name for this test.
229+
"""
230+
if item.cls is not None:
231+
name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}"
232+
else:
233+
name = f"{item.module.__name__}.{item.name}"
234+
return name
235+
236+
237+
def wrap_array_interceptor(plugin, item):
238+
"""
239+
Intercept and store arrays returned by test functions.
240+
"""
241+
# Only intercept array on marked array tests
242+
if item.get_closest_marker('array_compare') is not None:
243+
244+
# Use the full test name as a key to ensure correct array is being retrieved
245+
test_name = generate_test_name(item)
246+
247+
def array_interceptor(store, obj):
248+
def wrapper(*args, **kwargs):
249+
store.return_value[test_name] = obj(*args, **kwargs)
250+
return wrapper
251+
252+
item.obj = array_interceptor(plugin, item.obj)
253+
254+
226255
class ArrayComparison(object):
227256

228257
def __init__(self, config, reference_dir=None, generate_dir=None, default_format='text'):
229258
self.config = config
230259
self.reference_dir = reference_dir
231260
self.generate_dir = generate_dir
232261
self.default_format = default_format
262+
self.return_value = {}
233263

234-
def pytest_runtest_setup(self, item):
264+
@pytest.hookimpl(hookwrapper=True)
265+
def pytest_runtest_call(self, item):
235266

236267
compare = item.get_closest_marker('array_compare')
237268

238269
if compare is None:
270+
yield
239271
return
240272

241273
file_format = compare.kwargs.get('file_format', self.default_format)
@@ -255,85 +287,75 @@ def pytest_runtest_setup(self, item):
255287

256288
write_kwargs = compare.kwargs.get('write_kwargs', {})
257289

258-
original = item.function
290+
reference_dir = compare.kwargs.get('reference_dir', None)
291+
if reference_dir is None:
292+
if self.reference_dir is None:
293+
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), 'reference')
294+
else:
295+
reference_dir = self.reference_dir
296+
else:
297+
if not reference_dir.startswith(('http://', 'https://')):
298+
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), reference_dir)
259299

260-
@wraps(item.function)
261-
def item_function_wrapper(*args, **kwargs):
300+
baseline_remote = reference_dir.startswith('http')
262301

263-
reference_dir = compare.kwargs.get('reference_dir', None)
264-
if reference_dir is None:
265-
if self.reference_dir is None:
266-
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), 'reference')
267-
else:
268-
reference_dir = self.reference_dir
269-
else:
270-
if not reference_dir.startswith(('http://', 'https://')):
271-
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), reference_dir)
272-
273-
baseline_remote = reference_dir.startswith('http')
274-
275-
# Run test and get figure object
276-
import inspect
277-
if inspect.ismethod(original): # method
278-
array = original(*args[1:], **kwargs)
279-
else: # function
280-
array = original(*args, **kwargs)
281-
282-
# Find test name to use as plot name
283-
filename = compare.kwargs.get('filename', None)
284-
if filename is None:
285-
if single_reference:
286-
filename = original.__name__ + '.' + extension
287-
else:
288-
filename = item.name + '.' + extension
289-
filename = filename.replace('[', '_').replace(']', '_')
290-
filename = filename.replace('_.' + extension, '.' + extension)
291-
292-
# What we do now depends on whether we are generating the reference
293-
# files or simply running the test.
294-
if self.generate_dir is None:
295-
296-
# Save the figure
297-
result_dir = tempfile.mkdtemp()
298-
test_array = os.path.abspath(os.path.join(result_dir, filename))
299-
300-
FORMATS[file_format].write(test_array, array, **write_kwargs)
301-
302-
# Find path to baseline array
303-
if baseline_remote:
304-
baseline_file_ref = _download_file(reference_dir + filename)
305-
else:
306-
baseline_file_ref = os.path.abspath(os.path.join(os.path.dirname(item.fspath.strpath), reference_dir, filename))
307-
308-
if not os.path.exists(baseline_file_ref):
309-
raise Exception("""File not found for comparison test
310-
Generated file:
311-
\t{test}
312-
This is expected for new tests.""".format(
313-
test=test_array))
314-
315-
# setuptools may put the baseline arrays in non-accessible places,
316-
# copy to our tmpdir to be sure to keep them in case of failure
317-
baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename))
318-
shutil.copyfile(baseline_file_ref, baseline_file)
319-
320-
identical, msg = FORMATS[file_format].compare(baseline_file, test_array, atol=atol, rtol=rtol)
321-
322-
if identical:
323-
shutil.rmtree(result_dir)
324-
else:
325-
raise Exception(msg)
302+
# Run test and get array object
303+
wrap_array_interceptor(self, item)
304+
yield
305+
test_name = generate_test_name(item)
306+
if test_name not in self.return_value:
307+
# Test function did not complete successfully
308+
return
309+
array = self.return_value[test_name]
310+
311+
# Find test name to use as plot name
312+
filename = compare.kwargs.get('filename', None)
313+
if filename is None:
314+
filename = item.name + '.' + extension
315+
if not single_reference:
316+
filename = filename.replace('[', '_').replace(']', '_')
317+
filename = filename.replace('_.' + extension, '.' + extension)
318+
319+
# What we do now depends on whether we are generating the reference
320+
# files or simply running the test.
321+
if self.generate_dir is None:
322+
323+
# Save the figure
324+
result_dir = tempfile.mkdtemp()
325+
test_array = os.path.abspath(os.path.join(result_dir, filename))
326+
327+
FORMATS[file_format].write(test_array, array, **write_kwargs)
326328

329+
# Find path to baseline array
330+
if baseline_remote:
331+
baseline_file_ref = _download_file(reference_dir + filename)
327332
else:
333+
baseline_file_ref = os.path.abspath(os.path.join(os.path.dirname(item.fspath.strpath), reference_dir, filename))
328334

329-
if not os.path.exists(self.generate_dir):
330-
os.makedirs(self.generate_dir)
335+
if not os.path.exists(baseline_file_ref):
336+
raise Exception("""File not found for comparison test
337+
Generated file:
338+
\t{test}
339+
This is expected for new tests.""".format(
340+
test=test_array))
331341

332-
FORMATS[file_format].write(os.path.abspath(os.path.join(self.generate_dir, filename)), array, **write_kwargs)
342+
# setuptools may put the baseline arrays in non-accessible places,
343+
# copy to our tmpdir to be sure to keep them in case of failure
344+
baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename))
345+
shutil.copyfile(baseline_file_ref, baseline_file)
333346

334-
pytest.skip("Skipping test, since generating data")
347+
identical, msg = FORMATS[file_format].compare(baseline_file, test_array, atol=atol, rtol=rtol)
348+
349+
if identical:
350+
shutil.rmtree(result_dir)
351+
else:
352+
raise Exception(msg)
335353

336-
if item.cls is not None:
337-
setattr(item.cls, item.function.__name__, item_function_wrapper)
338354
else:
339-
item.obj = item_function_wrapper
355+
356+
if not os.path.exists(self.generate_dir):
357+
os.makedirs(self.generate_dir)
358+
359+
FORMATS[file_format].write(os.path.abspath(os.path.join(self.generate_dir, filename)), array, **write_kwargs)
360+
361+
pytest.skip("Skipping test, since generating data")

0 commit comments

Comments
 (0)