Skip to content

Commit 9dfaba1

Browse files
improve!: use multiprocessing in fetcher
BREAKING CHANGE
1 parent b78bd3e commit 9dfaba1

File tree

2 files changed

+93
-68
lines changed

2 files changed

+93
-68
lines changed

fast_s3/fetcher.py

Lines changed: 88 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
import io
1+
import multiprocessing
2+
import warnings
23
from pathlib import Path
3-
from typing import List, Union
4+
from queue import Empty
5+
from typing import Generator, List, Tuple, Union
6+
7+
import boto3
48

59
from .file import File, Status
6-
from .transfer_manager import transfer_manager
710

811

912
class Fetcher:
@@ -15,70 +18,96 @@ def __init__(
1518
aws_secret_access_key: str,
1619
region_name: str,
1720
bucket_name: str,
18-
ordered=True,
19-
buffer_size=1024,
21+
buffer_size: int = 1000,
2022
n_workers=32,
21-
**transfer_manager_kwargs,
23+
worker_batch_size=100,
24+
callback=lambda x: x,
25+
ordered: bool = False,
2226
):
23-
self.paths = paths
24-
self.ordered = ordered
25-
self.buffer_size = buffer_size
26-
self.transfer_manager = transfer_manager(
27-
endpoint_url=endpoint_url,
28-
aws_access_key_id=aws_access_key_id,
29-
aws_secret_access_key=aws_secret_access_key,
30-
region_name=region_name,
31-
n_workers=n_workers,
32-
**transfer_manager_kwargs,
33-
)
27+
self.paths = multiprocessing.Manager().list(list(enumerate(paths))[::-1])
3428
self.bucket_name = bucket_name
35-
self.files: List[File] = []
36-
self.current_path_index = 0
37-
38-
def __len__(self):
39-
return len(self.paths)
40-
41-
def __iter__(self):
42-
for _ in range(self.buffer_size):
43-
self.queue_download_()
29+
self.endpoint_url = endpoint_url
30+
self.aws_access_key_id = aws_access_key_id
31+
self.aws_secret_access_key = aws_secret_access_key
32+
self.region_name = region_name
33+
self.n_workers = n_workers
34+
self.buffer_size = min(buffer_size, len(paths))
35+
self.worker_batch_size = worker_batch_size
36+
self.ordered = ordered
37+
self.callback = callback
4438

45-
if self.ordered:
46-
for _ in range(len(self)):
47-
yield self.process_index(0)
39+
if ordered:
40+
# TODO: fix this issue
41+
warnings.warn(
42+
"buffer_size is ignored when ordered=True which can cause out of memory"
43+
)
44+
self.results = multiprocessing.Manager().dict()
45+
self.result_order = multiprocessing.Manager().list(range(len(paths)))
4846
else:
49-
for _ in range(len(self)):
50-
for index, file in enumerate(self.files):
51-
if file.future.done():
52-
break
53-
else:
54-
index = 0
55-
yield self.process_index(index)
47+
self.file_queue = multiprocessing.Queue(maxsize=buffer_size)
5648

57-
def process_index(self, index):
58-
file = self.files.pop(index)
59-
self.queue_download_()
60-
try:
61-
file.future.result()
62-
return file.with_status(Status.done)
63-
except Exception as e:
64-
return file.with_status(Status.error, exception=e)
49+
def _create_s3_client(self):
50+
return boto3.client(
51+
"s3",
52+
endpoint_url=self.endpoint_url,
53+
aws_access_key_id=self.aws_access_key_id,
54+
aws_secret_access_key=self.aws_secret_access_key,
55+
region_name=self.region_name,
56+
)
6557

66-
def queue_download_(self):
67-
if self.current_path_index < len(self):
68-
buffer = io.BytesIO()
69-
path = self.paths[self.current_path_index]
70-
self.files.append(
71-
File(
72-
buffer=buffer,
73-
future=self.transfer_manager.download(
74-
fileobj=buffer,
75-
bucket=self.bucket_name,
76-
key=str(path),
58+
def download_batch(self, batch: List[Tuple[int, Union[Path, str]]]):
59+
client = self._create_s3_client()
60+
for index, path in batch:
61+
try:
62+
file = File(
63+
content=self.callback(
64+
client.get_object(Bucket=self.bucket_name, Key=str(path))[
65+
"Body"
66+
].read()
7767
),
7868
path=path,
69+
status=Status.succeeded,
7970
)
80-
)
81-
self.current_path_index += 1
71+
except Exception as e:
72+
file = File(content=None, path=path, status=Status.failed, exception=e)
73+
if self.ordered:
74+
self.results[index] = file
75+
else:
76+
self.file_queue.put(file)
77+
78+
def _worker(self):
79+
while len(self.paths) > 0:
80+
batch = []
81+
for _ in range(min(self.worker_batch_size, len(self.paths))):
82+
try:
83+
index, path = self.paths.pop()
84+
batch.append((index, path))
85+
except IndexError:
86+
break
87+
if len(batch) > 0:
88+
self.download_batch(batch)
89+
90+
def __iter__(self) -> Generator[File, None, None]:
91+
workers = []
92+
for _ in range(self.n_workers):
93+
worker_process = multiprocessing.Process(target=self._worker)
94+
worker_process.start()
95+
workers.append(worker_process)
96+
97+
if self.ordered:
98+
for i in self.result_order:
99+
while any(p.is_alive() for p in workers) and i not in self.results:
100+
continue # wait for the item to appear
101+
yield self.results.pop(i)
102+
else:
103+
while any(p.is_alive() for p in workers) or not self.file_queue.empty():
104+
try:
105+
yield self.file_queue.get(timeout=1)
106+
except Empty:
107+
pass
108+
109+
for worker in workers:
110+
worker.join()
82111

83-
def close(self):
84-
self.transfer_manager.shutdown()
112+
def __len__(self):
113+
return len(self.paths)

fast_s3/file.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
1-
import io
21
from enum import Enum
32
from pathlib import Path
4-
from typing import Optional, Union
3+
from typing import Any, Optional, Union
54

65
from pydantic import BaseModel
7-
from s3transfer.futures import TransferFuture
86

97

108
class Status(str, Enum):
11-
pending = "pending"
12-
done = "done"
13-
error = "error"
9+
succeeded = "succeeded"
10+
failed = "failed"
1411

1512

1613
class File(BaseModel, arbitrary_types_allowed=True):
17-
buffer: io.BytesIO
18-
future: TransferFuture
14+
content: Any
1915
path: Union[str, Path]
20-
status: Status = Status.pending
16+
status: Status
2117
exception: Optional[Exception] = None
2218

2319
def with_status(self, status: Status, exception: Optional[Exception] = None):

0 commit comments

Comments
 (0)