diff --git a/airflow/providers/sftp/operators/sftp.py b/airflow/providers/sftp/operators/sftp.py index 0920387faab40..8da4b3f332930 100644 --- a/airflow/providers/sftp/operators/sftp.py +++ b/airflow/providers/sftp/operators/sftp.py @@ -19,10 +19,13 @@ from __future__ import annotations import os +import socket import warnings from pathlib import Path from typing import Any, Sequence +import paramiko + from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import BaseOperator from airflow.providers.sftp.hooks.sftp import SFTPHook @@ -188,3 +191,82 @@ def execute(self, context: Any) -> str | list[str] | None: raise AirflowException(f"Error while transferring {file_msg}, error: {str(e)}") return self.local_filepath + + def get_openlineage_facets_on_start(self): + """ + This returns OpenLineage datasets in format: + input: file:///path + output: file://:/path. + """ + from openlineage.client.run import Dataset + + from airflow.providers.openlineage.extractors import OperatorLineage + + scheme = "file" + local_host = socket.gethostname() + try: + local_host = socket.gethostbyname(local_host) + except Exception as e: + self.log.warning( + f"Failed to resolve local hostname. Using the hostname got by socket.gethostbyname() without resolution. {e}", # noqa: E501 + exc_info=True, + ) + + hook = self.sftp_hook or self.ssh_hook or SFTPHook(ssh_conn_id=self.ssh_conn_id) + + if self.remote_host is not None: + remote_host = self.remote_host + else: + remote_host = hook.get_connection(hook.ssh_conn_id).host + + try: + remote_host = socket.gethostbyname(remote_host) + except OSError as e: + self.log.warning( + f"Failed to resolve remote hostname. Using the provided hostname without resolution. {e}", # noqa: E501 + exc_info=True, + ) + + if hasattr(hook, "port"): + remote_port = hook.port + elif hasattr(hook, "ssh_hook"): + remote_port = hook.ssh_hook.port + + # Since v4.1.0, SFTPOperator accepts both a string (single file) and a list of + # strings (multiple files) as local_filepath and remote_filepath, and internally + # keeps them as list in both cases. But before 4.1.0, only single string is + # allowed. So we consider both cases here for backward compatibility. + if isinstance(self.local_filepath, str): + local_filepath = [self.local_filepath] + else: + local_filepath = self.local_filepath + if isinstance(self.remote_filepath, str): + remote_filepath = [self.remote_filepath] + else: + remote_filepath = self.remote_filepath + + local_datasets = [ + Dataset(namespace=self._get_namespace(scheme, local_host, None, path), name=path) + for path in local_filepath + ] + remote_datasets = [ + Dataset(namespace=self._get_namespace(scheme, remote_host, remote_port, path), name=path) + for path in remote_filepath + ] + + if self.operation.lower() == SFTPOperation.GET: + inputs = remote_datasets + outputs = local_datasets + else: + inputs = local_datasets + outputs = remote_datasets + + return OperatorLineage( + inputs=inputs, + outputs=outputs, + ) + + def _get_namespace(self, scheme, host, port, path) -> str: + port = port or paramiko.config.SSH_PORT + authority = f"{host}:{port}" + return f"{scheme}://{authority}" diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 091a45cf36ee2..f4eeb8a273415 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -785,6 +785,7 @@ "apache-airflow>=2.4.0" ], "cross-providers-deps": [ + "openlineage", "ssh" ], "excluded-python-versions": [] diff --git a/tests/providers/sftp/operators/test_sftp.py b/tests/providers/sftp/operators/test_sftp.py index 92137e2a5c183..8adb93f7db67b 100644 --- a/tests/providers/sftp/operators/test_sftp.py +++ b/tests/providers/sftp/operators/test_sftp.py @@ -18,13 +18,16 @@ from __future__ import annotations import os +import socket from base64 import b64encode from unittest import mock +import paramiko import pytest +from openlineage.client.run import Dataset from airflow.exceptions import AirflowException -from airflow.models import DAG +from airflow.models import DAG, Connection from airflow.providers.sftp.hooks.sftp import SFTPHook from airflow.providers.sftp.operators.sftp import SFTPOperation, SFTPOperator from airflow.providers.ssh.hooks.ssh import SSHHook @@ -36,6 +39,18 @@ DEFAULT_DATE = datetime(2017, 1, 1) TEST_CONN_ID = "conn_id_for_testing" +LOCAL_FILEPATH = "/path/local" +REMOTE_FILEPATH = "/path/remote" +LOCAL_DATASET = [ + Dataset(namespace=f"file://{socket.gethostbyname(socket.gethostname())}:22", name=LOCAL_FILEPATH) +] +REMOTE_DATASET = [Dataset(namespace="file://remotehost:22", name=REMOTE_FILEPATH)] + +TEST_GET_PUT_PARAMS = [ + (SFTPOperation.GET, (REMOTE_DATASET, LOCAL_DATASET)), + (SFTPOperation.PUT, (LOCAL_DATASET, REMOTE_DATASET)), +] + class TestSFTPOperator: def setup_method(self): @@ -478,3 +493,96 @@ def test_return_str_when_local_filepath_was_str(self, mock_get): return_value = sftp_op.execute(None) assert isinstance(return_value, str) assert return_value == local_filepath + + @pytest.mark.parametrize( + "operation, expected", + TEST_GET_PUT_PARAMS, + ) + @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn", spec=paramiko.SSHClient) + @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_connection", spec=Connection) + def test_extract_ssh_conn_id(self, get_connection, get_conn, operation, expected): + get_connection.return_value = Connection( + conn_id="sftp_conn_id", + conn_type="sftp", + host="remotehost", + port=22, + ) + + dag_id = "sftp_dag" + task_id = "sftp_task" + + task = SFTPOperator( + task_id=task_id, + ssh_conn_id="sftp_conn_id", + dag=DAG(dag_id), + start_date=timezone.utcnow(), + local_filepath="/path/local", + remote_filepath="/path/remote", + operation=operation, + ) + lineage = task.get_openlineage_facets_on_start() + + assert lineage.inputs == expected[0] + assert lineage.outputs == expected[1] + + @pytest.mark.parametrize( + "operation, expected", + TEST_GET_PUT_PARAMS, + ) + @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn", spec=paramiko.SSHClient) + @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_connection", spec=Connection) + def test_extract_sftp_hook(self, get_connection, get_conn, operation, expected): + get_connection.return_value = Connection( + conn_id="sftp_conn_id", + conn_type="sftp", + host="remotehost", + port=22, + ) + + dag_id = "sftp_dag" + task_id = "sftp_task" + + task = SFTPOperator( + task_id=task_id, + sftp_hook=SFTPHook(ssh_conn_id="sftp_conn_id"), + dag=DAG(dag_id), + start_date=timezone.utcnow(), + local_filepath="/path/local", + remote_filepath="/path/remote", + operation=operation, + ) + lineage = task.get_openlineage_facets_on_start() + + assert lineage.inputs == expected[0] + assert lineage.outputs == expected[1] + + @pytest.mark.parametrize( + "operation, expected", + TEST_GET_PUT_PARAMS, + ) + @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn", spec=paramiko.SSHClient) + @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_connection", spec=Connection) + def test_extract_ssh_hook(self, get_connection, get_conn, operation, expected): + get_connection.return_value = Connection( + conn_id="sftp_conn_id", + conn_type="sftp", + host="remotehost", + port=22, + ) + + dag_id = "sftp_dag" + task_id = "sftp_task" + + task = SFTPOperator( + task_id=task_id, + ssh_hook=SSHHook(ssh_conn_id="sftp_conn_id"), + dag=DAG(dag_id), + start_date=timezone.utcnow(), + local_filepath="/path/local", + remote_filepath="/path/remote", + operation=operation, + ) + lineage = task.get_openlineage_facets_on_start() + + assert lineage.inputs == expected[0] + assert lineage.outputs == expected[1]