Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/apache-airflow-providers-ssh/connections/ssh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Extra (optional)
* ``host_key`` - The base64 encoded ssh-rsa public key of the host or "ssh-<key type> <key data>" (as you would find in the ``known_hosts`` file). Specifying this allows making the connection if and only if the public key of the endpoint matches this value.
* ``disabled_algorithms`` - A dictionary mapping algorithm type to an iterable of algorithm identifiers, which will be disabled for the lifetime of the transport.
* ``ciphers`` - A list of ciphers to use in order of preference.
* ``host_proxy_cmd`` - A proxy command to be executed.

Example "extras" field:

Expand Down
19 changes: 17 additions & 2 deletions providers/src/airflow/providers/ssh/hooks/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,19 @@ def __init__(
if private_key:
self.pkey = self._pkey_from_private_key(private_key, passphrase=private_key_passphrase)

if "host_proxy_cmd" in extra_options:
self.host_proxy_cmd = extra_options.get("host_proxy_cmd")

if "timeout" in extra_options:
warnings.warn(
"Extra option `timeout` is deprecated."
"Please use `conn_timeout` instead."
"The old option `timeout` will be removed in a future version.",
category=AirflowProviderDeprecationWarning,
stacklevel=2,
)
self.timeout = int(extra_options["timeout"])

if "conn_timeout" in extra_options and self.conn_timeout is None:
self.conn_timeout = int(extra_options["conn_timeout"])

Expand Down Expand Up @@ -247,8 +260,10 @@ def __init__(
with open(user_ssh_config_filename) as config_fd:
ssh_conf.parse(config_fd)
host_info = ssh_conf.lookup(self.remote_host)
if host_info and host_info.get("proxycommand") and not self.host_proxy_cmd:
self.host_proxy_cmd = host_info["proxycommand"]
# If the proxy command is already set via the extra options, it will not be overwritten"""
if not self.host_proxy_cmd:
if host_info and host_info.get("proxycommand"):
self.host_proxy_cmd = host_info["proxycommand"]

if not (self.password or self.key_file):
if host_info and host_info.get("identityfile"):
Expand Down
18 changes: 18 additions & 0 deletions providers/tests/ssh/hooks/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,24 @@ def test_ssh_with_extra_ciphers(self, ssh_mock):
transport = ssh_mock.return_value.get_transport.return_value
assert transport.get_security_options.return_value.ciphers == TEST_CIPHERS

def test_host_proxy_cmd_in_extra(self):
TEST_HOST_PROXY_CMD = "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p"
session = settings.Session()
try:
conn = Connection(
conn_id="ssh_with_proxy_cmd",
host="localhost",
conn_type="ssh",
extra={"host_proxy_cmd": TEST_HOST_PROXY_CMD},
)
session.add(conn)
session.flush()
hook = SSHHook(ssh_conn_id=conn.conn_id)
assert hook.host_proxy_cmd == TEST_HOST_PROXY_CMD
finally:
session.delete(conn)
session.commit()

def test_openssh_private_key(self):
# Paramiko behaves differently with OpenSSH generated keys to paramiko
# generated keys, so we need a test one.
Expand Down
Loading