Skip to content

Commit

Permalink
Concurrency in pipe() (#901)
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant authored Oct 20, 2024
1 parent dd75a1a commit 7fea0f5
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 25 deletions.
48 changes: 28 additions & 20 deletions s3fs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ class S3FileSystem(AsyncFileSystem):
connect_timeout = 5
retries = 5
read_timeout = 15
default_block_size = 5 * 2**20
default_block_size = 50 * 2**20
protocol = ("s3", "s3a")
_extra_tokenize_attributes = ("default_block_size",)

Expand All @@ -299,7 +299,7 @@ def __init__(
cache_regions=False,
asynchronous=False,
loop=None,
max_concurrency=1,
max_concurrency=10,
fixed_upload_size: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -1133,8 +1133,11 @@ async def _call_and_read():

return await _error_wrapper(_call_and_read, retries=self.retries)

async def _pipe_file(self, path, data, chunksize=50 * 2**20, **kwargs):
async def _pipe_file(
self, path, data, chunksize=50 * 2**20, max_concurrency=None, **kwargs
):
bucket, key, _ = self.split_path(path)
concurrency = max_concurrency or self.max_concurrency
size = len(data)
# 5 GB is the limit for an S3 PUT
if size < min(5 * 2**30, 2 * chunksize):
Expand All @@ -1146,23 +1149,27 @@ async def _pipe_file(self, path, data, chunksize=50 * 2**20, **kwargs):
mpu = await self._call_s3(
"create_multipart_upload", Bucket=bucket, Key=key, **kwargs
)

# TODO: cancel MPU if the following fails
out = [
await self._call_s3(
"upload_part",
Bucket=bucket,
PartNumber=i + 1,
UploadId=mpu["UploadId"],
Body=data[off : off + chunksize],
Key=key,
ranges = list(range(0, len(data), chunksize))
inds = list(range(0, len(ranges), concurrency)) + [len(ranges)]
parts = []
for start, stop in zip(inds[:-1], inds[1:]):
out = await asyncio.gather(
*[
self._call_s3(
"upload_part",
Bucket=bucket,
PartNumber=i + 1,
UploadId=mpu["UploadId"],
Body=data[ranges[i] : ranges[i] + chunksize],
Key=key,
)
for i in range(start, stop)
]
)
parts.extend(
{"PartNumber": i + 1, "ETag": o["ETag"]}
for i, o in zip(range(start, stop), out)
)
for i, off in enumerate(range(0, len(data), chunksize))
]

parts = [
{"PartNumber": i + 1, "ETag": o["ETag"]} for i, o in enumerate(out)
]
await self._call_s3(
"complete_multipart_upload",
Bucket=bucket,
Expand Down Expand Up @@ -2145,7 +2152,7 @@ def __init__(
s3,
path,
mode="rb",
block_size=5 * 2**20,
block_size=50 * 2**20,
acl=False,
version_id=None,
fill_cache=True,
Expand Down Expand Up @@ -2365,6 +2372,7 @@ def n_bytes_left() -> int:
return len(self.buffer.getbuffer()) - self.buffer.tell()

min_chunk = 1 if final else self.blocksize
# TODO: concurrent here
if self.fs.fixed_upload_size:
# all chunks have fixed size, exception: last one can be smaller
while n_bytes_left() >= min_chunk:
Expand Down
5 changes: 4 additions & 1 deletion s3fs/tests/derived/s3fs_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
test_bucket_name = "test"
secure_bucket_name = "test-secure"
versioned_bucket_name = "test-versioned"
port = 5555
port = 5556
endpoint_uri = "http://127.0.0.1:%s/" % port


Expand Down Expand Up @@ -109,6 +109,9 @@ def _s3_base(self):
pass
timeout -= 0.1
time.sleep(0.1)
if proc.poll() is not None:
proc.terminate()
raise RuntimeError("Starting moto server failed")
print("server up")
yield
print("moto done")
Expand Down
11 changes: 7 additions & 4 deletions s3fs/tests/test_s3fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ def s3_base():
def reset_s3_fixture():
# We reuse the MotoServer for all tests
# But we do want a clean state for every test
requests.post(f"{endpoint_uri}/moto-api/reset")
try:
requests.post(f"{endpoint_uri}/moto-api/reset")
except:
pass


def get_boto3_client():
Expand Down Expand Up @@ -1253,7 +1256,7 @@ def test_write_fails(s3):


def test_write_blocks(s3):
with s3.open(test_bucket_name + "/temp", "wb") as f:
with s3.open(test_bucket_name + "/temp", "wb", block_size=5 * 2**20) as f:
f.write(b"a" * 2 * 2**20)
assert f.buffer.tell() == 2 * 2**20
assert not (f.parts)
Expand Down Expand Up @@ -1787,7 +1790,7 @@ def test_change_defaults_only_subsequent():
S3FileSystem.cachable = False # don't reuse instances with same pars

fs_default = S3FileSystem(client_kwargs={"endpoint_url": endpoint_uri})
assert fs_default.default_block_size == 5 * (1024**2)
assert fs_default.default_block_size == 50 * (1024**2)

fs_overridden = S3FileSystem(
default_block_size=64 * (1024**2),
Expand All @@ -1804,7 +1807,7 @@ def test_change_defaults_only_subsequent():

# Test the other file systems created to see if their block sizes changed
assert fs_overridden.default_block_size == 64 * (1024**2)
assert fs_default.default_block_size == 5 * (1024**2)
assert fs_default.default_block_size == 50 * (1024**2)
finally:
S3FileSystem.default_block_size = 5 * (1024**2)
S3FileSystem.cachable = True
Expand Down

0 comments on commit 7fea0f5

Please sign in to comment.