Skip to content

Commit

Permalink
SSHHook: check if existing connection is still alive (apache#41061)
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus authored Aug 21, 2024
1 parent 79f6383 commit d404a14
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 81 deletions.
159 changes: 79 additions & 80 deletions airflow/providers/ssh/hooks/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,94 +286,93 @@ def host_proxy(self) -> paramiko.ProxyCommand | None:

def get_conn(self) -> paramiko.SSHClient:
"""Establish an SSH connection to the remote host."""
if self.client is None:
self.log.debug("Creating SSH client for conn_id: %s", self.ssh_conn_id)
client = paramiko.SSHClient()

if self.allow_host_key_change:
self.log.warning(
"Remote Identification Change is not verified. "
"This won't protect against Man-In-The-Middle attacks"
)
# to avoid BadHostKeyException, skip loading host keys
client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy)
if self.client:
transport = self.client.get_transport()
if transport and transport.is_active():
# Return the existing connection
return self.client

self.log.debug("Creating SSH client for conn_id: %s", self.ssh_conn_id)
client = paramiko.SSHClient()

if self.allow_host_key_change:
self.log.warning(
"Remote Identification Change is not verified. "
"This won't protect against Man-In-The-Middle attacks"
)
# to avoid BadHostKeyException, skip loading host keys
client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy)
else:
client.load_system_host_keys()

if self.no_host_key_check:
self.log.warning("No Host Key Verification. This won't protect against Man-In-The-Middle attacks")
client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) # nosec B507
# to avoid BadHostKeyException, skip loading and saving host keys
known_hosts = os.path.expanduser("~/.ssh/known_hosts")
if not self.allow_host_key_change and os.path.isfile(known_hosts):
client.load_host_keys(known_hosts)

elif self.host_key is not None:
# Get host key from connection extra if it not set or None then we fallback to system host keys
client_host_keys = client.get_host_keys()
if self.port == SSH_PORT:
client_host_keys.add(self.remote_host, self.host_key.get_name(), self.host_key)
else:
client.load_system_host_keys()

if self.no_host_key_check:
self.log.warning(
"No Host Key Verification. This won't protect against Man-In-The-Middle attacks"
client_host_keys.add(
f"[{self.remote_host}]:{self.port}", self.host_key.get_name(), self.host_key
)
client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) # nosec B507
# to avoid BadHostKeyException, skip loading and saving host keys
known_hosts = os.path.expanduser("~/.ssh/known_hosts")
if not self.allow_host_key_change and os.path.isfile(known_hosts):
client.load_host_keys(known_hosts)

elif self.host_key is not None:
# Get host key from connection extra if it not set or None then we fallback to system host keys
client_host_keys = client.get_host_keys()
if self.port == SSH_PORT:
client_host_keys.add(self.remote_host, self.host_key.get_name(), self.host_key)
else:
client_host_keys.add(
f"[{self.remote_host}]:{self.port}", self.host_key.get_name(), self.host_key
)

connect_kwargs: dict[str, Any] = {
"hostname": self.remote_host,
"username": self.username,
"timeout": self.conn_timeout,
"compress": self.compress,
"port": self.port,
"sock": self.host_proxy,
"look_for_keys": self.look_for_keys,
"banner_timeout": self.banner_timeout,
}

if self.password:
password = self.password.strip()
connect_kwargs.update(password=password)

if self.pkey:
connect_kwargs.update(pkey=self.pkey)

if self.key_file:
connect_kwargs.update(key_filename=self.key_file)

if self.disabled_algorithms:
connect_kwargs.update(disabled_algorithms=self.disabled_algorithms)

def log_before_sleep(retry_state):
return self.log.info(
"Failed to connect. Sleeping before retry attempt %d", retry_state.attempt_number
)
connect_kwargs: dict[str, Any] = {
"hostname": self.remote_host,
"username": self.username,
"timeout": self.conn_timeout,
"compress": self.compress,
"port": self.port,
"sock": self.host_proxy,
"look_for_keys": self.look_for_keys,
"banner_timeout": self.banner_timeout,
}

for attempt in Retrying(
reraise=True,
wait=wait_fixed(3) + wait_random(0, 2),
stop=stop_after_attempt(3),
before_sleep=log_before_sleep,
):
with attempt:
client.connect(**connect_kwargs)
if self.password:
password = self.password.strip()
connect_kwargs.update(password=password)

if self.keepalive_interval:
# MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns
# type "Transport | None" and item "None" has no attribute "set_keepalive".
client.get_transport().set_keepalive(self.keepalive_interval) # type: ignore[union-attr]
if self.pkey:
connect_kwargs.update(pkey=self.pkey)

if self.ciphers:
# MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns
# type "Transport | None" and item "None" has no method `get_security_options`".
client.get_transport().get_security_options().ciphers = self.ciphers # type: ignore[union-attr]
if self.key_file:
connect_kwargs.update(key_filename=self.key_file)

self.client = client
return client
if self.disabled_algorithms:
connect_kwargs.update(disabled_algorithms=self.disabled_algorithms)

else:
# Return the existing connection
return self.client
def log_before_sleep(retry_state):
return self.log.info(
"Failed to connect. Sleeping before retry attempt %d", retry_state.attempt_number
)

for attempt in Retrying(
reraise=True,
wait=wait_fixed(3) + wait_random(0, 2),
stop=stop_after_attempt(3),
before_sleep=log_before_sleep,
):
with attempt:
client.connect(**connect_kwargs)

if self.keepalive_interval:
# MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns
# type "Transport | None" and item "None" has no attribute "set_keepalive".
client.get_transport().set_keepalive(self.keepalive_interval) # type: ignore[union-attr]

if self.ciphers:
# MyPy check ignored because "paramiko" isn't well-typed. The `client.get_transport()` returns
# type "Transport | None" and item "None" has no method `get_security_options`".
client.get_transport().get_security_options().ciphers = self.ciphers # type: ignore[union-attr]

self.client = client
return client

@deprecated(
reason=(
Expand Down
1 change: 0 additions & 1 deletion airflow/providers/ssh/operators/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def get_hook(self) -> SSHHook:

def get_ssh_client(self) -> SSHClient:
# Remember to use context manager or call .close() on this when done
self.log.info("Creating ssh_client")
return self.hook.get_conn()

@deprecated(
Expand Down
23 changes: 23 additions & 0 deletions tests/providers/ssh/hooks/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,3 +1092,26 @@ def test_connection_failure(self):
status, msg = hook.test_connection()
assert status is False
assert msg == "Test failure case"

def test_ssh_connection_client_is_reused_if_open(self):
hook = SSHHook(ssh_conn_id="ssh_default")
client1 = hook.get_conn()
client2 = hook.get_conn()
assert client1 is client2
assert client2.get_transport().is_active()

def test_ssh_connection_client_is_recreated_if_closed(self):
hook = SSHHook(ssh_conn_id="ssh_default")
client1 = hook.get_conn()
client1.close()
client2 = hook.get_conn()
assert client1 is not client2
assert client2.get_transport().is_active()

def test_ssh_connection_client_is_recreated_if_transport_closed(self):
hook = SSHHook(ssh_conn_id="ssh_default")
client1 = hook.get_conn()
client1.get_transport().close()
client2 = hook.get_conn()
assert client1 is not client2
assert client2.get_transport().is_active()

0 comments on commit d404a14

Please sign in to comment.