Skip to content

Commit

Permalink
openlineage, sftp: add OpenLineage support for sftp provider (#31360)
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
  • Loading branch information
mobuchowski authored Jul 25, 2023
1 parent ca20251 commit 6b88084
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 1 deletion.
82 changes: 82 additions & 0 deletions airflow/providers/sftp/operators/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
from __future__ import annotations

import os
import socket
import warnings
from pathlib import Path
from typing import Any, Sequence

import paramiko

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.sftp.hooks.sftp import SFTPHook
Expand Down Expand Up @@ -188,3 +191,82 @@ def execute(self, context: Any) -> str | list[str] | None:
raise AirflowException(f"Error while transferring {file_msg}, error: {str(e)}")

return self.local_filepath

def get_openlineage_facets_on_start(self):
"""
This returns OpenLineage datasets in format:
input: file://<local_host>/path
output: file://<remote_host>:<remote_port>/path.
"""
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage

scheme = "file"
local_host = socket.gethostname()
try:
local_host = socket.gethostbyname(local_host)
except Exception as e:
self.log.warning(
f"Failed to resolve local hostname. Using the hostname got by socket.gethostbyname() without resolution. {e}", # noqa: E501
exc_info=True,
)

hook = self.sftp_hook or self.ssh_hook or SFTPHook(ssh_conn_id=self.ssh_conn_id)

if self.remote_host is not None:
remote_host = self.remote_host
else:
remote_host = hook.get_connection(hook.ssh_conn_id).host

try:
remote_host = socket.gethostbyname(remote_host)
except OSError as e:
self.log.warning(
f"Failed to resolve remote hostname. Using the provided hostname without resolution. {e}", # noqa: E501
exc_info=True,
)

if hasattr(hook, "port"):
remote_port = hook.port
elif hasattr(hook, "ssh_hook"):
remote_port = hook.ssh_hook.port

# Since v4.1.0, SFTPOperator accepts both a string (single file) and a list of
# strings (multiple files) as local_filepath and remote_filepath, and internally
# keeps them as list in both cases. But before 4.1.0, only single string is
# allowed. So we consider both cases here for backward compatibility.
if isinstance(self.local_filepath, str):
local_filepath = [self.local_filepath]
else:
local_filepath = self.local_filepath
if isinstance(self.remote_filepath, str):
remote_filepath = [self.remote_filepath]
else:
remote_filepath = self.remote_filepath

local_datasets = [
Dataset(namespace=self._get_namespace(scheme, local_host, None, path), name=path)
for path in local_filepath
]
remote_datasets = [
Dataset(namespace=self._get_namespace(scheme, remote_host, remote_port, path), name=path)
for path in remote_filepath
]

if self.operation.lower() == SFTPOperation.GET:
inputs = remote_datasets
outputs = local_datasets
else:
inputs = local_datasets
outputs = remote_datasets

return OperatorLineage(
inputs=inputs,
outputs=outputs,
)

def _get_namespace(self, scheme, host, port, path) -> str:
port = port or paramiko.config.SSH_PORT
authority = f"{host}:{port}"
return f"{scheme}://{authority}"
1 change: 1 addition & 0 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,7 @@
"apache-airflow>=2.4.0"
],
"cross-providers-deps": [
"openlineage",
"ssh"
],
"excluded-python-versions": []
Expand Down
110 changes: 109 additions & 1 deletion tests/providers/sftp/operators/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
from __future__ import annotations

import os
import socket
from base64 import b64encode
from unittest import mock

import paramiko
import pytest
from openlineage.client.run import Dataset

from airflow.exceptions import AirflowException
from airflow.models import DAG
from airflow.models import DAG, Connection
from airflow.providers.sftp.hooks.sftp import SFTPHook
from airflow.providers.sftp.operators.sftp import SFTPOperation, SFTPOperator
from airflow.providers.ssh.hooks.ssh import SSHHook
Expand All @@ -36,6 +39,18 @@
DEFAULT_DATE = datetime(2017, 1, 1)
TEST_CONN_ID = "conn_id_for_testing"

LOCAL_FILEPATH = "/path/local"
REMOTE_FILEPATH = "/path/remote"
LOCAL_DATASET = [
Dataset(namespace=f"file://{socket.gethostbyname(socket.gethostname())}:22", name=LOCAL_FILEPATH)
]
REMOTE_DATASET = [Dataset(namespace="file://remotehost:22", name=REMOTE_FILEPATH)]

TEST_GET_PUT_PARAMS = [
(SFTPOperation.GET, (REMOTE_DATASET, LOCAL_DATASET)),
(SFTPOperation.PUT, (LOCAL_DATASET, REMOTE_DATASET)),
]


class TestSFTPOperator:
def setup_method(self):
Expand Down Expand Up @@ -478,3 +493,96 @@ def test_return_str_when_local_filepath_was_str(self, mock_get):
return_value = sftp_op.execute(None)
assert isinstance(return_value, str)
assert return_value == local_filepath

@pytest.mark.parametrize(
"operation, expected",
TEST_GET_PUT_PARAMS,
)
@mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn", spec=paramiko.SSHClient)
@mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_connection", spec=Connection)
def test_extract_ssh_conn_id(self, get_connection, get_conn, operation, expected):
get_connection.return_value = Connection(
conn_id="sftp_conn_id",
conn_type="sftp",
host="remotehost",
port=22,
)

dag_id = "sftp_dag"
task_id = "sftp_task"

task = SFTPOperator(
task_id=task_id,
ssh_conn_id="sftp_conn_id",
dag=DAG(dag_id),
start_date=timezone.utcnow(),
local_filepath="/path/local",
remote_filepath="/path/remote",
operation=operation,
)
lineage = task.get_openlineage_facets_on_start()

assert lineage.inputs == expected[0]
assert lineage.outputs == expected[1]

@pytest.mark.parametrize(
"operation, expected",
TEST_GET_PUT_PARAMS,
)
@mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn", spec=paramiko.SSHClient)
@mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_connection", spec=Connection)
def test_extract_sftp_hook(self, get_connection, get_conn, operation, expected):
get_connection.return_value = Connection(
conn_id="sftp_conn_id",
conn_type="sftp",
host="remotehost",
port=22,
)

dag_id = "sftp_dag"
task_id = "sftp_task"

task = SFTPOperator(
task_id=task_id,
sftp_hook=SFTPHook(ssh_conn_id="sftp_conn_id"),
dag=DAG(dag_id),
start_date=timezone.utcnow(),
local_filepath="/path/local",
remote_filepath="/path/remote",
operation=operation,
)
lineage = task.get_openlineage_facets_on_start()

assert lineage.inputs == expected[0]
assert lineage.outputs == expected[1]

@pytest.mark.parametrize(
"operation, expected",
TEST_GET_PUT_PARAMS,
)
@mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn", spec=paramiko.SSHClient)
@mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_connection", spec=Connection)
def test_extract_ssh_hook(self, get_connection, get_conn, operation, expected):
get_connection.return_value = Connection(
conn_id="sftp_conn_id",
conn_type="sftp",
host="remotehost",
port=22,
)

dag_id = "sftp_dag"
task_id = "sftp_task"

task = SFTPOperator(
task_id=task_id,
ssh_hook=SSHHook(ssh_conn_id="sftp_conn_id"),
dag=DAG(dag_id),
start_date=timezone.utcnow(),
local_filepath="/path/local",
remote_filepath="/path/remote",
operation=operation,
)
lineage = task.get_openlineage_facets_on_start()

assert lineage.inputs == expected[0]
assert lineage.outputs == expected[1]

0 comments on commit 6b88084

Please sign in to comment.