diff --git a/Changelog.rst b/Changelog.rst index 8f1b93dc..cdc42951 100644 --- a/Changelog.rst +++ b/Changelog.rst @@ -107,6 +107,7 @@ Fixes Changes ------- +* Performance and scaling improvements for all clients. Allow ``ssh-python`` (``libssh``) client to use multiple cores for authentication. * ``user`` keyword argument no longer required on Windows - exception is raised if user cannot be identified. * Removed deprecated since ``2.0.0`` functions and parameters. @@ -117,6 +118,11 @@ Fixes * Reconnecting to the same proxy host when proxy is configured would sometimes cause segfauls - ##304 +Fixes +----- + +* ``ParallelSSHClient.connect_auth`` would not honour client pool size. + 2.5.4 +++++ diff --git a/examples/parallel_commands.py b/examples/parallel_commands.py index 72b96902..c0a48137 100644 --- a/examples/parallel_commands.py +++ b/examples/parallel_commands.py @@ -10,7 +10,7 @@ cmds = ['sleep 5; uname' for _ in range(10)] start = datetime.datetime.now() for cmd in cmds: - output.append(client.run_command(cmd, stop_on_errors=False, return_list=True)) + output.append(client.run_command(cmd, stop_on_errors=False)) end = datetime.datetime.now() print("Started %s 'sleep 5' commands on %s host(s) in %s" % ( len(cmds), len(hosts), end-start,)) diff --git a/examples/quickstart.py b/examples/quickstart.py index a759ae62..412466cb 100644 --- a/examples/quickstart.py +++ b/examples/quickstart.py @@ -10,4 +10,4 @@ for host_out in output: for line in host_out.stdout: print(line) -print("Host %s: exit code %s" % (host_out.host, host_out.exit_code)) + print("Host %s: exit code %s" % (host_out.host, host_out.exit_code)) diff --git a/pssh/__init__.py b/pssh/__init__.py index bce2c66a..29bebb72 100644 --- a/pssh/__init__.py +++ b/pssh/__init__.py @@ -26,10 +26,10 @@ See also `pssh.clients.ParallelSSHClient` and pssh.clients.SSHClient` for class documentation. """ - - from logging import getLogger, NullHandler + from ._version import get_versions + __version__ = get_versions()['version'] del get_versions diff --git a/pssh/clients/base/parallel.py b/pssh/clients/base/parallel.py index 43f05db1..2ed327dd 100644 --- a/pssh/clients/base/parallel.py +++ b/pssh/clients/base/parallel.py @@ -23,7 +23,7 @@ from gevent import joinall, spawn, Timeout as GTimeout from gevent.hub import Hub -from ..common import _validate_pkey_path +from ..common import _validate_pkey_path, _validate_pkey from ...config import HostConfig from ...constants import DEFAULT_RETRIES, RETRY_DELAY from ...exceptions import HostArgumentError, Timeout, ShellError, HostConfigError @@ -64,7 +64,8 @@ def __init__(self, hosts, user=None, password=None, port=None, pkey=None, self.user = user self.password = password self.port = port - self.pkey = pkey + self.pkey = _validate_pkey(pkey) + self.__pkey_data = self._load_pkey_data(pkey) if pkey is not None else None self.num_retries = num_retries self.timeout = timeout self._host_clients = {} @@ -113,9 +114,26 @@ def hosts(self, _hosts): self._host_clients.pop((i, host), None) self._hosts = _hosts + def __del__(self): + self.disconnect() + + def disconnect(self): + if not hasattr(self, '_host_clients'): + return + for s_client in self._host_clients.values(): + try: + s_client.disconnect() + except Exception as ex: + logger.debug("Client disconnect failed with %s", ex) + pass + del s_client + def _check_host_config(self): if self.host_config is None: return + if not isinstance(self.host_config, list): + raise HostConfigError("Host configuration of type %s is invalid - valid types are list[HostConfig]", + type(self.host_config)) host_len = len(self.hosts) if host_len != len(self.host_config): raise ValueError( @@ -256,7 +274,7 @@ def get_last_output(self, cmds=None): return self._get_output_from_cmds( cmds, raise_error=False) - def _get_host_config(self, host_i, host): + def _get_host_config(self, host_i): if self.host_config is None: config = HostConfig( user=self.user, port=self.port, password=self.password, private_key=self.pkey, @@ -274,9 +292,6 @@ def _get_host_config(self, host_i, host): gssapi_delegate_credentials=self.gssapi_delegate_credentials, ) return config - elif not isinstance(self.host_config, list): - raise HostConfigError("Host configuration of type %s is invalid - valid types are list[HostConfig]", - type(self.host_config)) config = self.host_config[host_i] return config @@ -284,7 +299,6 @@ def _run_command(self, host_i, host, command, sudo=False, user=None, shell=None, use_pty=False, encoding='utf-8', read_timeout=None): """Make SSHClient if needed, run command on host""" - logger.debug("_run_command with read timeout %s", read_timeout) try: _client = self._get_ssh_client(host_i, host) host_out = _client.run_command( @@ -310,13 +324,13 @@ def connect_auth(self): :returns: list of greenlets to ``joinall`` with. :rtype: list(:py:mod:`gevent.greenlet.Greenlet`) """ - cmds = [spawn(self._get_ssh_client, i, host) for i, host in enumerate(self.hosts)] + cmds = [self.pool.spawn(self._get_ssh_client, i, host) for i, host in enumerate(self.hosts)] return cmds def _consume_output(self, stdout, stderr): - for line in stdout: + for _ in stdout: pass - for line in stderr: + for _ in stderr: pass def join(self, output=None, consume_output=False, timeout=None): @@ -543,32 +557,25 @@ def _copy_remote_file(self, host_i, host, remote_file, local_file, recurse, return client.copy_remote_file( remote_file, local_file, recurse=recurse, **kwargs) - def _handle_greenlet_exc(self, func, host, *args, **kwargs): - try: - return func(*args, **kwargs) - except Exception as ex: - raise ex - def _get_ssh_client(self, host_i, host): logger.debug("Make client request for host %s, (host_i, host) in clients: %s", host, (host_i, host) in self._host_clients) _client = self._host_clients.get((host_i, host)) if _client is not None: return _client - cfg = self._get_host_config(host_i, host) - _pkey = self.pkey if cfg.private_key is None else cfg.private_key - _pkey_data = self._load_pkey_data(_pkey) + cfg = self._get_host_config(host_i) + _pkey_data = self.__pkey_data if cfg.private_key is None else self._load_pkey_data(cfg.private_key) _client = self._make_ssh_client(host, cfg, _pkey_data) self._host_clients[(host_i, host)] = _client return _client def _load_pkey_data(self, _pkey): - if isinstance(_pkey, str): - _validate_pkey_path(_pkey) - with open(_pkey, 'rb') as fh: - _pkey_data = fh.read() - return _pkey_data - return _pkey + if not isinstance(_pkey, str): + return _pkey + _pkey = _validate_pkey_path(_pkey) + with open(_pkey, 'rb') as fh: + _pkey_data = fh.read() + return _pkey_data def _make_ssh_client(self, host, cfg, _pkey_data): raise NotImplementedError diff --git a/pssh/clients/base/single.py b/pssh/clients/base/single.py index 4255a14f..ff1d42b3 100644 --- a/pssh/clients/base/single.py +++ b/pssh/clients/base/single.py @@ -22,20 +22,19 @@ from gevent import sleep, socket, Timeout as GTimeout from gevent.hub import Hub +from gevent.lock import RLock from gevent.select import poll, POLLIN, POLLOUT - -from ssh2.utils import find_eol from ssh2.exceptions import AgentConnectionError, AgentListIdentitiesError, \ AgentAuthenticationError, AgentGetIdentityError +from ssh2.utils import find_eol from ..common import _validate_pkey -from ...constants import DEFAULT_RETRIES, RETRY_DELAY from ..reader import ConcurrentRWBuffer +from ...constants import DEFAULT_RETRIES, RETRY_DELAY from ...exceptions import UnknownHostError, AuthenticationError, \ ConnectionError, Timeout, NoIPv6AddressFoundError from ...output import HostOutput, HostOutputBuffers, BufferData - Hub.NOT_ERROR = (Exception,) host_logger = logging.getLogger('pssh.host_logger') logger = logging.getLogger(__name__) @@ -186,6 +185,7 @@ def __init__(self, host, self.identity_auth = identity_auth self._keepalive_greenlet = None self.ipv6_only = ipv6_only + self._sess_lock = RLock() self._init() def _pkey_from_memory(self, pkey_data): @@ -286,7 +286,7 @@ def _connect(self, host, port, retries=1): raise unknown_ex from ex for i, (family, _type, proto, _, sock_addr) in enumerate(addr_info): try: - return self._connect_socket(family, _type, proto, sock_addr, host, port, retries) + return self._connect_socket(family, _type, sock_addr, host, port, retries) except ConnectionRefusedError as ex: if i+1 == len(addr_info): logger.error("No available addresses from %s", [addr[4] for addr in addr_info]) @@ -294,7 +294,7 @@ def _connect(self, host, port, retries=1): raise continue - def _connect_socket(self, family, _type, proto, sock_addr, host, port, retries): + def _connect_socket(self, family, _type, sock_addr, host, port, retries): self.sock = socket.socket(family, _type) if self.timeout: self.sock.settimeout(self.timeout) @@ -427,6 +427,9 @@ def read_stderr(self, stderr_buffer, timeout=None): :param stderr_buffer: Buffer to read from. :type stderr_buffer: :py:class:`pssh.clients.reader.ConcurrentRWBuffer` + :param timeout: Timeout in seconds - defaults to no timeout. + :type timeout: int or float + :rtype: generator """ logger.debug("Reading from stderr buffer, timeout=%s", timeout) @@ -438,6 +441,9 @@ def read_output(self, stdout_buffer, timeout=None): :param stdout_buffer: Buffer to read from. :type stdout_buffer: :py:class:`pssh.clients.reader.ConcurrentRWBuffer` + :param timeout: Timeout in seconds - defaults to no timeout. + :type timeout: int or float + :rtype: generator """ logger.debug("Reading from stdout buffer, timeout=%s", timeout) @@ -473,7 +479,7 @@ def _read_output_buffer(self, _buffer, timeout=None): finally: timer.close() - def _read_output_to_buffer(self, read_func, _buffer): + def _read_output_to_buffer(self, read_func, _buffer, is_stderr=False): raise NotImplementedError def wait_finished(self, host_output, timeout=None): @@ -495,6 +501,8 @@ def read_output_buffer(self, output_buffer, prefix=None, :type output_buffer: iterator :param prefix: String to prefix log output to ``host_logger`` with :type prefix: str + :param encoding: Output encoding to use for host logger. + :type encoding: str :param callback: Function to call back once buffer is depleted: :type callback: function :param callback_args: Arguments for call back function @@ -569,6 +577,7 @@ def _eagain_errcode(self, func, eagain, *args, **kwargs): while ret == eagain: self.poll() ret = func(*args, **kwargs) + sleep() return ret def _eagain_write(self, write_func, data, timeout=None): diff --git a/pssh/clients/native/parallel.py b/pssh/clients/native/parallel.py index 0b95a51e..d6839f3c 100644 --- a/pssh/clients/native/parallel.py +++ b/pssh/clients/native/parallel.py @@ -127,7 +127,6 @@ def __init__(self, hosts, user=None, password=None, port=22, pkey=None, identity_auth=identity_auth, ipv6_only=ipv6_only, ) - self.pkey = _validate_pkey(pkey) self.proxy_host = proxy_host self.proxy_port = proxy_port self.proxy_pkey = _validate_pkey(proxy_pkey) @@ -216,17 +215,6 @@ def run_command(self, command, sudo=False, user=None, stop_on_errors=True, read_timeout=read_timeout, ) - def __del__(self): - if not hasattr(self, '_host_clients'): - return - for s_client in self._host_clients.values(): - try: - s_client.disconnect() - except Exception as ex: - logger.debug("Client disconnect failed with %s", ex) - pass - del s_client - def _make_ssh_client(self, host, cfg, _pkey_data): _client = SSHClient( host, user=cfg.user or self.user, password=cfg.password or self.password, port=cfg.port or self.port, @@ -370,16 +358,12 @@ def copy_remote_file(self, remote_file, local_file, recurse=False, encoding=encoding) def _scp_send(self, host_i, host, local_file, remote_file, recurse=False): - self._get_ssh_client(host_i, host) - return self._handle_greenlet_exc( - self._host_clients[(host_i, host)].scp_send, host, - local_file, remote_file, recurse=recurse) + _client = self._get_ssh_client(host_i, host) + return _client.scp_send(local_file, remote_file, recurse=recurse) def _scp_recv(self, host_i, host, remote_file, local_file, recurse=False): - self._get_ssh_client(host_i, host) - return self._handle_greenlet_exc( - self._host_clients[(host_i, host)].scp_recv, host, - remote_file, local_file, recurse=recurse) + _client = self._get_ssh_client(host_i, host) + return _client.scp_recv(remote_file, local_file, recurse=recurse) def scp_send(self, local_file, remote_file, recurse=False, copy_args=None): """Copy local file to remote file in parallel via SCP. @@ -404,6 +388,11 @@ def scp_send(self, local_file, remote_file, recurse=False, copy_args=None): :type local_file: str :param remote_file: Remote filepath on remote host to copy file to :type remote_file: str + :param copy_args: (Optional) format local_file and remote_file strings + with per-host arguments in ``copy_args``. ``copy_args`` length must + equal length of host list - + :py:class:`pssh.exceptions.HostArgumentError` is raised otherwise + :type copy_args: tuple or list :param recurse: Whether or not to descend into directories recursively. :type recurse: bool @@ -415,7 +404,7 @@ def scp_send(self, local_file, remote_file, recurse=False, copy_args=None): """ copy_args = [{'local_file': local_file, 'remote_file': remote_file} - for i, host in enumerate(self.hosts)] \ + for _ in self.hosts] \ if copy_args is None else copy_args local_file = "%(local_file)s" remote_file = "%(remote_file)s" diff --git a/pssh/clients/native/single.py b/pssh/clients/native/single.py index 2dc4e7cf..1c150841 100644 --- a/pssh/clients/native/single.py +++ b/pssh/clients/native/single.py @@ -131,7 +131,8 @@ def __init__(self, host, identity_auth=identity_auth, ) proxy_host = '127.0.0.1' - self._chan_lock = RLock() + self._chan_stdout_lock = RLock() + self._chan_stderr_lock = RLock() super(SSHClient, self).__init__( host, user=user, password=password, port=port, pkey=pkey, num_retries=num_retries, retry_delay=retry_delay, @@ -219,7 +220,8 @@ def _init_session(self, retries=1): self.session.set_timeout(self.timeout * 1000) try: if self._auth_thread_pool: - THREAD_POOL.apply(self.session.handshake, (self.sock,)) + with self._sess_lock: + THREAD_POOL.apply(self.session.handshake, (self.sock,)) else: self.session.handshake(self.sock) except Exception as ex: @@ -228,6 +230,8 @@ def _init_session(self, retries=1): return self._connect_init_session_retry(retries=retries+1) msg = "Error connecting to host %s:%s - %s" logger.error(msg, self.host, self.port, ex) + if not self.sock.closed: + self.sock.close() if isinstance(ex, SSH2Timeout): raise Timeout(msg, self.host, self.port, ex) raise @@ -238,23 +242,30 @@ def _keepalive(self): self._keepalive_greenlet = self.spawn_send_keepalive() def _agent_auth(self): - self.session.agent_auth(self.user) + with self._sess_lock: + THREAD_POOL.apply(self.session.agent_auth, args=(self.user,)) def _pkey_file_auth(self, pkey_file, password=None): - self.session.userauth_publickey_fromfile( - self.user, - pkey_file, - passphrase=password if password is not None else b'') + passphrase = password if password is not None else b'' + with self._sess_lock: + THREAD_POOL.apply( + self.session.userauth_publickey_fromfile, + args=(self.user, pkey_file), + kwds={'passphrase': passphrase}, + ) def _pkey_from_memory(self, pkey_data): - self.session.userauth_publickey_frommemory( - self.user, - pkey_data, - passphrase=self.password if self.password is not None else b'', - ) + passphrase = self.password if self.password is not None else b'' + with self._sess_lock: + THREAD_POOL.apply( + self.session.userauth_publickey_frommemory, + args=(self.user, pkey_data), + kwds={'passphrase': passphrase}, + ) def _password_auth(self): - self.session.userauth_password(self.user, self.password) + with self._sess_lock: + THREAD_POOL.apply(self.session.userauth_password, args=(self.user, self.password)) def _open_session(self): chan = self._eagain(self.session.open_session) @@ -300,18 +311,19 @@ def execute(self, cmd, use_pty=False, channel=None): self._eagain(channel.execute, cmd) return channel - def _read_output_to_buffer(self, read_func, _buffer): + def _read_output_to_buffer(self, read_func, _buffer, is_stderr=False): + _lock = self._chan_stderr_lock if is_stderr else self._chan_stdout_lock try: while True: - with self._chan_lock: + with _lock: size, data = read_func() - while size == LIBSSH2_ERROR_EAGAIN: + if size == LIBSSH2_ERROR_EAGAIN: self.poll() - with self._chan_lock: - size, data = read_func() + continue if size <= 0: break _buffer.write(data) + sleep() finally: _buffer.eof.set() @@ -335,11 +347,10 @@ def wait_finished(self, host_output, timeout=None): if channel is None: return self._eagain(channel.wait_eof, timeout=timeout) - # Close channel to indicate no more commands will be sent over it self.close_channel(channel) def close_channel(self, channel): - with self._chan_lock: + with self._chan_stdout_lock, self._chan_stderr_lock: logger.debug("Closing channel") self._eagain(channel.close) @@ -435,12 +446,15 @@ def sftp_put(self, sftp, local_file, remote_file): f_flags = LIBSSH2_FXF_CREAT | LIBSSH2_FXF_WRITE | LIBSSH2_FXF_TRUNC with self._sftp_openfh( sftp.open, remote_file, f_flags, mode) as remote_fh: + self._sess_lock.acquire() try: - self._sftp_put(remote_fh, local_file) + THREAD_POOL.apply(self._sftp_put, args=(remote_fh, local_file)) except SFTPProtocolError as ex: msg = "Error writing to remote file %s - %s" logger.error(msg, remote_file, ex) raise SFTPIOError(msg, remote_file, ex) + finally: + self._sess_lock.release() def mkdir(self, sftp, directory): """Make directory via SFTP channel. @@ -557,6 +571,9 @@ def scp_recv(self, remote_file, local_file, recurse=False, sftp=None, :type local_file: str :param recurse: Whether or not to recursively copy directories :type recurse: bool + :param sftp: The SFTP channel to use instead of creating a new one. + Only used when ``recurse`` is ``True``. + :type sftp: :py:class:`ssh2.sftp.SFTP` :param encoding: Encoding to use for file paths when recursion is enabled. :type encoding: str @@ -616,6 +633,9 @@ def scp_send(self, local_file, remote_file, recurse=False, sftp=None): :type remote_file: str :param recurse: Whether or not to descend into directories recursively. :type recurse: bool + :param sftp: The SFTP channel to use instead of creating a new one. + Only used when ``recurse`` is ``True``. + :type sftp: :py:class:`ssh2.sftp.SFTP` :raises: :py:class:`ValueError` when a directory is supplied to ``local_file`` and ``recurse`` is not set @@ -693,12 +713,15 @@ def sftp_get(self, sftp, remote_file, local_file): with self._sftp_openfh( sftp.open, remote_file, LIBSSH2_FXF_READ, LIBSSH2_SFTP_S_IRUSR) as remote_fh: + self._sess_lock.acquire() try: - self._sftp_get(remote_fh, local_file) + THREAD_POOL.apply(self._sftp_get, args=(remote_fh, local_file)) except SFTPProtocolError as ex: msg = "Error reading from remote file %s - %s" logger.error(msg, remote_file, ex) raise SFTPIOError(msg, remote_file, ex) + finally: + self._sess_lock.release() def get_exit_status(self, channel): """Get exit status code for channel or ``None`` if not ready. diff --git a/pssh/clients/native/tunnel.py b/pssh/clients/native/tunnel.py index 5748a3c2..ed937d8a 100644 --- a/pssh/clients/native/tunnel.py +++ b/pssh/clients/native/tunnel.py @@ -16,12 +16,8 @@ # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA import logging - +from queue import Queue from threading import Thread, Event -try: - from queue import Queue -except ImportError: - from Queue import Queue from gevent import spawn, joinall, get_hub, sleep from gevent.server import StreamServer @@ -29,7 +25,6 @@ from ...constants import DEFAULT_RETRIES - logger = logging.getLogger(__name__) @@ -193,7 +188,7 @@ def _read_forward_sock(self, forward_sock, channel): sleep(.01) continue try: - self._client._eagain_write(channel.write, data) + self._client.eagain_write(channel.write, data) except Exception as ex: logger.error("Error writing data to channel - %s", ex) raise diff --git a/pssh/clients/reader.py b/pssh/clients/reader.py index 2fb19094..c6b69b2a 100644 --- a/pssh/clients/reader.py +++ b/pssh/clients/reader.py @@ -17,31 +17,41 @@ from io import BytesIO -from gevent import sleep from gevent.event import Event from gevent.lock import RLock +class _Eof(Event): + def __init__(self, unread_data): + self._unread_data = unread_data + Event.__init__(self) + + def set(self): + self._unread_data.set() + Event.set(self) + + class ConcurrentRWBuffer(object): """Concurrent reader/writer of bytes for use from multiple greenlets. Supports both concurrent reading and writing. - Iterate on buffer object to read data, yielding greenlet if no data exists + Iterate on buffer object to read data, yielding event loop if no data exists until self.eof has been set. - Writers should ``eof.set()`` when finished writing data via ``write``. + Writers should call ``ConcurrentRWBuffer.eof.set()`` when finished writing data via ``write``. Readers can use ``read()`` to get any available data or ``None``. """ - __slots__ = ('_buffer', '_read_pos', '_write_pos', 'eof', '_lock') + __slots__ = ('_buffer', '_read_pos', '_write_pos', 'eof', '_lock', '_unread_data') def __init__(self): self._buffer = BytesIO() self._read_pos = 0 self._write_pos = 0 - self.eof = Event() self._lock = RLock() + self._unread_data = Event() + self.eof = _Eof(self._unread_data) def write(self, data): """Write data to buffer. @@ -53,14 +63,17 @@ def write(self, data): if not self._buffer.tell() == self._write_pos: self._buffer.seek(self._write_pos) self._write_pos += self._buffer.write(data) + if not self._unread_data.is_set() and self._read_pos < self._write_pos: + self._unread_data.set() def read(self): - """Read available data, or return None + """Read available data, or return None. :rtype: bytes """ with self._lock: if self._write_pos == 0 or self._read_pos == self._write_pos: + self._unread_data.clear() return elif not self._buffer.tell() == self._read_pos: self._buffer.seek(self._read_pos) @@ -73,5 +86,5 @@ def __iter__(self): data = self.read() if data: yield data - elif self._read_pos == self._write_pos: - sleep(.1) + else: + self._unread_data.wait() diff --git a/pssh/clients/ssh/parallel.py b/pssh/clients/ssh/parallel.py index bd7a11a8..a9193a4c 100644 --- a/pssh/clients/ssh/parallel.py +++ b/pssh/clients/ssh/parallel.py @@ -18,11 +18,10 @@ import logging from .single import SSHClient -from ..common import _validate_pkey_path, _validate_pkey from ..base.parallel import BaseParallelSSHClient +from ..common import _validate_pkey_path from ...constants import DEFAULT_RETRIES, RETRY_DELAY - logger = logging.getLogger(__name__) @@ -126,7 +125,6 @@ def __init__(self, hosts, user=None, password=None, port=22, pkey=None, identity_auth=identity_auth, ipv6_only=ipv6_only, ) - self.pkey = _validate_pkey(pkey) self.cert_file = _validate_pkey_path(cert_file) self.forward_ssh_agent = forward_ssh_agent self.gssapi_auth = gssapi_auth diff --git a/pssh/clients/ssh/single.py b/pssh/clients/ssh/single.py index 855ada98..a3c5b287 100644 --- a/pssh/clients/ssh/single.py +++ b/pssh/clients/ssh/single.py @@ -17,22 +17,22 @@ import logging -from gevent import sleep, spawn, Timeout as GTimeout, joinall +from gevent import sleep, spawn, Timeout as GTimeout, joinall, get_hub from ssh import options -from ssh.session import Session, SSH_READ_PENDING, SSH_WRITE_PENDING -from ssh.key import import_privkey_file, import_cert_file, copy_cert_to_privkey,\ - import_privkey_base64 -from ssh.exceptions import EOF from ssh.error_codes import SSH_AGAIN +from ssh.exceptions import EOF +from ssh.key import import_privkey_file, import_cert_file, copy_cert_to_privkey, \ + import_privkey_base64 +from ssh.session import Session, SSH_READ_PENDING, SSH_WRITE_PENDING from ..base.single import BaseSSHClient from ..common import _validate_pkey_path -from ...output import HostOutput -from ...exceptions import SessionError, Timeout from ...constants import DEFAULT_RETRIES, RETRY_DELAY - +from ...exceptions import SessionError, Timeout +from ...output import HostOutput logger = logging.getLogger(__name__) +THREAD_POOL = get_hub().threadpool class SSHClient(BaseSSHClient): @@ -129,7 +129,8 @@ def disconnect(self): self.sock.close() def _agent_auth(self): - self.session.userauth_agent(self.user) + with self._sess_lock: + THREAD_POOL.apply(self.session.userauth_agent, args=(self.user,)) def _keepalive(self): pass @@ -159,54 +160,67 @@ def _init_session(self, retries=1): return self._connect_init_session_retry(retries=retries+1) msg = "Error connecting to host %s:%s - %s" logger.error(msg, self.host, self.port, ex) + if not self.sock.closed: + self.sock.close() raise ex def _session_connect(self): - self.session.connect() + with self._sess_lock: + THREAD_POOL.apply(self.session.connect) def auth(self): if self.gssapi_auth or (self.gssapi_server_identity or self.gssapi_client_identity): try: - return self.session.userauth_gssapi() + with self._sess_lock: + return THREAD_POOL.apply(self.session.userauth_gssapi) except Exception as ex: logger.error( "GSSAPI authentication with server id %s and client id %s failed - %s", self.gssapi_server_identity, self.gssapi_client_identity, ex) + raise return super(SSHClient, self).auth() def _password_auth(self): - self.session.userauth_password(self.user, self.password) + with self._sess_lock: + THREAD_POOL.apply(self.session.userauth_password, args=(self.user, self.password)) def _pkey_file_auth(self, pkey_file, password=None): - pkey = import_privkey_file(pkey_file, passphrase=password if password is not None else '') + passphrase = password if password is not None else '' + pkey = THREAD_POOL.apply( + import_privkey_file, args=(pkey_file,), kwds={'passphrase': passphrase}) return self._pkey_obj_auth(pkey) def _pkey_obj_auth(self, pkey): if self.cert_file is not None: logger.debug("Certificate file set - trying certificate authentication") - self._import_cert_file(pkey) - self.session.userauth_publickey(pkey) + THREAD_POOL.apply(self._import_cert_file, args=(pkey,)) + with self._sess_lock: + THREAD_POOL.apply(self.session.userauth_publickey, args=(pkey,)) def _pkey_from_memory(self, pkey_data): - _pkey = import_privkey_base64( - pkey_data, - passphrase=self.password if self.password is not None else b'') + passphrase = self.password if self.password is not None else b'' + _pkey = THREAD_POOL.apply( + import_privkey_base64, + args=(pkey_data,), + kwds={'passphrase': passphrase}, + ) return self._pkey_obj_auth(_pkey) def _import_cert_file(self, pkey): cert_key = import_cert_file(self.cert_file) - self.session.userauth_try_publickey(cert_key) + with self._sess_lock: + self.session.userauth_try_publickey(cert_key) copy_cert_to_privkey(cert_key, pkey) logger.debug("Imported certificate file %s for pkey %s", self.cert_file, self.pkey) - def _shell(self, channel): - return self._eagain(channel.request_shell) + def _shell(self, chan): + return self._eagain(chan.request_shell) def _open_session(self): - channel = self.session.channel_new() - channel.set_blocking(0) - self._eagain(channel.open_session) - return channel + chan = self.session.channel_new() + chan.set_blocking(0) + self._eagain(chan.open_session) + return chan def open_session(self): """Open new channel from session.""" @@ -238,25 +252,22 @@ def execute(self, cmd, use_pty=False, channel=None): if use_pty: self._eagain(channel.request_pty, timeout=self.timeout) logger.debug("Executing command '%s'", cmd) - self._eagain(channel.request_exec, cmd, timeout=self.timeout) + self._eagain(channel.request_exec, cmd) return channel def _read_output_to_buffer(self, channel, _buffer, is_stderr=False): - while True: - self.poll() - try: - size, data = channel.read_nonblocking(is_stderr=is_stderr) - except EOF: - _buffer.eof.set() - sleep(.1) - return - if size > 0: - _buffer.write(data) - else: - # Yield event loop to other greenlets if we have no data to - # send back, meaning the generator does not yield and can there - # for block other generators/greenlets from running. - sleep(.1) + try: + while True: + self.poll() + try: + size, data = channel.read_nonblocking(is_stderr=is_stderr) + except EOF: + return + if size > 0: + _buffer.write(data) + sleep() + finally: + _buffer.eof.set() def wait_finished(self, host_output, timeout=None): """Wait for EOF from channel and close channel. @@ -313,7 +324,7 @@ def close_channel(self, channel): :type channel: :py:class:`ssh.channel.Channel` """ logger.debug("Closing channel") - self._eagain(channel.close, timeout=self.timeout) + self._eagain(channel.close) def poll(self, timeout=None): """ssh-python based co-operative gevent poll on session socket. diff --git a/pssh/output.py b/pssh/output.py index c7e9375e..0baa7f76 100644 --- a/pssh/output.py +++ b/pssh/output.py @@ -44,7 +44,7 @@ def __init__(self, reader, rw_buffer): """ :param reader: Greenlet reading data from channel and writing to rw_buffer :type reader: :py:class:`gevent.Greenlet` - :param rw_bufffer: Read/write buffer + :param rw_buffer: Read/write buffer :type rw_buffer: :py:class:`pssh.clients.reader.ConcurrentRWBuffer` """ self.reader = reader diff --git a/pssh/utils.py b/pssh/utils.py index 3ee6ef2b..230f47a1 100644 --- a/pssh/utils.py +++ b/pssh/utils.py @@ -46,5 +46,5 @@ def enable_host_logger(): def enable_debug_logger(): - """Enable debug logging for the library to sdout.""" + """Enable debug logging for the library to stdout.""" return enable_logger(logger, level=logging.DEBUG) diff --git a/tests/native/test_parallel_client.py b/tests/native/test_parallel_client.py index 5b760985..38bac397 100644 --- a/tests/native/test_parallel_client.py +++ b/tests/native/test_parallel_client.py @@ -36,7 +36,7 @@ AuthenticationException, ConnectionErrorException, \ HostArgumentException, SFTPError, SFTPIOError, Timeout, SCPError, \ PKeyFileError, ShellError, HostArgumentError, NoIPv6AddressFoundError, \ - AuthenticationError + AuthenticationError, HostConfigError from pssh.output import HostOutput from .base_ssh2_case import PKEY_FILENAME, PUB_FILE @@ -44,6 +44,8 @@ class ParallelSSHClientTest(unittest.TestCase): + server = None + client = None @classmethod def setUpClass(cls): @@ -146,12 +148,7 @@ def test_client_shells_join_timeout(self): """ self.client.run_shell_commands(shells, cmd) self.assertRaises(Timeout, self.client.join_shells, shells, timeout=.1) - try: - self.client.join_shells(shells, timeout=.1) - except Timeout: - pass - else: - raise AssertionError + self.assertRaises(Timeout, self.client.join_shells, shells, timeout=.1) self.client.join_shells(shells, timeout=1) stdout = list(shells[0].stdout) self.assertListEqual(stdout, [self.resp, self.resp]) @@ -963,6 +960,11 @@ def test_host_config_bad_entries(self): self.assertRaises(ValueError, ParallelSSHClient, hosts, host_config=host_config) self.assertRaises(ValueError, ParallelSSHClient, iter(hosts), host_config=host_config) + def test_invalid_host_config(self): + hosts = ['localhost', 'localhost'] + host_config = {'localhost': HostConfig(), 'localhost2': HostConfig()} + self.assertRaises(HostConfigError, ParallelSSHClient, hosts, host_config=host_config) + def test_pssh_client_override_allow_agent_authentication(self): """Test running command with allow_agent set to False""" client = ParallelSSHClient([self.host], @@ -1038,9 +1040,10 @@ def test_per_host_dict_args(self): server.start_server() hosts = [self.host, host2, host3] hosts_gen = (h for h in hosts) - host_args = [dict(zip(('host_arg1', 'host_arg2',), - ('arg1-%s' % (i,), 'arg2-%s' % (i,),))) - for i, _ in enumerate(hosts)] + host_args = [{'host_arg1': 'arg1-%s' % (i,), + 'host_arg2': 'arg2-%s' % (i,), + } + for i in range(len(hosts))] cmd = 'echo %(host_arg1)s %(host_arg2)s' client = ParallelSSHClient(hosts, port=self.port, pkey=self.user_key, @@ -1116,7 +1119,7 @@ def test_pty(self): expected_stdout = [] # With a PTY, stdout and stderr are combined into stdout self.assertEqual(expected_stderr, stdout) - self.assertEqual([], stderr) + self.assertEqual(expected_stdout, stderr) self.assertTrue(exit_code == 0) def test_output_attributes(self): @@ -1167,7 +1170,7 @@ def test_retries(self): client = ParallelSSHClient(['127.0.0.100'], port=self.port, num_retries=2, retry_delay=.1) self.assertRaises(ConnectionErrorException, client.run_command, self.cmd) - host = ''.join([random.choice(string.ascii_letters) for n in range(8)]) + host = ''.join([random.choice(string.ascii_letters) for _ in range(8)]) client.hosts = [host] self.assertRaises(UnknownHostException, client.run_command, self.cmd) @@ -1234,7 +1237,7 @@ def test_setting_hosts(self): def test_unknown_host_failure(self): """Test connection error failure case - ConnectionErrorException""" - host = ''.join([random.choice(string.ascii_letters) for n in range(8)]) + host = ''.join([random.choice(string.ascii_letters) for _ in range(8)]) client = ParallelSSHClient([host], port=self.port, num_retries=1) self.assertRaises(UnknownHostException, client.run_command, self.cmd) @@ -1266,13 +1269,14 @@ def test_join_timeout_subset_read(self): hosts = [self.host, self.host] cmd = 'sleep %(i)s; echo %(i)s' host_args = [{'i': '0.1'}, - {'i': '0.25'}, + {'i': '0.35'}, ] client = ParallelSSHClient(hosts, port=self.port, pkey=self.user_key) + joinall(client.connect_auth(), raise_error=True) output = client.run_command(cmd, host_args=host_args) try: - client.join(output, timeout=.2) + client.join(output, timeout=.3) except Timeout as ex: finished_output = ex.args[2] unfinished_output = ex.args[3] @@ -1285,7 +1289,7 @@ def test_join_timeout_subset_read(self): # Should not timeout client.join(unfinished_output, timeout=2) rest_stdout = list(unfinished_output[0].stdout) - self.assertEqual(rest_stdout, ['0.25']) + self.assertEqual(rest_stdout, ['0.35']) def test_join_timeout_set_no_timeout(self): client = ParallelSSHClient([self.host], port=self.port, @@ -1303,7 +1307,7 @@ def test_read_timeout(self): self.assertFalse(client.finished(output)) client.join(output) # import ipdb; ipdb.set_trace() - for host_out in output: + for _ in output: stdout = list(output[0].stdout) self.assertEqual(len(stdout), 3) self.assertTrue(client.finished(output)) @@ -1320,6 +1324,7 @@ def test_partial_read_timeout_close_cmd(self): except Timeout: pass self.assertTrue(len(stdout) > 0) + sleep(.15) output[0].client.close_channel(output[0].channel) self.client.join(output) # Should not timeout @@ -1389,7 +1394,7 @@ def test_timeout_file_read(self): self.assertRaises(Timeout, self.client.join, output, timeout=.1) for host_out in output: try: - for line in host_out.stdout: + for _ in host_out.stdout: pass except Timeout: pass @@ -1425,7 +1430,6 @@ def test_scp_send_dir(self): with open(local_filename, 'w') as file_h: file_h.writelines([test_file_data + os.linesep]) remote_filename = os.path.sep.join([remote_test_dir, remote_filepath]) - remote_file_abspath = os.path.expanduser('~/' + remote_filename) remote_test_dir_abspath = os.path.expanduser('~/' + remote_test_dir) try: cmds = self.client.scp_send(local_filename, remote_filename) @@ -1540,21 +1544,20 @@ def test_scp_bad_copy_args(self): def test_scp_send_exc(self): client = ParallelSSHClient([self.host], pkey=self.user_key, num_retries=1) - def _scp_send(*args): + + def _scp_send(*_): raise Exception - def _client_send(*args): - return client._handle_greenlet_exc(_scp_send, 'fake') - client._scp_send = _client_send + client._scp_send = _scp_send cmds = client.scp_send('local_file', 'remote_file') self.assertRaises(Exception, joinall, cmds, raise_error=True) def test_scp_recv_exc(self): client = ParallelSSHClient([self.host], pkey=self.user_key, num_retries=1) - def _scp_recv(*args): + + def _scp_recv(*_): raise Exception - def _client_recv(*args): - return client._handle_greenlet_exc(_scp_recv, 'fake') - client._scp_recv = _client_recv + + client._scp_recv = _scp_recv cmds = client.scp_recv('remote_file', 'local_file') self.assertRaises(Exception, joinall, cmds, raise_error=True) @@ -1758,6 +1761,7 @@ def test_client_disconnect(self): client.join(consume_output=True) single_client = list(client._host_clients.values())[0] del client + # client.disconnect() self.assertEqual(single_client.session, None) def test_client_disconnect_error(self): @@ -1804,7 +1808,7 @@ def read_stream_dt(self, host_out, stream, read_timeout): now = datetime.now() timed_out = False try: - for line in stream: + for _ in stream: pass except Timeout: timed_out = True @@ -1875,7 +1879,6 @@ def test_read_stdout_timeout_stderr_no_timeout(self): self.assertTrue(dt.total_seconds() < read_timeout) def test_read_multi_same_hosts(self): - hosts = [self.host, self.host] outputs = [ self.client.run_command(self.cmd), self.client.run_command(self.cmd), diff --git a/tests/native/test_single_client.py b/tests/native/test_single_client.py index 61df944e..0d6228bb 100644 --- a/tests/native/test_single_client.py +++ b/tests/native/test_single_client.py @@ -22,7 +22,7 @@ from datetime import datetime from hashlib import sha256 from tempfile import NamedTemporaryFile -from unittest.mock import MagicMock, call, patch +from unittest.mock import MagicMock, call, patch, call import pytest from gevent import sleep, spawn, Timeout as GTimeout, socket @@ -101,6 +101,9 @@ def test_ipv6(self, gsocket): host = '::1' addr_info = ('::1', self.port, 0, 0) gsocket.IPPROTO_TCP = socket.IPPROTO_TCP + gsocket.AF_INET6 = socket.AF_INET6 + gsocket.AF_INET = socket.AF_INET + gsocket.SocketKind = socket.SocketKind gsocket.socket = MagicMock() _sock = MagicMock() gsocket.socket.return_value = _sock @@ -112,16 +115,20 @@ def test_ipv6(self, gsocket): getaddrinfo.return_value = [( socket.AF_INET6, socket.SocketKind.SOCK_STREAM, socket.IPPROTO_TCP, '', addr_info)] with raises(ConnectionError): - client = SSHClient(host, port=self.port, pkey=self.user_key, - num_retries=1) + SSHClient(host, port=self.port, pkey=self.user_key, + num_retries=1) getaddrinfo.assert_called_once_with(host, self.port, proto=socket.IPPROTO_TCP) sock_con.assert_called_once_with(addr_info) @patch('pssh.clients.base.single.socket') def test_multiple_available_addr(self, gsocket): - host = '127.0.0.1' - addr_info = (host, self.port) + host = 'localhost' + ipv6_addr_info = ('::1', self.port, 0, 0) + ipv4_addr_info = ('127.0.0.1', self.port) gsocket.IPPROTO_TCP = socket.IPPROTO_TCP + gsocket.AF_INET6 = socket.AF_INET6 + gsocket.AF_INET = socket.AF_INET + gsocket.SocketKind = socket.SocketKind gsocket.socket = MagicMock() _sock = MagicMock() gsocket.socket.return_value = _sock @@ -131,14 +138,44 @@ def test_multiple_available_addr(self, gsocket): getaddrinfo = MagicMock() gsocket.getaddrinfo = getaddrinfo getaddrinfo.return_value = [ - (socket.AF_INET, socket.SocketKind.SOCK_STREAM, socket.IPPROTO_TCP, '', addr_info), - (socket.AF_INET, socket.SocketKind.SOCK_STREAM, socket.IPPROTO_TCP, '', addr_info), + (socket.AF_INET6, socket.SocketKind.SOCK_STREAM, socket.IPPROTO_TCP, '', ipv6_addr_info), + (socket.AF_INET, socket.SocketKind.SOCK_STREAM, socket.IPPROTO_TCP, '', ipv4_addr_info), ] with raises(ConnectionError): - client = SSHClient(host, port=self.port, pkey=self.user_key, - num_retries=1) + SSHClient(host, port=self.port, pkey=self.user_key, + num_retries=1) + expected_calls = [call(ipv6_addr_info), call(ipv4_addr_info)] getaddrinfo.assert_called_with(host, self.port, proto=socket.IPPROTO_TCP) assert sock_con.call_count == len(getaddrinfo.return_value) + assert sock_con.call_args_list == expected_calls + + @patch('pssh.clients.base.single.socket') + def test_multiple_available_addr_ipv6(self, gsocket): + host = 'localhost' + ipv6_addr_info = ('::1', self.port, 0, 0) + ipv4_addr_info = ('127.0.0.1', self.port) + gsocket.IPPROTO_TCP = socket.IPPROTO_TCP + gsocket.AF_INET6 = socket.AF_INET6 + gsocket.AF_INET = socket.AF_INET + gsocket.SocketKind = socket.SocketKind + gsocket.socket = MagicMock() + _sock = MagicMock() + gsocket.socket.return_value = _sock + sock_con = MagicMock() + sock_con.side_effect = ConnectionRefusedError + _sock.connect = sock_con + getaddrinfo = MagicMock() + gsocket.getaddrinfo = getaddrinfo + getaddrinfo.return_value = [ + (socket.AF_INET6, socket.SocketKind.SOCK_STREAM, socket.IPPROTO_TCP, '', ipv6_addr_info), + (socket.AF_INET, socket.SocketKind.SOCK_STREAM, socket.IPPROTO_TCP, '', ipv4_addr_info), + ] + with raises(ConnectionError): + SSHClient(host, port=self.port, pkey=self.user_key, + num_retries=1, + ipv6_only=True) + getaddrinfo.assert_called_once_with(host, self.port, proto=socket.IPPROTO_TCP) + sock_con.assert_called_once_with(ipv6_addr_info) def test_no_ipv6(self): try: diff --git a/tests/native/test_tunnel.py b/tests/native/test_tunnel.py index 8137ddd9..da4c7d4a 100644 --- a/tests/native/test_tunnel.py +++ b/tests/native/test_tunnel.py @@ -15,14 +15,13 @@ # License along with this library; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA -import gc import os -import time import unittest from datetime import datetime from getpass import getuser -from sys import version_info +import gc +import time from gevent import sleep, spawn, Timeout as GTimeout from ssh2.exceptions import SocketSendError, SocketRecvError @@ -39,7 +38,7 @@ class TunnelTest(unittest.TestCase): @classmethod def setUpClass(cls): - _mask = int('0600') if version_info <= (2,) else 0o600 + _mask = 0o600 os.chmod(PKEY_FILENAME, _mask) cls.port = 2225 cls.cmd = 'echo me' diff --git a/tests/ssh/test_parallel_client.py b/tests/ssh/test_parallel_client.py index 401fd791..b922dee3 100644 --- a/tests/ssh/test_parallel_client.py +++ b/tests/ssh/test_parallel_client.py @@ -33,6 +33,8 @@ class LibSSHParallelTest(unittest.TestCase): + server = None + client = None @classmethod def setUpClass(cls): @@ -237,12 +239,7 @@ def test_pssh_client_hosts_list_part_failure(self): self.assertTrue(output[1].exception is not None) self.assertEqual(output[1].host, hosts[1]) self.assertEqual(output[1].exception.args[-2], hosts[1]) - try: - raise output[1].exception - except ConnectionErrorException: - pass - else: - raise Exception("Expected ConnectionError, got %s instead" % (output[1].exception,)) + self.assertIsInstance(output[1].exception, ConnectionErrorException) def test_pssh_client_timeout(self): # 1ms timeout @@ -316,13 +313,12 @@ def test_connection_error_exception(self): client.join(output) self.assertIsInstance(output[0].exception, ConnectionErrorException) self.assertEqual(output[0].host, host) + self.assertIsInstance(output[0].exception, ConnectionErrorException) try: raise output[0].exception except ConnectionErrorException as ex: self.assertEqual(ex.args[-2], host) self.assertEqual(ex.args[-1], port) - else: - raise Exception("Expected ConnectionErrorException") def test_bad_pkey_path(self): self.assertRaises(PKeyFileError, ParallelSSHClient, [self.host], port=self.port, @@ -334,12 +330,9 @@ def test_multiple_single_quotes_in_cmd(self): output = self.client.run_command("echo 'me' 'and me'") stdout = list(output[0].stdout) expected = 'me and me' - self.assertTrue(len(stdout)==1, - msg="Got incorrect number of lines in output - %s" % (stdout,)) + self.assertTrue(len(stdout) == 1) self.assertEqual(output[0].exit_code, 0) - self.assertEqual(expected, stdout[0], - msg="Got unexpected output. Expected %s, got %s" % ( - expected, stdout[0],)) + self.assertEqual(expected, stdout[0]) def test_backtics_in_cmd(self): """Test running command with backtics in it""" @@ -353,9 +346,7 @@ def test_multiple_shell_commands(self): stdout = list(output[0].stdout) expected = ["me", "and", "me"] self.assertEqual(output[0].exit_code, 0) - self.assertEqual(expected, stdout, - msg="Got unexpected output. Expected %s, got %s" % ( - expected, stdout,)) + self.assertEqual(expected, stdout) def test_escaped_quotes(self): """Test escaped quotes in shell variable are handled correctly""" @@ -363,9 +354,7 @@ def test_escaped_quotes(self): stdout = list(output[0].stdout) expected = ['--flags="this"'] self.assertEqual(output[0].exit_code, 0) - self.assertEqual(expected, stdout, - msg="Got unexpected output. Expected %s, got %s" % ( - expected, stdout,)) + self.assertEqual(expected, stdout) def test_read_timeout(self): client = ParallelSSHClient([self.host], port=self.port, @@ -392,7 +381,7 @@ def test_timeout_file_read(self): self.assertRaises(Timeout, self.client.join, output, timeout=.1) for host_out in output: try: - for line in host_out.stdout: + for _ in host_out.stdout: pass except Timeout: pass diff --git a/tests/ssh/test_single_client.py b/tests/ssh/test_single_client.py index b109ba1f..9dda6006 100644 --- a/tests/ssh/test_single_client.py +++ b/tests/ssh/test_single_client.py @@ -51,6 +51,7 @@ def test_execute(self): stderr = list(host_out.stderr) expected = [self.resp] self.assertEqual(expected, output) + self.assertEqual(len(stderr), 0) exit_code = host_out.channel.get_exit_status() self.assertEqual(exit_code, 0) @@ -186,11 +187,11 @@ def test_identity_auth_failure(self): def test_password_auth_failure(self): try: - client = SSHClient(self.host, port=self.port, num_retries=1, - allow_agent=False, - identity_auth=False, - password='blah blah blah', - ) + SSHClient(self.host, port=self.port, num_retries=1, + allow_agent=False, + identity_auth=False, + password='blah blah blah', + ) except AuthenticationException as ex: self.assertIsInstance(ex.args[3], AuthenticationDenied) else: @@ -229,7 +230,7 @@ def test_open_session_timeout(self): num_retries=2, timeout=.1) - def _session(timeout=None): + def _session(_=None): sleep(.2) client.open_session = _session self.assertRaises(GTimeout, client.run_command, self.cmd) @@ -244,6 +245,7 @@ def test_client_read_timeout(self): def test_open_session_exc(self): class Error(Exception): pass + def _session(): raise Error client = SSHClient(self.host, port=self.port, @@ -255,6 +257,7 @@ def _session(): def test_session_connect_exc(self): class Error(Exception): pass + def _con(): raise Error client = SSHClient(self.host, port=self.port, @@ -282,8 +285,10 @@ def test_no_auth(self): def test_agent_auth_failure(self): class UnknownError(Exception): pass + def _agent_auth_unk(): raise UnknownError + def _agent_auth_agent_err(): raise AuthenticationDenied client = SSHClient(self.host, port=self.port, @@ -316,6 +321,7 @@ def _agent_auth(): def test_disconnect_exc(self): class DiscError(Exception): pass + def _disc(): raise DiscError client = SSHClient(self.host, port=self.port, diff --git a/tests/test_reader.py b/tests/test_reader.py index b353f80c..8e4c643a 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -16,12 +16,12 @@ # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA import unittest - -from random import random, randint, randrange +from random import randrange from string import ascii_letters -from gevent.queue import Queue from gevent import spawn, sleep +from gevent.queue import Queue + from pssh.clients.reader import ConcurrentRWBuffer @@ -55,24 +55,25 @@ def test_multi_write_read(self): def test_concurrent_rw(self): written_data = Queue() + def _writer(_buffer): while True: data = b"".join([ascii_letters[m].encode() for m in [randrange(0, 8) for _ in range(8)]]) _buffer.write(data) written_data.put(data) - sleep(0.2) + sleep(0.05) writer = spawn(_writer, self.buffer) writer.start() - sleep(0.5) + sleep(0.1) data = self.buffer.read() _data = b"" - while written_data.qsize() !=0 : + while written_data.qsize() != 0: _data += written_data.get() self.assertEqual(data, _data) - sleep(0.5) + sleep(0.1) data = self.buffer.read() _data = b"" - while written_data.qsize() !=0 : + while written_data.qsize() != 0: _data += written_data.get() self.assertEqual(data, _data) writer.kill()