Skip to content
9 changes: 8 additions & 1 deletion providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,20 @@ def create_directory(self, path: str, mode: int = 0o777) -> None:
self.log.info("Creating %s", path)
conn.mkdir(path, mode=mode)

def delete_directory(self, path: str) -> None:
def delete_directory(self, path: str, include_files: bool = False) -> None:
"""
Delete a directory on the remote system.

:param path: full path to the remote directory to delete
"""
with self.get_conn() as conn:
if include_files is True:
files, dirs, _ = self.get_tree_map(path)
dirs = dirs[::-1] # reverse the order for deleting deepest directories first
for file_path in files:
conn.remove(file_path)
for dir_path in dirs:
conn.rmdir(dir_path)
conn.rmdir(path)

def retrieve_file(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None:
Expand Down
77 changes: 48 additions & 29 deletions providers/sftp/src/airflow/providers/sftp/operators/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class SFTPOperation:

PUT = "put"
GET = "get"
DELETE = "delete"


class SFTPOperator(BaseOperator):
Expand All @@ -53,8 +54,8 @@ class SFTPOperator(BaseOperator):
Nullable. If provided, it will replace the `remote_host` which was
defined in `sftp_hook` or predefined in the connection of `ssh_conn_id`.
:param local_filepath: local file path or list of local file paths to get or put. (templated)
:param remote_filepath: remote file path or list of remote file paths to get or put. (templated)
:param operation: specify operation 'get' or 'put', defaults to put
:param remote_filepath: remote file path or list of remote file paths to get, put, or delete. (templated)
:param operation: specify operation 'get', 'put', or 'delete', defaults to put
:param confirm: specify if the SFTP operation should be confirmed, defaults to True
:param create_intermediate_dirs: create missing intermediate directories when
copying from remote to local and vice-versa. Default is False.
Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(
sftp_hook: SFTPHook | None = None,
ssh_conn_id: str | None = None,
remote_host: str | None = None,
local_filepath: str | list[str],
local_filepath: str | list[str] | None = None,
remote_filepath: str | list[str],
operation: str = SFTPOperation.PUT,
confirm: bool = True,
Expand All @@ -102,7 +103,9 @@ def __init__(
self.remote_filepath = remote_filepath

def execute(self, context: Any) -> str | list[str] | None:
if isinstance(self.local_filepath, str):
if self.local_filepath is None:
local_filepath_array = []
elif isinstance(self.local_filepath, str):
local_filepath_array = [self.local_filepath]
else:
local_filepath_array = self.local_filepath
Expand All @@ -112,16 +115,21 @@ def execute(self, context: Any) -> str | list[str] | None:
else:
remote_filepath_array = self.remote_filepath

if len(local_filepath_array) != len(remote_filepath_array):
if self.operation.lower() in (SFTPOperation.GET, SFTPOperation.PUT) and len(
local_filepath_array
) != len(remote_filepath_array):
raise ValueError(
f"{len(local_filepath_array)} paths in local_filepath "
f"!= {len(remote_filepath_array)} paths in remote_filepath"
)

if self.operation.lower() not in (SFTPOperation.GET, SFTPOperation.PUT):
if self.operation.lower() == SFTPOperation.DELETE and local_filepath_array:
raise ValueError("local_filepath should not be provided for delete operation")

if self.operation.lower() not in (SFTPOperation.GET, SFTPOperation.PUT, SFTPOperation.DELETE):
raise TypeError(
f"Unsupported operation value {self.operation}, "
f"expected {SFTPOperation.GET} or {SFTPOperation.PUT}."
f"expected {SFTPOperation.GET} or {SFTPOperation.PUT} or {SFTPOperation.DELETE}."
)

file_msg = None
Expand All @@ -144,32 +152,43 @@ def execute(self, context: Any) -> str | list[str] | None:
)
self.sftp_hook.remote_host = self.remote_host

for _local_filepath, _remote_filepath in zip(local_filepath_array, remote_filepath_array):
if self.operation.lower() == SFTPOperation.GET:
local_folder = os.path.dirname(_local_filepath)
if self.create_intermediate_dirs:
Path(local_folder).mkdir(parents=True, exist_ok=True)
file_msg = f"from {_remote_filepath} to {_local_filepath}"
self.log.info("Starting to transfer %s", file_msg)
if self.operation.lower() in (SFTPOperation.GET, SFTPOperation.PUT):
for _local_filepath, _remote_filepath in zip(local_filepath_array, remote_filepath_array):
if self.operation.lower() == SFTPOperation.GET:
local_folder = os.path.dirname(_local_filepath)
if self.create_intermediate_dirs:
Path(local_folder).mkdir(parents=True, exist_ok=True)
file_msg = f"from {_remote_filepath} to {_local_filepath}"
self.log.info("Starting to transfer %s", file_msg)
if self.sftp_hook.isdir(_remote_filepath):
self.sftp_hook.retrieve_directory(_remote_filepath, _local_filepath)
else:
self.sftp_hook.retrieve_file(_remote_filepath, _local_filepath)
elif self.operation.lower() == SFTPOperation.PUT:
remote_folder = os.path.dirname(_remote_filepath)
if self.create_intermediate_dirs:
self.sftp_hook.create_directory(remote_folder)
file_msg = f"from {_local_filepath} to {_remote_filepath}"
self.log.info("Starting to transfer file %s", file_msg)
if os.path.isdir(_local_filepath):
self.sftp_hook.store_directory(
_remote_filepath, _local_filepath, confirm=self.confirm
)
else:
self.sftp_hook.store_file(_remote_filepath, _local_filepath, confirm=self.confirm)
elif self.operation.lower() == SFTPOperation.DELETE:
for _remote_filepath in remote_filepath_array:
file_msg = f"{_remote_filepath}"
self.log.info("Starting to delete %s", file_msg)
if self.sftp_hook.isdir(_remote_filepath):
self.sftp_hook.retrieve_directory(_remote_filepath, _local_filepath)
self.sftp_hook.delete_directory(_remote_filepath, include_files=True)
else:
self.sftp_hook.retrieve_file(_remote_filepath, _local_filepath)
else:
remote_folder = os.path.dirname(_remote_filepath)
if self.create_intermediate_dirs:
self.sftp_hook.create_directory(remote_folder)
file_msg = f"from {_local_filepath} to {_remote_filepath}"
self.log.info("Starting to transfer file %s", file_msg)
if os.path.isdir(_local_filepath):
self.sftp_hook.store_directory(
_remote_filepath, _local_filepath, confirm=self.confirm
)
else:
self.sftp_hook.store_file(_remote_filepath, _local_filepath, confirm=self.confirm)
self.sftp_hook.delete_file(_remote_filepath)

except Exception as e:
raise AirflowException(f"Error while transferring {file_msg}, error: {e}")
raise AirflowException(
f"Error while processing {self.operation.upper()} operation {file_msg}, error: {e}"
)

return self.local_filepath

Expand Down
22 changes: 22 additions & 0 deletions providers/sftp/tests/provider_tests/sftp/hooks/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,28 @@ def test_create_and_delete_directories(self):
assert new_dir_path not in output
assert base_dir not in output

def test_create_and_delete_directory_with_files(self):
new_dir = "new_dir"
sub_dir = "sub_dir"
additional_file = "additional_file.txt"
self.hook.create_directory(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir))
output = self.hook.describe_directory(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS))
assert new_dir in output
self.hook.create_directory(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir, sub_dir))
self._create_additional_test_file(file_name=additional_file)
self.hook.store_file(
remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir, additional_file),
local_full_path=os.path.join(self.temp_dir, additional_file),
)
output = self.hook.describe_directory(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir))
assert sub_dir in output
assert additional_file in output
self.hook.delete_directory(
os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir), include_files=True
)
output = self.hook.describe_directory(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS))
assert new_dir not in output

def test_store_retrieve_and_delete_file(self):
self.hook.store_file(
remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
Expand Down
54 changes: 54 additions & 0 deletions providers/sftp/tests/provider_tests/sftp/operators/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,60 @@ def test_return_str_when_local_filepath_was_str(self, mock_get):
assert isinstance(return_value, str)
assert return_value == local_filepath

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.delete_file")
def test_str_filepaths_delete(self, mock_delete):
remote_filepath = "/tmp/test"
SFTPOperator(
task_id="test_str_filepaths_delete",
sftp_hook=self.sftp_hook,
remote_filepath=remote_filepath,
operation=SFTPOperation.DELETE,
).execute(None)
assert mock_delete.call_count == 1
args, _ = mock_delete.call_args_list[0]
assert args == (remote_filepath,)

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.delete_file")
def test_multiple_filepaths_delete(self, mock_delete):
remote_filepath = ["/tmp/rtest1", "/tmp/rtest2"]
SFTPOperator(
task_id="test_multiple_filepaths_delete",
sftp_hook=self.sftp_hook,
remote_filepath=remote_filepath,
operation=SFTPOperation.DELETE,
).execute(None)
assert mock_delete.call_count == 2
args0, _ = mock_delete.call_args_list[0]
args1, _ = mock_delete.call_args_list[1]
assert args0 == (remote_filepath[0],)
assert args1 == (remote_filepath[1],)

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.delete_directory")
def test_str_dirpaths_delete(self, mock_delete):
remote_filepath = "/tmp"
SFTPOperator(
task_id="test_str_dirpaths_delete",
sftp_hook=self.sftp_hook,
remote_filepath=remote_filepath,
operation=SFTPOperation.DELETE,
).execute(None)
assert mock_delete.call_count == 1
args, _ = mock_delete.call_args_list[0]
assert args == (remote_filepath,)

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.delete_file")
def test_local_filepath_exists_error_delete(self, mock_delete):
local_filepath = "/tmp"
remote_filepath = "/tmp_remote"
with pytest.raises(ValueError, match="local_filepath should not be provided for delete operation"):
SFTPOperator(
task_id="test_local_filepath_exists_error_delete",
sftp_hook=self.sftp_hook,
local_filepath=local_filepath,
remote_filepath=remote_filepath,
operation=SFTPOperation.DELETE,
).execute(None)

@pytest.mark.parametrize(
"operation, expected",
TEST_GET_PUT_PARAMS,
Expand Down