Skip to content

Commit 0259a2b

Browse files
committed
TST: Add a test for SSH function
1 parent 9708216 commit 0259a2b

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

nipype/interfaces/tests/test_io.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from builtins import str, zip, range, open
66
from future import standard_library
77
import os
8+
import copy
89
import simplejson
910
import glob
1011
import shutil
@@ -37,6 +38,12 @@
3738
except ImportError:
3839
noboto3 = True
3940

41+
try:
42+
import paramiko
43+
no_paramiko = False
44+
except ImportError:
45+
no_paramiko = True
46+
4047
# Check for fakes3
4148
standard_library.install_aliases()
4249
from subprocess import check_call, CalledProcessError
@@ -611,3 +618,45 @@ def test_bids_infields_outfields(tmpdir):
611618
bg = nio.BIDSDataGrabber()
612619
for outfield in ['anat', 'func']:
613620
assert outfield in bg._outputs().traits()
621+
622+
623+
@pytest.mark.skipif(no_paramiko, reason="paramiko library is not available")
624+
def test_SSHDataGrabber(tmpdir):
625+
"""Test SSHDataGrabber by connecting to localhost and finding this test
626+
file.
627+
"""
628+
old_cwd = tmpdir.chdir()
629+
630+
# ssh client that connects to localhost, current user, regardless of
631+
# ~/.ssh/config
632+
def _mock_get_ssh_client(self):
633+
proxy = None
634+
client = paramiko.SSHClient()
635+
client.load_system_host_keys()
636+
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
637+
client.connect('localhost', username=os.getenv('USER'), sock=proxy)
638+
return client
639+
MockSSHDataGrabber = copy.copy(nio.SSHDataGrabber)
640+
MockSSHDataGrabber._get_ssh_client = _mock_get_ssh_client
641+
642+
this_dir = os.path.dirname(__file__)
643+
this_file = os.path.basename(__file__)
644+
this_test = this_file[:-3] # without .py
645+
646+
ssh_grabber = MockSSHDataGrabber(infields=['test'],
647+
outfields=['test_file'])
648+
# ssh_grabber.base_dir = str(tmpdir)
649+
ssh_grabber.inputs.base_directory = this_dir
650+
ssh_grabber.inputs.hostname = 'localhost'
651+
ssh_grabber.inputs.field_template = dict(test_file='%s.py')
652+
ssh_grabber.inputs.template = ''
653+
ssh_grabber.inputs.template_args = dict(test_file=[['test']])
654+
ssh_grabber.inputs.test = this_test
655+
ssh_grabber.inputs.sort_filelist = True
656+
657+
runtime = ssh_grabber.run()
658+
659+
# did we successfully get this file?
660+
assert runtime.outputs.test_file == str(tmpdir.join(this_file))
661+
662+
old_cwd.chdir()

0 commit comments

Comments
 (0)