diff --git a/streamflow/deployment/connector/ssh.py b/streamflow/deployment/connector/ssh.py index c7933bc6..0e44d30f 100644 --- a/streamflow/deployment/connector/ssh.py +++ b/streamflow/deployment/connector/ssh.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import base64 import logging import os from abc import ABC @@ -19,7 +20,11 @@ from streamflow.deployment.connector.base import BaseConnector from streamflow.deployment.stream import StreamReaderWrapper, StreamWriterWrapper from streamflow.deployment.template import CommandTemplateMap -from streamflow.log_handler import logger +from streamflow.log_handler import logger, defaultStreamHandler + +asyncssh.logging.logger.setLevel(logging.DEBUG) +asyncssh.logging.logger.set_debug_level(2) +asyncssh.logging.logger.logger.addHandler(defaultStreamHandler) def _parse_hostname(hostname): @@ -45,10 +50,9 @@ def __init__( self._max_concurrent_sessions: int = max_concurrent_sessions self._ssh_connection: asyncssh.SSHClientConnection | None = None self._connecting = False - self._retry_delay = retry_delay - self._sleeping = False self._connect_event: asyncio.Event = asyncio.Event() - self.retries = retries + self.retries: int = retries + self.retry_delay: int = retry_delay self.ssh_attempts: int = 0 async def get_connection(self) -> asyncssh.SSHClientConnection: @@ -57,15 +61,12 @@ async def get_connection(self) -> asyncssh.SSHClientConnection: self._connecting = True try: self._ssh_connection = await self._get_connection(self._config) - except (ConnectionError, ConnectionLost) as e: + except (ConnectionError, ConnectionLost, asyncssh.Error) as e: if logger.isEnabledFor(logging.WARNING): logger.warning( f"Connection to {self._config.hostname} failed: {e}." ) self._connect_event.set() - raise - except asyncssh.Error: - self._connect_event.set() self.close() raise self._connect_event.set() @@ -129,21 +130,11 @@ def close(self): self._connect_event.clear() def full(self) -> bool: - # if it is sleeping then is full. - # if it has already a connection and there aren't free channels then is full - # otherwise is not full - return self._sleeping or ( + return ( self._ssh_connection and len(self._ssh_connection._channels) >= self._max_concurrent_sessions ) - async def sleep(self, cond): - self._sleeping = True - await asyncio.sleep(self._retry_delay) - self._sleeping = False - async with cond: - cond.notify_all() - class SSHContextManager: def __init__( @@ -172,10 +163,19 @@ def __init__( async def __aenter__(self) -> asyncssh.SSHClientProcess: async with self._condition: while True: - for context in self._contexts: - if not context.full(): - ssh_connection = await context.get_connection() + if ( + len(free_contexts := [c for c in self._contexts if not c.full()]) + == 0 + ): + await self._condition.wait() + else: + for context in free_contexts: + if context.ssh_attempts >= context.retries: + raise WorkflowExecutionException( + "No more connection attempts available" + ) try: + ssh_connection = await context.get_connection() self._selected_context = context self._proc = await ssh_connection.create_process( self.command, @@ -187,12 +187,26 @@ async def __aenter__(self) -> asyncssh.SSHClientProcess: ) await self._proc.__aenter__() return self._proc - except ChannelOpenError as coe: - logger.warning( - f"Error opening SSH session to {context.get_hostname()} " - f"to execute command `{self.command}`: [{coe.code}] {coe.reason}" - ) - await self._condition.wait() + except ( + ConnectionError, + ConnectionLost, + ChannelOpenError, + ) as coe: + if isinstance(coe, ChannelOpenError): + logger.warning( + f"Error opening SSH session to {context.get_hostname()} " + f"to execute command `{self.command}`: [{coe.code}] {coe.reason}" + ) + self._selected_context = None + if self._proc: + await self._proc.__aexit__(None, None, None) + self._proc = None + context.close() + context.ssh_attempts += 1 + logger.warning(f"Attempt {context.ssh_attempts}") + except Exception as e: + logger.warning(f"Fatal error: {e}") + raise async def __aexit__(self, exc_type, exc_val, exc_tb): async with self._condition: @@ -730,22 +744,14 @@ async def run( workdir=workdir, ) command = utils.encode_command(command) - async with self._get_ssh_client_process( - location=location.name, - command=command, - stderr=asyncio.subprocess.STDOUT, - environment=environment, - ) as proc: - result = await proc.wait(timeout=timeout) - else: - async with self._get_ssh_client_process( - location=location.name, - command=command, - stderr=asyncio.subprocess.STDOUT, - environment=environment, - ) as proc: - result = await proc.wait(timeout=timeout) - if result.returncode is None or result.returncode != 0: + async with self._get_ssh_client_process( + location=location.name, + command=command, + stderr=asyncio.subprocess.STDOUT, + environment=environment, + ) as proc: + result = await proc.wait(timeout=timeout) + if result.returncode is None or result.returncode not in (0, 1): logger.info( f"CMD: {command}" f"\n\tcommand: {result.command}" @@ -756,8 +762,10 @@ async def run( f"\n\tstdout: {result.stdout}" f"\n\treturncode: {result.returncode}" ) - if result.returncode is None: - result.returncode = 9999 + if result.returncode is None: + raise WorkflowExecutionException( + f"Connection failed executing command {base64.b64decode(command).decode('utf-8')}" + ) return (result.stdout.strip(), result.returncode) if capture_output else None async def undeploy(self, external: bool) -> None: diff --git a/tests/utils/deployment.py b/tests/utils/deployment.py index 1c9fc2b2..7cccd26d 100644 --- a/tests/utils/deployment.py +++ b/tests/utils/deployment.py @@ -239,8 +239,8 @@ async def get_ssh_deployment_config(_context: StreamFlowContext): "hostname": "127.0.0.1:2222", "sshKey": f.name, "username": "linuxserver.io", - "retries": 0, - "retryDelay": 1, + "retries": 2, + "retryDelay": 5, } ], "maxConcurrentSessions": 10,