Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def get_managed_conn(self) -> Generator[SFTPClient, None, None]:
self._sftp_conn = None
self._ssh_conn.close()
self._ssh_conn = None
if hasattr(self, "host_proxy"):
del self.host_proxy

def get_conn_count(self) -> int:
"""Get the number of open connections."""
Expand Down
42 changes: 42 additions & 0 deletions providers/sftp/tests/unit/sftp/hooks/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,48 @@ def test_get_managed_conn(self):
assert self.hook.get_conn_count() == 0
assert self.hook.conn is None

@patch("paramiko.SSHClient")
@patch("paramiko.ProxyCommand")
@patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection")
def test_proxy_command_cache_invalidated_after_connection_closed(
self, mock_get_connection, mock_proxy_command, mock_ssh_client
):
"""
Assert that the ProxyCommand gets invalidated after the connection is closed
"""

mock_connection = MagicMock()
mock_connection.login = "user"
mock_connection.password = None
mock_connection.host = "example.com"
mock_connection.port = 22
mock_connection.extra = None
mock_get_connection.return_value = mock_connection

mock_sftp_client = MagicMock(spec=SFTPClient)
mock_ssh_client.open_sftp.return_value = mock_sftp_client

mock_transport = MagicMock()
mock_ssh_client.return_value.get_transport.return_value = mock_transport
mock_proxy_command.return_value = MagicMock()

host_proxy_cmd = "ncat --proxy-auth proxy_user:**** --proxy proxy_host:port %h %p"
prev_proxy_command = None

hook = SFTPHook(
remote_host="example.com",
username="user",
host_proxy_cmd=host_proxy_cmd,
)

with hook.get_managed_conn() as _:
assert hasattr(self.hook, "host_proxy")
prev_proxy_command = hook.host_proxy

mock_proxy_command.return_value = MagicMock()

assert prev_proxy_command != hook.host_proxy

@patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn")
def test_get_close_conn(self, mock_get_conn):
mock_sftp_client = MagicMock(spec=SFTPClient)
Expand Down