Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SFTPFileTransfer #136

Merged
merged 21 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Pass connect function in init
  • Loading branch information
jl-wynen committed Aug 25, 2023
commit 46a9b710e63d4c7b18f850316084c29e2aeccc64
53 changes: 18 additions & 35 deletions src/scitacean/transfer/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def __init__(
host: str,
port: Optional[int] = None,
source_folder: Optional[Union[str, RemotePath]] = None,
connect: Optional[Callable[..., Connection]] = None,
) -> None:
"""Construct a new SFTP file transfer.

Expand All @@ -258,6 +259,17 @@ def __init__(
Upload files to this folder if set.
Otherwise, upload to the dataset's source_folder.
Ignored when downloading files.
connect:
A function that creates and returns a :class:`fabric.connection.Connection`
object.
Will first be called with only ``host`` and ``port``.
If this fails (by raising
:class:`paramiko.ssh_exception.AuthenticationException`), the function is
called with ``host``, ``port``, and, optionally, ``user`` and
``connection_kwargs`` depending on the authentication method.
Raising :class:`paramiko.ssh_exception.AuthenticationException` in the 2nd
call or any other exception in the 1st signals failure of
``connect_for_download``.
"""
self._host = host
self._port = port
Expand All @@ -266,62 +278,33 @@ def __init__(
if isinstance(source_folder, str)
else source_folder
)
self._connect = connect

def source_folder_for(self, dataset: Dataset) -> RemotePath:
"""Return the source folder used for the given dataset."""
return source_folder_for(dataset, self._source_folder_pattern)

@contextmanager
def connect_for_download(
self, connect: Optional[Callable[..., Connection]] = None
) -> Iterator[SFTPDownloadConnection]:
"""Create a connection for downloads, use as a context manager.

Parameters
----------
connect:
A function that creates and returns a :class:`fabric.connection.Connection`
object.
Will first be called with only ``host`` and ``port``.
If this fails (by raising
:class:`paramiko.ssh_exception.AuthenticationException`), the function is
called with ``host``, ``port``, and, optionally, ``user`` and
``connection_kwargs`` depending on the authentication method.
Raising :class:`paramiko.ssh_exception.AuthenticationException` in the 2nd
call or any other exception in the 1st signals failure of
``connect_for_download``.
"""
con = _connect(self._host, self._port, connect=connect)
def connect_for_download(self) -> Iterator[SFTPDownloadConnection]:
"""Create a connection for downloads, use as a context manager."""
con = _connect(self._host, self._port, connect=self._connect)
try:
yield SFTPDownloadConnection(connection=con)
finally:
con.close()

@contextmanager
def connect_for_upload(
self, dataset: Dataset, connect: Optional[Callable[..., Connection]] = None
) -> Iterator[SFTPUploadConnection]:
def connect_for_upload(self, dataset: Dataset) -> Iterator[SFTPUploadConnection]:
"""Create a connection for uploads, use as a context manager.

Parameters
----------
dataset:
The connection will be used to upload files of this dataset.
Used to determine the target folder.
connect:
A function that creates and returns a :class:`fabric.connection.Connection`
object.
Will first be called with only ``host`` and ``port``.
If this fails (by raising
:class:`paramiko.ssh_exception.AuthenticationException`), the function is
called with ``host``, ``port``, and, optionally, ``user`` and
``connection_kwargs`` depending on the authentication method.
Raising :class:`paramiko.ssh_exception.AuthenticationException` in the 2nd
call or any other exception in the 1st signals failure of
``connect_for_upload``.
"""
source_folder = self.source_folder_for(dataset)
con = _connect(self._host, self._port, connect=connect)
con = _connect(self._host, self._port, connect=self._connect)
try:
yield SFTPUploadConnection(
connection=con,
Expand Down
133 changes: 76 additions & 57 deletions tests/transfer/sftp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Callable, Iterator, Optional
from typing import Iterator

import fabric
import paramiko
import pytest
from fabric import Connection

from scitacean import Dataset, File, FileUploadError, RemotePath
from scitacean.testing.client import FakeClient
Expand All @@ -32,8 +31,12 @@ def server(request, sftp_fileserver):


def test_download_one_file(sftp_access, sftp_connect_with_username_password, tmp_path):
sftp = SFTPFileTransfer(host=sftp_access.host, port=sftp_access.port)
with sftp.connect_for_download(connect=sftp_connect_with_username_password) as con:
sftp = SFTPFileTransfer(
host=sftp_access.host,
port=sftp_access.port,
connect=sftp_connect_with_username_password,
)
with sftp.connect_for_download() as con:
con.download_files(
remote=[RemotePath("/data/seed/text.txt")], local=[tmp_path / "text.txt"]
)
Expand All @@ -43,8 +46,12 @@ def test_download_one_file(sftp_access, sftp_connect_with_username_password, tmp


def test_download_two_files(sftp_access, sftp_connect_with_username_password, tmp_path):
sftp = SFTPFileTransfer(host=sftp_access.host, port=sftp_access.port)
with sftp.connect_for_download(connect=sftp_connect_with_username_password) as con:
sftp = SFTPFileTransfer(
host=sftp_access.host,
port=sftp_access.port,
connect=sftp_connect_with_username_password,
)
with sftp.connect_for_download() as con:
con.download_files(
remote=[
RemotePath("/data/seed/table.csv"),
Expand All @@ -64,10 +71,12 @@ def test_upload_one_file_source_folder_in_dataset(
ds = Dataset(type="raw", source_folder=RemotePath("/data/upload"))
tmp_path.joinpath("file0.txt").write_text("File to test upload123")

sftp = SFTPFileTransfer(host=sftp_access.host, port=sftp_access.port)
with sftp.connect_for_upload(
dataset=ds, connect=sftp_connect_with_username_password
) as con:
sftp = SFTPFileTransfer(
host=sftp_access.host,
port=sftp_access.port,
connect=sftp_connect_with_username_password,
)
with sftp.connect_for_upload(dataset=ds) as con:
assert con.source_folder == RemotePath("/data/upload")
con.upload_files(
File.from_local(path=tmp_path / "file0.txt", remote_path="upload_0.txt")
Expand All @@ -89,10 +98,9 @@ def test_upload_one_file_source_folder_in_transfer(
host=sftp_access.host,
port=sftp_access.port,
source_folder="/data/upload/{owner}",
connect=sftp_connect_with_username_password,
)
with sftp.connect_for_upload(
dataset=ds, connect=sftp_connect_with_username_password
) as con:
with sftp.connect_for_upload(dataset=ds) as con:
assert con.source_folder == RemotePath("/data/upload/librarian")
con.upload_files(
File.from_local(
Expand All @@ -113,10 +121,12 @@ def test_upload_two_files(
tmp_path.joinpath("file2.1.md").write_text("First part of file 2")
tmp_path.joinpath("file2.2.md").write_text("Second part of file 2")

sftp = SFTPFileTransfer(host=sftp_access.host, port=sftp_access.port)
with sftp.connect_for_upload(
dataset=ds, connect=sftp_connect_with_username_password
) as con:
sftp = SFTPFileTransfer(
host=sftp_access.host,
port=sftp_access.port,
connect=sftp_connect_with_username_password,
)
with sftp.connect_for_upload(dataset=ds) as con:
assert con.source_folder == RemotePath("/data/upload2")
con.upload_files(
File.from_local(path=tmp_path / "file2.1.md", base_path=tmp_path),
Expand All @@ -141,20 +151,24 @@ def test_upload_one_file_existing_source_folder(
tmp_path.joinpath("file3.2.md").write_text("Second part of file 3")

# First upload to ensure the folder exists.
sftp = SFTPFileTransfer(host=sftp_access.host, port=sftp_access.port)
with sftp.connect_for_upload(
dataset=ds, connect=sftp_connect_with_username_password
) as con:
sftp = SFTPFileTransfer(
host=sftp_access.host,
port=sftp_access.port,
connect=sftp_connect_with_username_password,
)
with sftp.connect_for_upload(dataset=ds) as con:
assert con.source_folder == RemotePath("/data/upload-multiple")
con.upload_files(
File.from_local(path=tmp_path / "file3.1.md", base_path=tmp_path),
)

# Second upload to test uploading to existing folder.
sftp = SFTPFileTransfer(host=sftp_access.host, port=sftp_access.port)
with sftp.connect_for_upload(
dataset=ds, connect=sftp_connect_with_username_password
) as con:
sftp = SFTPFileTransfer(
host=sftp_access.host,
port=sftp_access.port,
connect=sftp_connect_with_username_password,
)
with sftp.connect_for_upload(dataset=ds) as con:
assert con.source_folder == RemotePath("/data/upload-multiple")
con.upload_files(
File.from_local(path=tmp_path / "file3.2.md", base_path=tmp_path),
Expand All @@ -176,10 +190,12 @@ def test_revert_all_uploaded_files_single(
ds = Dataset(type="raw", source_folder=RemotePath("/data/revert-all-test-1"))
tmp_path.joinpath("file3.txt").write_text("File that should get reverted")

sftp = SFTPFileTransfer(host=sftp_access.host, port=sftp_access.port)
with sftp.connect_for_upload(
dataset=ds, connect=sftp_connect_with_username_password
) as con:
sftp = SFTPFileTransfer(
host=sftp_access.host,
port=sftp_access.port,
connect=sftp_connect_with_username_password,
)
with sftp.connect_for_upload(dataset=ds) as con:
file = File.from_local(path=tmp_path / "file3.txt", base_path=tmp_path)
con.upload_files(file)
con.revert_upload(file)
Expand All @@ -194,10 +210,12 @@ def test_revert_all_uploaded_files_two(
tmp_path.joinpath("file3.1.txt").write_text("File that should get reverted 1")
tmp_path.joinpath("file3.2.txt").write_text("File that should get reverted 2")

sftp = SFTPFileTransfer(host=sftp_access.host, port=sftp_access.port)
with sftp.connect_for_upload(
dataset=ds, connect=sftp_connect_with_username_password
) as con:
sftp = SFTPFileTransfer(
host=sftp_access.host,
port=sftp_access.port,
connect=sftp_connect_with_username_password,
)
with sftp.connect_for_upload(dataset=ds) as con:
file1 = File.from_local(path=tmp_path / "file3.1.txt", base_path=tmp_path)
file2 = File.from_local(path=tmp_path / "file3.2.txt", base_path=tmp_path)
con.upload_files(file1, file2)
Expand All @@ -213,10 +231,12 @@ def test_revert_one_uploaded_file(
tmp_path.joinpath("file4.txt").write_text("File that should get reverted")
tmp_path.joinpath("file5.txt").write_text("File that should be kept")

sftp = SFTPFileTransfer(host=sftp_access.host, port=sftp_access.port)
with sftp.connect_for_upload(
dataset=ds, connect=sftp_connect_with_username_password
) as con:
sftp = SFTPFileTransfer(
host=sftp_access.host,
port=sftp_access.port,
connect=sftp_connect_with_username_password,
)
with sftp.connect_for_upload(dataset=ds) as con:
file4 = File.from_local(path=tmp_path / "file4.txt", base_path=tmp_path)
file5 = File.from_local(path=tmp_path / "file5.txt", base_path=tmp_path)
con.upload_files(file4, file5)
Expand All @@ -237,10 +257,12 @@ def test_stat_uploaded_file(
ds = Dataset(type="raw", source_folder=RemotePath("/data/upload6"))
tmp_path.joinpath("file6.txt").write_text("File to test upload no 6")

sftp = SFTPFileTransfer(host=sftp_access.host, port=sftp_access.port)
with sftp.connect_for_upload(
dataset=ds, connect=sftp_connect_with_username_password
) as con:
sftp = SFTPFileTransfer(
host=sftp_access.host,
port=sftp_access.port,
connect=sftp_connect_with_username_password,
)
with sftp.connect_for_upload(dataset=ds) as con:
[uploaded] = con.upload_files(
File.from_local(path=tmp_path / "file6.txt", remote_path="upload_6.txt")
)
Expand Down Expand Up @@ -312,8 +334,10 @@ def test_upload_file_detects_checksum_mismatch(
)
tmp_path.joinpath("file7.txt").write_text("File to test upload no 7")

sftp = SFTPFileTransfer(host=sftp_access.host, port=sftp_access.port)
with sftp.connect_for_upload(dataset=ds, connect=sftp_corrupting_connect) as con:
sftp = SFTPFileTransfer(
host=sftp_access.host, port=sftp_access.port, connect=sftp_corrupting_connect
)
with sftp.connect_for_upload(dataset=ds) as con:
with pytest.raises(FileUploadError):
con.upload_files(
dataclasses.replace(
Expand Down Expand Up @@ -369,8 +393,10 @@ def test_upload_file_reverts_if_upload_fails(
ds = Dataset(type="raw", source_folder=RemotePath("/data/upload8"))
tmp_path.joinpath("file8.txt").write_text("File to test upload no 8")

sftp = SFTPFileTransfer(host=sftp_access.host, port=sftp_access.port)
with sftp.connect_for_upload(dataset=ds, connect=sftp_raising_connect) as con:
sftp = SFTPFileTransfer(
host=sftp_access.host, port=sftp_access.port, connect=sftp_raising_connect
)
with sftp.connect_for_upload(dataset=ds) as con:
with pytest.raises(RuntimeError):
con.upload_files(
File.from_local(
Expand All @@ -383,24 +409,17 @@ def test_upload_file_reverts_if_upload_fails(


class SFTPTestFileTransfer(SFTPFileTransfer):
def __init__(self, connect, **kwargs):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connect = connect

@contextmanager
def connect_for_download(
self, connect: Optional[Callable[..., Connection]] = None
) -> Iterator[SFTPDownloadConnection]:
connect = connect if connect is not None else self.connect
with super().connect_for_download(connect=connect) as connection:
def connect_for_download(self) -> Iterator[SFTPDownloadConnection]:
with super().connect_for_download() as connection:
yield connection

@contextmanager
def connect_for_upload(
self, dataset: Dataset, connect: Optional[Callable[..., Connection]] = None
) -> Iterator[SFTPUploadConnection]:
connect = connect if connect is not None else self.connect
with super().connect_for_upload(dataset=dataset, connect=connect) as connection:
def connect_for_upload(self, dataset: Dataset) -> Iterator[SFTPUploadConnection]:
with super().connect_for_upload(dataset=dataset) as connection:
yield connection


Expand Down