From 37ae15c1a5e52d43e26e0f999609ff20d5cfa30f Mon Sep 17 00:00:00 2001 From: LanderOtto <48457093+LanderOtto@users.noreply.github.com> Date: Sun, 10 Dec 2023 09:36:33 +0100 Subject: [PATCH] Fixed remote-to-local copy of nested directories (#309) This commit fixes the remote-to-local copy of nested directories. When transferring nested directory to a remote location, an empty directory was created with the same name of her parent. This depended on a wrong handling on paths in the `base Connector`. To avoid regressions, a new unit test was added. Furthermore, this commit fixes the `seek` coroutine in the `aiotarstream` module, which was never awaited, and correctly handles the case when a `seek` is asked to move the stream pointer to its current position, which resulted in a wrongly-raised exception. --- streamflow/deployment/aiotarstream.py | 4 +- streamflow/deployment/connector/base.py | 10 +- tests/test_transfer.py | 229 +++++++++++++++++------- 3 files changed, 180 insertions(+), 63 deletions(-) diff --git a/streamflow/deployment/aiotarstream.py b/streamflow/deployment/aiotarstream.py index 4d6dcad6..0f08d4ce 100644 --- a/streamflow/deployment/aiotarstream.py +++ b/streamflow/deployment/aiotarstream.py @@ -235,7 +235,7 @@ async def read(self, size: int | None = None): return b"" length = min(size, stop - self.position) if data: - self.stream.seek(offset + (self.position - start)) + await self.stream.seek(offset + (self.position - start)) buf = await self.stream.read(length) self.position += len(buf) return buf @@ -273,7 +273,7 @@ async def seek(self, offset: int): if offset > self.position: await self.stream.read(offset - self.position) self.position = offset - else: + elif offset < self.position: raise tarfile.ReadError("Cannot seek backward with streams") async def write(self, data: Any): diff --git a/streamflow/deployment/connector/base.py b/streamflow/deployment/connector/base.py index 42b02a1b..0d0a35d2 100644 --- a/streamflow/deployment/connector/base.py +++ b/streamflow/deployment/connector/base.py @@ -38,8 +38,11 @@ async def extract_tar_stream( transferBufferSize: int | None = None, ) -> None: async for member in tar: + # If `dst` is a directory, copy the content of `src` inside `dst` if os.path.isdir(dst) and member.path == posixpath.basename(src): await tar.extract(member, dst) + + # Otherwise, if copying a file, simply move it inside `dst` elif member.isfile(): async with await tar.extractfile(member) as inputfile: path = os.path.normpath( @@ -50,9 +53,14 @@ async def extract_tar_stream( with open(path, "wb") as outputfile: while content := await inputfile.read(transferBufferSize): outputfile.write(content) + + # Otherwise, if copying a directory, modify the member path to + # move all the file hierarchy inside `dst` else: member.path = posixpath.relpath(member.path, posixpath.basename(src)) - await tar.extract(member, os.path.normpath(os.path.join(dst, member.path))) + await tar.extract( + member, os.path.normpath(os.path.join(dst, os.path.curdir)) + ) class BaseConnector(Connector, FutureAware): diff --git a/tests/test_transfer.py b/tests/test_transfer.py index eabeaf6b..6ad43200 100644 --- a/tests/test_transfer.py +++ b/tests/test_transfer.py @@ -7,7 +7,7 @@ import pytest_asyncio from streamflow.core import utils -from streamflow.core.data import DataType +from streamflow.core.data import DataType, FileType from streamflow.core.deployment import Connector, Location from streamflow.data import remotepath from streamflow.deployment.connector import LocalConnector @@ -15,6 +15,109 @@ from tests.utils.deployment import get_location +async def _compare_remote_dirs( + context, + src_connector, + src_location, + src_path, + dst_connector, + dst_location, + dst_path, +): + assert await remotepath.exists(dst_connector, dst_location, dst_path) + src_path_processor = get_path_processor(src_connector) + dst_path_processor = get_path_processor(dst_connector) + + # the two dirs must have the same elements order + src_files, dst_files = await asyncio.gather( + asyncio.create_task( + remotepath.listdir(src_connector, src_location, src_path, FileType.FILE) + ), + asyncio.create_task( + remotepath.listdir(dst_connector, dst_location, dst_path, FileType.FILE) + ), + ) + assert len(src_files) == len(dst_files) + for src_file, dst_file in zip(sorted(src_files), sorted(dst_files)): + checksums = await asyncio.gather( + asyncio.create_task( + remotepath.checksum( + context, + src_connector, + src_location, + src_path_processor.join(src_path, src_file), + ) + ), + asyncio.create_task( + remotepath.checksum( + context, + dst_connector, + dst_location, + dst_path_processor.join(dst_path, dst_file), + ) + ), + ) + assert checksums[0] == checksums[1] + + src_dirs, dst_dirs = await asyncio.gather( + asyncio.create_task( + remotepath.listdir( + src_connector, src_location, src_path, FileType.DIRECTORY + ) + ), + asyncio.create_task( + remotepath.listdir( + dst_connector, dst_location, dst_path, FileType.DIRECTORY + ) + ), + ) + assert len(src_dirs) == len(dst_dirs) + tasks = [] + for src_dir, dst_dir in zip(sorted(src_dirs), sorted(dst_dirs)): + assert os.path.basename(src_dir) == os.path.basename(dst_dir) + tasks.append( + asyncio.create_task( + _compare_remote_dirs( + context, + src_connector, + src_location, + src_dir, + dst_connector, + dst_location, + dst_dir, + ) + ) + ) + await asyncio.gather(*tasks) + + +async def _create_tmp_dir(context, connector, location, root=None, lvl=None, n_files=0): + path_processor = get_path_processor(connector) + dir_lvl = f"-{lvl}" if lvl else "" + if isinstance(src_connector, LocalConnector): + dir_path = os.path.join( + root if root else tempfile.gettempdir(), + f"dir{dir_lvl}-{utils.random_name()}", + ) + else: + dir_path = os.path.join( + root if root else "/tmp", f"dir{dir_lvl}-{utils.random_name()}" + ) + await remotepath.mkdir(connector, [location], dir_path) + + dir_path = await remotepath.follow_symlink(context, connector, location, dir_path) + file_lvl = f"-{lvl}" if lvl else "" + for i in range(n_files): + file_name = f"file{file_lvl}-{i}-{utils.random_name()}" + await remotepath.write( + connector, + location, + path_processor.join(dir_path, file_name), + f"Hello from {file_name}", + ) + return dir_path + + @pytest_asyncio.fixture(scope="module") async def src_location(context, deployment_src) -> Location: return await get_location(context, deployment_src) @@ -39,55 +142,67 @@ def dst_connector(context, dst_location) -> Connector: async def test_directory_to_directory( context, src_connector, src_location, dst_connector, dst_location ): - """Test transferring a directory and its content from one location to another.""" - if isinstance(src_connector, LocalConnector): - src_path = os.path.join(tempfile.gettempdir(), utils.random_name()) - else: - src_path = posixpath.join("/tmp", utils.random_name()) - if isinstance(dst_connector, LocalConnector): - dst_path = os.path.join(tempfile.gettempdir(), utils.random_name()) - else: - dst_path = posixpath.join("/tmp", utils.random_name()) - path_processor = get_path_processor(src_connector) - inner_file_1 = utils.random_name() - inner_file_2 = utils.random_name() + src_path = None + dst_path = None + # dir_0 + # |- file_0 + # |- file_1 + # |- file_2 + # |- file_3 + # |- dir_0_0 + # | |- file_0_0_0 + # | |- file_0_0_1 + # | |- dir_0_0_0 + # | | |- file_0_0_0_1 + # | | |- file_0_0_0_2 + # |- dir_0_1 + # | |- file_0_1_0 + # | |- file_0_1_1 + # | |- file_0_1_2 + # |- dir_0_2 + # | | empty try: - await remotepath.mkdir(src_connector, [src_location], src_path) - await remotepath.write( - src_connector, - src_location, - path_processor.join(src_path, inner_file_1), - "Hello", - ) - await remotepath.write( - src_connector, - src_location, - path_processor.join(src_path, inner_file_2), - "StreamFlow", + # create src structure + src_path = await _create_tmp_dir( + context, src_connector, src_location, n_files=4 ) + for i in range(2): + inner_dir = await _create_tmp_dir( + context, + src_connector, + src_location, + root=src_path, + n_files=2 + i if i < 2 else 0, + lvl=f"{i}", + ) + if i == 0: + await _create_tmp_dir( + context, + src_connector, + src_location, + root=inner_dir, + n_files=2, + lvl=f"{i}-0", + ) src_path = await remotepath.follow_symlink( context, src_connector, src_location, src_path ) - src_digest_1 = await remotepath.checksum( - context, - src_connector, - src_location, - path_processor.join(src_path, inner_file_1), - ) - src_digest_2 = await remotepath.checksum( - context, - src_connector, - src_location, - path_processor.join(src_path, inner_file_1), - ) + + # dst init + if isinstance(dst_connector, LocalConnector): + dst_path = os.path.join(tempfile.gettempdir(), utils.random_name()) + else: + dst_path = posixpath.join("/tmp", utils.random_name()) + + # save src_path into StreamFlow context.data_manager.register_path( location=src_location, - path=await remotepath.follow_symlink( - context, src_connector, src_location, src_path - ), + path=src_path, relpath=src_path, data_type=DataType.PRIMARY, ) + + # transfer src_path to dst_path await context.data_manager.transfer_data( src_location=src_location, src_path=src_path, @@ -95,31 +210,25 @@ async def test_directory_to_directory( dst_path=dst_path, writable=False, ) - path_processor = get_path_processor(dst_connector) - assert await remotepath.exists(dst_connector, dst_location, dst_path) - assert await remotepath.exists( - dst_connector, dst_location, path_processor.join(dst_path, inner_file_1) - ) - assert await remotepath.exists( - dst_connector, dst_location, path_processor.join(dst_path, inner_file_2) - ) - dst_digest_1 = await remotepath.checksum( - context, - dst_connector, - dst_location, - path_processor.join(dst_path, inner_file_1), - ) - assert src_digest_1 == dst_digest_1 - dst_digest_2 = await remotepath.checksum( + + # check if dst exists + await remotepath.exists(dst_connector, dst_location, dst_path) + + # check that src and dst have the same sub dirs and files + await _compare_remote_dirs( context, + src_connector, + src_location, + src_path, dst_connector, dst_location, - path_processor.join(dst_path, inner_file_1), + dst_path, ) - assert src_digest_2 == dst_digest_2 finally: - await remotepath.rm(src_connector, src_location, src_path) - await remotepath.rm(dst_connector, dst_location, dst_path) + if src_path: + await remotepath.rm(src_connector, src_location, src_path) + if dst_path: + await remotepath.rm(dst_connector, dst_location, dst_path) @pytest.mark.asyncio