Skip to content

Commit

Permalink
Adjusted local-to-remote copy with folders (#308)
Browse files Browse the repository at this point in the history
The local-to-remote copy was failing when trying to copy a folder with
multiple files, as only the first file was actually tansferred. This
commit solves this issue by modifying the `aiotarstream` module to deal
with `tell` and `seek` operations, allowing to position the stream
pointer in the correct position.

In detail, two new classes have been implemented:

- `TellableStreamWrapper`, which keeps track of the actual stream
   pointer position;
- `SeekableStreamReaderWrapper`, which allows forward seek operations
  in a stream reader.

Note that these classes are internal to the `aiotarstram` module, so
nothing changes in the publicly exposed API.
  • Loading branch information
GlassOfWhiskey authored Dec 8, 2023
1 parent fd6274c commit a492ae2
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 21 deletions.
72 changes: 59 additions & 13 deletions streamflow/deployment/aiotarstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import time
from abc import ABC
from builtins import open as bltn_open
from typing import Any
from typing import Any, cast

from streamflow.core.data import StreamWrapper
from streamflow.deployment.stream import BaseStreamWrapper
Expand Down Expand Up @@ -189,23 +189,26 @@ def __init__(self, stream, mode, preset: int | None = None):


class FileStreamReaderWrapper(StreamWrapper):
def __init__(self, stream, size, blockinfo=None):
def __init__(self, stream, offset, size, blockinfo=None):
super().__init__(stream)
self.size = size
self.offset = offset
self.position = 0
self.closed = False
if blockinfo is None:
blockinfo = [(0, size)]
self.map_index = 0
self.map = []
lastpos = 0
realpos = self.offset
for offset, size in blockinfo:
if offset > lastpos:
self.map.append((False, lastpos, offset))
self.map.append((True, offset, offset + size))
self.map.append((False, lastpos, offset, None))
self.map.append((True, offset, offset + size, realpos))
realpos += size
lastpos = offset + size
if lastpos < self.size:
self.map.append((False, lastpos, self.size))
self.map.append((False, lastpos, self.size, None))

async def __aenter__(self):
return self
Expand All @@ -223,7 +226,7 @@ async def read(self, size: int | None = None):
else min(size, self.size - self.position)
)
while True:
data, start, stop = self.map[self.map_index]
data, start, stop, offset = self.map[self.map_index]
if start <= self.position < stop:
break
else:
Expand All @@ -232,6 +235,7 @@ async def read(self, size: int | None = None):
return b""
length = min(size, stop - self.position)
if data:
self.stream.seek(offset + (self.position - start))
buf = await self.stream.read(length)
self.position += len(buf)
return buf
Expand All @@ -243,16 +247,49 @@ async def write(self, data: Any):
raise NotImplementedError


class TellableStreamWrapper(BaseStreamWrapper):
def __init__(self, stream):
super().__init__(stream)
self.position: int = 0

async def read(self, size: int | None = None):
buf = await self.stream.read(size)
self.position += len(buf)
return buf

def tell(self):
return self.position

async def write(self, data: Any):
await self.stream.write(data)
self.position += len(data)


class SeekableStreamReaderWrapper(TellableStreamWrapper):
def __init__(self, stream):
super().__init__(stream)

async def seek(self, offset: int):
if offset > self.position:
await self.stream.read(offset - self.position)
self.position = offset
else:
raise tarfile.ReadError("Cannot seek backward with streams")

async def write(self, data: Any):
raise NotImplementedError


class AioTarInfo(tarfile.TarInfo):
@classmethod
async def fromtarfile(cls, tarstream):
buf = await tarstream.stream.read(tarfile.BLOCKSIZE)
obj = cls.frombuf(buf, tarstream.encoding, tarstream.errors)
obj.offset = tarstream.offset - tarfile.BLOCKSIZE
obj.offset = tarstream.stream.tell() - tarfile.BLOCKSIZE
return await obj._proc_member(tarstream)

def _proc_builtin(self, tarstream):
self.offset_data = tarstream.offset
self.offset_data = tarstream.stream.tell()
offset = self.offset_data
if self.isreg() or self.type not in tarfile.SUPPORTED_TYPES:
offset += self._block(self.size)
Expand Down Expand Up @@ -285,7 +322,7 @@ async def _proc_gnusparse_10(self, next, pax_headers, tarstream):
buf += await tarstream.stream.read(tarfile.BLOCKSIZE)
number, buf = buf.split(b"\n", 1)
sparse.append(int(number))
next.offset_data = tarstream.offset
next.offset_data = tarstream.stream.tell()
next.sparse = list(zip(sparse[::2], sparse[1::2]))

async def _proc_member(self, tarstream):
Expand Down Expand Up @@ -376,7 +413,7 @@ async def _proc_sparse(self, tarstream):
pos += 24
isextended = bool(buf[504])
self.sparse = structs
self.offset_data = tarstream.offset
self.offset_data = tarstream.stream.tell()
tarstream.offset = self.offset_data + self._block(self.size)
self.size = origsize
return self
Expand Down Expand Up @@ -412,9 +449,11 @@ def __init__(
if mode not in modes:
raise ValueError("mode must be 'r', 'a', 'w' or 'x'")
self.mode = mode
self._mode = stream.mode if hasattr(stream, "mode") else modes[mode]
self.stream = stream
self.offset = 0
if self.mode is not None and self.mode in "ra":
self.stream = SeekableStreamReaderWrapper(stream)
else:
self.stream = TellableStreamWrapper(stream)
self.index = 0
if format is not None:
self.format = format
Expand Down Expand Up @@ -765,7 +804,10 @@ async def extractfile(self, member):
tarinfo = await self.getmember(member) if isinstance(member, str) else member
if tarinfo.isreg() or tarinfo.type not in tarinfo.SUPPORTED_TYPES:
return self.fileobject(
stream=self.stream, size=tarinfo.size, blockinfo=tarinfo.sparse
stream=self.stream,
offset=tarinfo.offset_data,
size=tarinfo.size,
blockinfo=tarinfo.sparse,
)
elif tarinfo.islnk() or tarinfo.issym():
raise tarinfo.StreamError("cannot extract (sym)link as file object")
Expand Down Expand Up @@ -955,6 +997,10 @@ async def next(self):
m = self.firstmember
self.firstmember = None
return m
if self.offset != self.stream.tell():
if self.offset == 0:
return None
await cast(SeekableStreamReaderWrapper, self.stream).seek(self.offset)
tarinfo = None
while True:
try:
Expand Down
39 changes: 31 additions & 8 deletions tests/test_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,36 @@ async def test_directory_to_directory(
else:
dst_path = posixpath.join("/tmp", utils.random_name())
path_processor = get_path_processor(src_connector)
inner_file = utils.random_name()
inner_file_1 = utils.random_name()
inner_file_2 = utils.random_name()
try:
await remotepath.mkdir(src_connector, [src_location], src_path)
await remotepath.write(
src_connector,
src_location,
path_processor.join(src_path, inner_file),
path_processor.join(src_path, inner_file_1),
"Hello",
)
await remotepath.write(
src_connector,
src_location,
path_processor.join(src_path, inner_file_2),
"StreamFlow",
)
src_path = await remotepath.follow_symlink(
context, src_connector, src_location, src_path
)
src_digest = await remotepath.checksum(
src_digest_1 = await remotepath.checksum(
context,
src_connector,
src_location,
path_processor.join(src_path, inner_file_1),
)
src_digest_2 = await remotepath.checksum(
context,
src_connector,
src_location,
path_processor.join(src_path, inner_file),
path_processor.join(src_path, inner_file_1),
)
context.data_manager.register_path(
location=src_location,
Expand All @@ -85,15 +98,25 @@ async def test_directory_to_directory(
path_processor = get_path_processor(dst_connector)
assert await remotepath.exists(dst_connector, dst_location, dst_path)
assert await remotepath.exists(
dst_connector, dst_location, path_processor.join(dst_path, inner_file)
dst_connector, dst_location, path_processor.join(dst_path, inner_file_1)
)
dst_digest = await remotepath.checksum(
assert await remotepath.exists(
dst_connector, dst_location, path_processor.join(dst_path, inner_file_2)
)
dst_digest_1 = await remotepath.checksum(
context,
dst_connector,
dst_location,
path_processor.join(dst_path, inner_file),
path_processor.join(dst_path, inner_file_1),
)
assert src_digest == dst_digest
assert src_digest_1 == dst_digest_1
dst_digest_2 = await remotepath.checksum(
context,
dst_connector,
dst_location,
path_processor.join(dst_path, inner_file_1),
)
assert src_digest_2 == dst_digest_2
finally:
await remotepath.rm(src_connector, src_location, src_path)
await remotepath.rm(dst_connector, dst_location, dst_path)
Expand Down

0 comments on commit a492ae2

Please sign in to comment.