Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
LanderOtto committed Aug 5, 2024
1 parent 0c49622 commit 7d97e06
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 48 deletions.
100 changes: 54 additions & 46 deletions streamflow/deployment/connector/ssh.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import base64
import logging
import os
from abc import ABC
Expand All @@ -19,7 +20,11 @@
from streamflow.deployment.connector.base import BaseConnector
from streamflow.deployment.stream import StreamReaderWrapper, StreamWriterWrapper
from streamflow.deployment.template import CommandTemplateMap
from streamflow.log_handler import logger
from streamflow.log_handler import logger, defaultStreamHandler

asyncssh.logging.logger.setLevel(logging.DEBUG)
asyncssh.logging.logger.set_debug_level(2)
asyncssh.logging.logger.logger.addHandler(defaultStreamHandler)


def _parse_hostname(hostname):
Expand All @@ -45,10 +50,9 @@ def __init__(
self._max_concurrent_sessions: int = max_concurrent_sessions
self._ssh_connection: asyncssh.SSHClientConnection | None = None
self._connecting = False
self._retry_delay = retry_delay
self._sleeping = False
self._connect_event: asyncio.Event = asyncio.Event()
self.retries = retries
self.retries: int = retries
self.retry_delay: int = retry_delay
self.ssh_attempts: int = 0

async def get_connection(self) -> asyncssh.SSHClientConnection:
Expand All @@ -57,15 +61,12 @@ async def get_connection(self) -> asyncssh.SSHClientConnection:
self._connecting = True
try:
self._ssh_connection = await self._get_connection(self._config)
except (ConnectionError, ConnectionLost) as e:
except (ConnectionError, ConnectionLost, asyncssh.Error) as e:
if logger.isEnabledFor(logging.WARNING):
logger.warning(
f"Connection to {self._config.hostname} failed: {e}."
)
self._connect_event.set()
raise
except asyncssh.Error:
self._connect_event.set()
self.close()
raise
self._connect_event.set()
Expand Down Expand Up @@ -129,21 +130,11 @@ def close(self):
self._connect_event.clear()

def full(self) -> bool:
# if it is sleeping then is full.
# if it has already a connection and there aren't free channels then is full
# otherwise is not full
return self._sleeping or (
return (
self._ssh_connection
and len(self._ssh_connection._channels) >= self._max_concurrent_sessions
)

async def sleep(self, cond):
self._sleeping = True
await asyncio.sleep(self._retry_delay)
self._sleeping = False
async with cond:
cond.notify_all()


class SSHContextManager:
def __init__(
Expand Down Expand Up @@ -172,10 +163,19 @@ def __init__(
async def __aenter__(self) -> asyncssh.SSHClientProcess:
async with self._condition:
while True:
for context in self._contexts:
if not context.full():
ssh_connection = await context.get_connection()
if (
len(free_contexts := [c for c in self._contexts if not c.full()])
== 0
):
await self._condition.wait()
else:
for context in free_contexts:
if context.ssh_attempts >= context.retries:
raise WorkflowExecutionException(
"No more connection attempts available"
)
try:
ssh_connection = await context.get_connection()
self._selected_context = context
self._proc = await ssh_connection.create_process(
self.command,
Expand All @@ -187,12 +187,26 @@ async def __aenter__(self) -> asyncssh.SSHClientProcess:
)
await self._proc.__aenter__()
return self._proc
except ChannelOpenError as coe:
logger.warning(
f"Error opening SSH session to {context.get_hostname()} "
f"to execute command `{self.command}`: [{coe.code}] {coe.reason}"
)
await self._condition.wait()
except (
ConnectionError,
ConnectionLost,
ChannelOpenError,
) as coe:
if isinstance(coe, ChannelOpenError):
logger.warning(
f"Error opening SSH session to {context.get_hostname()} "
f"to execute command `{self.command}`: [{coe.code}] {coe.reason}"
)
self._selected_context = None
if self._proc:
await self._proc.__aexit__(None, None, None)
self._proc = None
context.close()
context.ssh_attempts += 1
logger.warning(f"Attempt {context.ssh_attempts}")
except Exception as e:
logger.warning(f"Fatal error: {e}")
raise

async def __aexit__(self, exc_type, exc_val, exc_tb):
async with self._condition:
Expand Down Expand Up @@ -730,22 +744,14 @@ async def run(
workdir=workdir,
)
command = utils.encode_command(command)
async with self._get_ssh_client_process(
location=location.name,
command=command,
stderr=asyncio.subprocess.STDOUT,
environment=environment,
) as proc:
result = await proc.wait(timeout=timeout)
else:
async with self._get_ssh_client_process(
location=location.name,
command=command,
stderr=asyncio.subprocess.STDOUT,
environment=environment,
) as proc:
result = await proc.wait(timeout=timeout)
if result.returncode is None or result.returncode != 0:
async with self._get_ssh_client_process(
location=location.name,
command=command,
stderr=asyncio.subprocess.STDOUT,
environment=environment,
) as proc:
result = await proc.wait(timeout=timeout)
if result.returncode is None or result.returncode not in (0, 1):
logger.info(
f"CMD: {command}"
f"\n\tcommand: {result.command}"
Expand All @@ -756,8 +762,10 @@ async def run(
f"\n\tstdout: {result.stdout}"
f"\n\treturncode: {result.returncode}"
)
if result.returncode is None:
result.returncode = 9999
if result.returncode is None:
raise WorkflowExecutionException(
f"Connection failed executing command {base64.b64decode(command)}"
)
return (result.stdout.strip(), result.returncode) if capture_output else None

async def undeploy(self, external: bool) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ async def get_ssh_deployment_config(_context: StreamFlowContext):
"hostname": "127.0.0.1:2222",
"sshKey": f.name,
"username": "linuxserver.io",
"retries": 0,
"retryDelay": 1,
"retries": 2,
"retryDelay": 5,
}
],
"maxConcurrentSessions": 10,
Expand Down

0 comments on commit 7d97e06

Please sign in to comment.