Skip to content

Commit 589659d

Browse files
authored
Merge pull request #1773 from djarecka/test_engine_clean
simplifying tests, removing some try/except blocks
2 parents 486775c + 6f5939f commit 589659d

File tree

1 file changed

+34
-72
lines changed

1 file changed

+34
-72
lines changed

nipype/pipeline/engine/tests/test_engine.py

Lines changed: 34 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _list_outputs(self):
4343

4444

4545
def test_init():
46-
with pytest.raises(Exception): pe.Workflow()
46+
with pytest.raises(TypeError): pe.Workflow()
4747
pipe = pe.Workflow(name='pipe')
4848
assert type(pipe._graph) == nx.DiGraph
4949

@@ -156,12 +156,8 @@ def test_expansion():
156156
pipe5.add_nodes([pipe4])
157157
pipe6 = pe.Workflow(name="pipe6")
158158
pipe6.connect([(pipe5, pipe3, [('pipe4.mod5.output1', 'pipe2.mod3.input1')])])
159-
error_raised = False
160-
try:
161-
pipe6._flatgraph = pipe6._create_flat_graph()
162-
except:
163-
error_raised = True
164-
assert not error_raised
159+
160+
pipe6._flatgraph = pipe6._create_flat_graph()
165161

166162

167163
def test_iterable_expansion():
@@ -330,11 +326,16 @@ def test_doubleconnect():
330326
flow1 = pe.Workflow(name='test')
331327
flow1.connect(a, 'a', b, 'a')
332328
x = lambda: flow1.connect(a, 'b', b, 'a')
333-
with pytest.raises(Exception): x()
329+
with pytest.raises(Exception) as excinfo:
330+
x()
331+
assert "Trying to connect" in str(excinfo.value)
332+
334333
c = pe.Node(IdentityInterface(fields=['a', 'b']), name='c')
335334
flow1 = pe.Workflow(name='test2')
336335
x = lambda: flow1.connect([(a, c, [('b', 'b')]), (b, c, [('a', 'b')])])
337-
with pytest.raises(Exception): x()
336+
with pytest.raises(Exception) as excinfo:
337+
x()
338+
assert "Trying to connect" in str(excinfo.value)
338339

339340

340341
'''
@@ -479,14 +480,10 @@ def func1(in1):
479480
nested=False,
480481
name='n1')
481482
n2.inputs.in1 = [[1, [2]], 3, [4, 5]]
482-
error_raised = False
483-
try:
483+
484+
with pytest.raises(Exception) as excinfo:
484485
n2.run()
485-
except Exception as e:
486-
from nipype.pipeline.engine.base import logger
487-
logger.info('Exception: %s' % str(e))
488-
error_raised = True
489-
assert error_raised
486+
assert "can only concatenate list" in str(excinfo.value)
490487

491488

492489
def test_node_hash(tmpdir):
@@ -518,35 +515,31 @@ def func2(a):
518515
w1.config['execution'] = {'stop_on_first_crash': 'true',
519516
'local_hash_check': 'false',
520517
'crashdump_dir': wd}
521-
error_raised = False
522518
# create dummy distributed plugin class
523519
from nipype.pipeline.plugins.base import DistributedPluginBase
524520

521+
# create a custom exception
522+
class EngineTestException(Exception):
523+
pass
524+
525525
class RaiseError(DistributedPluginBase):
526526
def _submit_job(self, node, updatehash=False):
527-
raise Exception('Submit called')
528-
try:
527+
raise EngineTestException('Submit called')
528+
529+
# check if a proper exception is raised
530+
with pytest.raises(EngineTestException) as excinfo:
529531
w1.run(plugin=RaiseError())
530-
except Exception as e:
531-
from nipype.pipeline.engine.base import logger
532-
logger.info('Exception: %s' % str(e))
533-
error_raised = True
534-
assert error_raised
532+
assert 'Submit called' == str(excinfo.value)
533+
535534
# rerun to ensure we have outputs
536535
w1.run(plugin='Linear')
537536
# set local check
538537
w1.config['execution'] = {'stop_on_first_crash': 'true',
539538
'local_hash_check': 'true',
540539
'crashdump_dir': wd}
541-
error_raised = False
542-
try:
543-
w1.run(plugin=RaiseError())
544-
except Exception as e:
545-
from nipype.pipeline.engine.base import logger
546-
logger.info('Exception: %s' % str(e))
547-
error_raised = True
548-
assert not error_raised
549540

541+
w1.run(plugin=RaiseError())
542+
550543

551544
def test_old_config(tmpdir):
552545
wd = str(tmpdir)
@@ -574,14 +567,8 @@ def func2(a):
574567

575568
w1.config['execution']['crashdump_dir'] = wd
576569
# generate outputs
577-
error_raised = False
578-
try:
579-
w1.run(plugin='Linear')
580-
except Exception as e:
581-
from nipype.pipeline.engine.base import logger
582-
logger.info('Exception: %s' % str(e))
583-
error_raised = True
584-
assert not error_raised
570+
571+
w1.run(plugin='Linear')
585572

586573

587574
def test_mapnode_json(tmpdir):
@@ -618,13 +605,9 @@ def func1(in1):
618605
with open(os.path.join(node.output_dir(), 'test.json'), 'wt') as fp:
619606
fp.write('dummy file')
620607
w1.config['execution'].update(**{'stop_on_first_rerun': True})
621-
error_raised = False
622-
try:
623-
w1.run()
624-
except:
625-
error_raised = True
626-
assert not error_raised
627608

609+
w1.run()
610+
628611

629612
def test_parameterize_dirs_false(tmpdir):
630613
from ....interfaces.utility import IdentityInterface
@@ -643,14 +626,8 @@ def test_parameterize_dirs_false(tmpdir):
643626
wf.config['execution']['parameterize_dirs'] = False
644627
wf.connect([(n1, n2, [('output1', 'in1')])])
645628

646-
error_raised = False
647-
try:
648-
wf.run()
649-
except TypeError as typerr:
650-
from nipype.pipeline.engine.base import logger
651-
logger.info('Exception: %s' % str(typerr))
652-
error_raised = True
653-
assert not error_raised
629+
630+
wf.run()
654631

655632

656633
def test_serial_input(tmpdir):
@@ -680,30 +657,15 @@ def func1(in1):
680657
assert n1.num_subnodes() == len(n1.inputs.in1)
681658

682659
# test running the workflow on default conditions
683-
error_raised = False
684-
try:
685-
w1.run(plugin='MultiProc')
686-
except Exception as e:
687-
from nipype.pipeline.engine.base import logger
688-
logger.info('Exception: %s' % str(e))
689-
error_raised = True
690-
assert not error_raised
660+
w1.run(plugin='MultiProc')
691661

692662
# test output of num_subnodes method when serial is True
693663
n1._serial = True
694664
assert n1.num_subnodes() == 1
695665

696666
# test running the workflow on serial conditions
697-
error_raised = False
698-
try:
699-
w1.run(plugin='MultiProc')
700-
except Exception as e:
701-
from nipype.pipeline.engine.base import logger
702-
logger.info('Exception: %s' % str(e))
703-
error_raised = True
704-
705-
assert not error_raised
706-
667+
w1.run(plugin='MultiProc')
668+
707669

708670
def test_write_graph_runs(tmpdir):
709671
os.chdir(str(tmpdir))

0 commit comments

Comments
 (0)