Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved ssh retry mechanism #501

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion streamflow/deployment/connector/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,7 +1303,7 @@ async def get_available_locations(
(json_end := output.rfind("}")) != -1
):
if json_start != 0 and logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Docker Compose log: {output[:json_start]}")
logger.debug(f"Docker Compose log: {output[:json_start].strip()}")
locations = json.loads(output[json_start : json_end + 1])
else:
raise WorkflowExecutionException(
Expand Down
157 changes: 86 additions & 71 deletions streamflow/deployment/connector/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,44 +36,28 @@ def __init__(
streamflow_config_dir: str,
config: SSHConfig,
max_concurrent_sessions: int,
retries: int,
retry_delay: int,
):
self._streamflow_config_dir: str = streamflow_config_dir
self._config: SSHConfig = config
self._max_concurrent_sessions: int = max_concurrent_sessions
self._ssh_connection: asyncssh.SSHClientConnection | None = None
self._connecting = False
self._retries = retries
self._retry_delay = retry_delay
self._connect_event: asyncio.Event = asyncio.Event()
self.ssh_attempts: int = 0

async def get_connection(self) -> asyncssh.SSHClientConnection:
if self._ssh_connection is None:
if not self._connecting:
self._connecting = True
for i in range(1, self._retries + 1):
try:
self._ssh_connection = await self._get_connection(self._config)
break
except (ConnectionError, ConnectionLost) as e:
if i == self._retries:
logger.exception(
f"Impossible to connect to {self._config.hostname}: {e}"
)
self._connect_event.set()
self.close()
raise
if logger.isEnabledFor(logging.WARNING):
logger.warning(
f"Connection to {self._config.hostname} failed: {e}. "
f"Waiting {self._retry_delay} seconds for the next attempt."
)
except asyncssh.Error:
self._connect_event.set()
self.close()
raise
await asyncio.sleep(self._retry_delay)
try:
self._ssh_connection = await self._get_connection(self._config)
except (ConnectionError, ConnectionLost, asyncssh.Error) as e:
if logger.isEnabledFor(logging.WARNING):
logger.warning(
f"Connection to {self._config.hostname} failed: {e}."
)
await self.close()
raise
self._connect_event.set()
else:
await self._connect_event.wait()
Expand Down Expand Up @@ -126,19 +110,20 @@ def _get_param_from_file(self, file_path: str):
with open(file_path) as f:
return f.read().strip()

def close(self):
self._connecting = False
async def close(self):
if self._ssh_connection is not None:
self._ssh_connection.close()
await self._ssh_connection.wait_closed()
self._ssh_connection = None
if self._connect_event.is_set():
self._connect_event.clear()
self._connect_event.set() # it is necessary to free any blocked tasks and avoid deadlocks
self._connect_event.clear()
self._connecting = False

def full(self) -> bool:
if self._ssh_connection:
return len(self._ssh_connection._channels) >= self._max_concurrent_sessions
else:
return False
return (
self._ssh_connection
and len(self._ssh_connection._channels) >= self._max_concurrent_sessions
)


class SSHContextManager:
Expand All @@ -148,6 +133,8 @@ def __init__(
contexts: MutableSequence[SSHContext],
command: str,
environment: MutableMapping[str, str] | None,
retries: int,
retry_delay: int,
stdin: int = asyncio.subprocess.PIPE,
stdout: int = asyncio.subprocess.PIPE,
stderr: int = asyncio.subprocess.PIPE,
Expand All @@ -161,16 +148,30 @@ def __init__(
self.encoding: str | None = encoding
self._condition: asyncio.Condition = condition
self._contexts: MutableSequence[SSHContext] = contexts
self._retries: int = retries
self._retry_delay: int = retry_delay
self._selected_context: SSHContext | None = None
self._proc: asyncssh.SSHClientProcess | None = None

async def __aenter__(self) -> asyncssh.SSHClientProcess:
async with self._condition:
while True:
for context in self._contexts:
if not context.full():
ssh_connection = await context.get_connection()
if all(c.ssh_attempts > self._retries for c in self._contexts):
raise WorkflowExecutionException(
"No more contexts available: terminating."
)
if (
len(free_contexts := [c for c in self._contexts if not c.full()])
== 0
):
await self._condition.wait()
else:
for context in free_contexts:
if context.ssh_attempts > self._retries:
# context terminated the retries
continue
try:
ssh_connection = await context.get_connection()
self._selected_context = context
self._proc = await ssh_connection.create_process(
self.command,
Expand All @@ -181,17 +182,30 @@ async def __aenter__(self) -> asyncssh.SSHClientProcess:
encoding=self.encoding,
)
await self._proc.__aenter__()
context.ssh_attempts = 0 # reset attempts
return self._proc
except ChannelOpenError as coe:
logger.warning(
f"Error opening SSH session to {context.get_hostname()} "
f"to execute command `{self.command}`: [{coe.code}] {coe.reason}"
)
await self._condition.wait()
except (
ConnectionError,
ConnectionLost,
asyncssh.Error,
ChannelOpenError,
) as coe:
if isinstance(coe, ChannelOpenError):
logger.warning(
f"Error opening SSH session to {context.get_hostname()} "
f"to execute command `{self.command}`: [{coe.code}] {coe.reason}"
)
context.ssh_attempts += 1
await context.close()
self._selected_context = None
self._proc.__aexit(None, None, None)
logger.warning(f"Connection attempt {context.ssh_attempts}")
await asyncio.sleep(self._retry_delay)

async def __aexit__(self, exc_type, exc_val, exc_tb):
async with self._condition:
if self._selected_context:
await self._selected_context.close()
if self._proc:
await self._proc.__aexit__(exc_type, exc_val, exc_tb)
self._condition.notify_all()
Expand All @@ -213,15 +227,14 @@ def __init__(
streamflow_config_dir=streamflow_config_dir,
config=config,
max_concurrent_sessions=max_concurrent_sessions,
retries=retries,
retry_delay=retry_delay,
)
for _ in range(max_connections)
]
self._retries = retries
self._retry_delay = retry_delay

def close(self):
for c in self._contexts:
c.close()
async def close(self):
await asyncio.gather(*(asyncio.create_task(c.close()) for c in self._contexts))

def get(
self,
Expand All @@ -241,6 +254,8 @@ def get(
stdout=stdout,
stderr=stderr,
encoding=encoding,
retries=self._retries,
retry_delay=self._retry_delay,
)


Expand Down Expand Up @@ -469,9 +484,9 @@ def _get_command(
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
f"EXECUTING command {command} on {location}" f" for job {job_name}"
if job_name
else ""
"EXECUTING command {} on {}{}".format(
command, location, f" for job {job_name}" if job_name else ""
)
)
return utils.encode_command(command)

Expand Down Expand Up @@ -645,27 +660,27 @@ async def run(
workdir=workdir,
)
command = utils.encode_command(command)
async with self._get_ssh_client_process(
location=location.name,
command=command,
stderr=asyncio.subprocess.STDOUT,
environment=environment,
) as proc:
result = await proc.wait(timeout=timeout)
else:
async with self._get_ssh_client_process(
location=location.name,
command=command,
stderr=asyncio.subprocess.STDOUT,
environment=environment,
) as proc:
result = await proc.wait(timeout=timeout)
return result.stdout.strip(), result.returncode if capture_output else None
async with self._get_ssh_client_process(
location=location.name,
command=command,
stderr=asyncio.subprocess.STDOUT,
environment=environment,
) as proc:
result = await proc.wait(timeout=timeout)
return (result.stdout.strip(), result.returncode) if capture_output else None

async def undeploy(self, external: bool) -> None:
for ssh_context in self.ssh_context_factories.values():
ssh_context.close()
await asyncio.gather(
*(
asyncio.create_task(ssh_context.close())
for ssh_context in self.ssh_context_factories.values()
)
)
self.ssh_context_factories = {}
for ssh_context in self.data_transfer_context_factories.values():
ssh_context.close()
await asyncio.gather(
*(
asyncio.create_task(ssh_context.close())
for ssh_context in self.data_transfer_context_factories.values()
)
)
self.data_transfer_context_factories = {}
1 change: 1 addition & 0 deletions streamflow/deployment/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async def write(self, data: Any):

class StreamWriterWrapper(StreamWrapper):
async def close(self):
self.stream.write_eof()
self.stream.close()
await self.stream.wait_closed()

Expand Down
2 changes: 1 addition & 1 deletion streamflow/workflow/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ async def _get_inputs(self, input_ports: MutableMapping[str, Port]):
if logger.isEnabledFor(logging.DEBUG):
if check_termination(inputs.values()):
logger.debug(
f"Step {self.name} received termination token with Status {_reduce_statuses([t.value for t in inputs.values()]).name.lower()}"
f"Step {self.name} received termination token with Status {_reduce_statuses([t.value for t in inputs.values()]).name}"
)
else:
logger.debug(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,5 @@ async def test_ssh_connector_multiple_request_fail(context: StreamFlowContext) -
):
assert isinstance(result, (ConnectionError, asyncssh.Error)) or (
isinstance(result, WorkflowExecutionException)
and result.args[0] == "Impossible to connect to .*"
and result.args[0] == "No more contexts available: terminating."
)
2 changes: 2 additions & 0 deletions tests/utils/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ async def get_ssh_deployment_config(_context: StreamFlowContext):
"hostname": "127.0.0.1:2222",
"sshKey": f.name,
"username": "linuxserver.io",
"retries": 2,
"retryDelay": 5,
}
],
"maxConcurrentSessions": 10,
Expand Down
Loading