Skip to content

Commit

Permalink
fixing SSHHook bug when using allow_host_key_change param (#24116)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkruc authored Jun 2, 2022
1 parent 3120576 commit ddb2a4f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
9 changes: 6 additions & 3 deletions airflow/providers/ssh/hooks/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,16 @@ def get_conn(self) -> paramiko.SSHClient:
self.log.debug('Creating SSH client for conn_id: %s', self.ssh_conn_id)
client = paramiko.SSHClient()

if not self.allow_host_key_change:
if self.allow_host_key_change:
self.log.warning(
"Remote Identification Change is not verified. "
"This won't protect against Man-In-The-Middle attacks"
)
else:
client.load_system_host_keys()

if self.no_host_key_check:
self.log.warning("No Host Key Verification. This won't protect against Man-In-The-Middle attacks")
# Default is RejectPolicy
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
else:
if self.host_key is not None:
client_host_keys = client.get_host_keys()
Expand All @@ -289,6 +288,10 @@ def get_conn(self) -> paramiko.SSHClient:
else:
pass # will fallback to system host keys if none explicitly specified in conn extra

if self.no_host_key_check or self.allow_host_key_change:
# Default is RejectPolicy
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

connect_kwargs: Dict[str, Any] = dict(
hostname=self.remote_host,
username=self.username,
Expand Down
48 changes: 48 additions & 0 deletions tests/providers/ssh/hooks/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ class TestSSHHook(unittest.TestCase):
CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE = 'ssh_with_host_key_and_no_host_key_check_false'
CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE = 'ssh_with_host_key_and_no_host_key_check_true'
CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE = 'ssh_with_no_host_key_and_no_host_key_check_false'
CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE = 'ssh_with_no_host_key_and_no_host_key_check_true'
CONN_SSH_WITH_HOST_KEY_AND_ALLOW_HOST_KEY_CHANGES_TRUE = (
'ssh_with_host_key_and_allow_host_key_changes_true'
)

@classmethod
def tearDownClass(cls) -> None:
Expand All @@ -110,6 +114,7 @@ def tearDownClass(cls) -> None:
cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
]
connections = session.query(Connection).filter(Connection.conn_id.in_(conns_to_reset))
connections.delete(synchronize_session=False)
Expand Down Expand Up @@ -236,6 +241,28 @@ def setUpClass(cls) -> None:
extra=json.dumps({"private_key": TEST_PRIVATE_KEY, "no_host_key_check": False}),
)
)
db.merge_conn(
Connection(
conn_id=cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
host='remote_host',
conn_type='ssh',
extra=json.dumps({"private_key": TEST_PRIVATE_KEY, "no_host_key_check": True}),
)
)
db.merge_conn(
Connection(
conn_id=cls.CONN_SSH_WITH_HOST_KEY_AND_ALLOW_HOST_KEY_CHANGES_TRUE,
host='remote_host',
conn_type='ssh',
extra=json.dumps(
{
"private_key": TEST_PRIVATE_KEY,
"host_key": TEST_HOST_KEY,
"allow_host_key_change": True,
}
),
)
)

@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_password(self, ssh_mock):
Expand Down Expand Up @@ -522,6 +549,27 @@ def test_ssh_connection_with_no_host_key_where_no_host_key_check_is_false(self,
assert ssh_client.return_value.connect.called is True
assert ssh_client.return_value.get_host_keys.return_value.add.called is False

def test_ssh_connection_with_host_key_where_no_host_key_check_is_true(self):
with pytest.raises(ValueError):
SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE)

@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_no_host_key_where_no_host_key_check_is_true(self, ssh_client):
hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE)
assert hook.host_key is None
with hook.get_conn():
assert ssh_client.return_value.connect.called is True
assert ssh_client.return_value.set_missing_host_key_policy.called is True

@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_host_key_where_allow_host_key_change_is_true(self, ssh_client):
hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_ALLOW_HOST_KEY_CHANGES_TRUE)
assert hook.host_key is not None
with hook.get_conn():
assert ssh_client.return_value.connect.called is True
assert ssh_client.return_value.load_system_host_keys.called is False
assert ssh_client.return_value.set_missing_host_key_policy.called is True

@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_conn_timeout(self, ssh_mock):
hook = SSHHook(
Expand Down

0 comments on commit ddb2a4f

Please sign in to comment.