1
- import io
1
+ import multiprocessing
2
+ import warnings
2
3
from pathlib import Path
3
- from typing import List , Union
4
+ from queue import Empty
5
+ from typing import Generator , List , Tuple , Union
6
+
7
+ import boto3
4
8
5
9
from .file import File , Status
6
- from .transfer_manager import transfer_manager
7
10
8
11
9
12
class Fetcher :
@@ -15,70 +18,96 @@ def __init__(
15
18
aws_secret_access_key : str ,
16
19
region_name : str ,
17
20
bucket_name : str ,
18
- ordered = True ,
19
- buffer_size = 1024 ,
21
+ buffer_size : int = 1000 ,
20
22
n_workers = 32 ,
21
- ** transfer_manager_kwargs ,
23
+ worker_batch_size = 100 ,
24
+ callback = lambda x : x ,
25
+ ordered : bool = False ,
22
26
):
23
- self .paths = paths
24
- self .ordered = ordered
25
- self .buffer_size = buffer_size
26
- self .transfer_manager = transfer_manager (
27
- endpoint_url = endpoint_url ,
28
- aws_access_key_id = aws_access_key_id ,
29
- aws_secret_access_key = aws_secret_access_key ,
30
- region_name = region_name ,
31
- n_workers = n_workers ,
32
- ** transfer_manager_kwargs ,
33
- )
27
+ self .paths = multiprocessing .Manager ().list (list (enumerate (paths ))[::- 1 ])
34
28
self .bucket_name = bucket_name
35
- self .files : List [ File ] = []
36
- self .current_path_index = 0
37
-
38
- def __len__ ( self ):
39
- return len ( self .paths )
40
-
41
- def __iter__ ( self ):
42
- for _ in range ( self .buffer_size ):
43
- self .queue_download_ ()
29
+ self .endpoint_url = endpoint_url
30
+ self .aws_access_key_id = aws_access_key_id
31
+ self . aws_secret_access_key = aws_secret_access_key
32
+ self . region_name = region_name
33
+ self .n_workers = n_workers
34
+ self . buffer_size = min ( buffer_size , len ( paths ))
35
+ self . worker_batch_size = worker_batch_size
36
+ self .ordered = ordered
37
+ self .callback = callback
44
38
45
- if self .ordered :
46
- for _ in range (len (self )):
47
- yield self .process_index (0 )
39
+ if ordered :
40
+ # TODO: fix this issue
41
+ warnings .warn (
42
+ "buffer_size is ignored when ordered=True which can cause out of memory"
43
+ )
44
+ self .results = multiprocessing .Manager ().dict ()
45
+ self .result_order = multiprocessing .Manager ().list (range (len (paths )))
48
46
else :
49
- for _ in range (len (self )):
50
- for index , file in enumerate (self .files ):
51
- if file .future .done ():
52
- break
53
- else :
54
- index = 0
55
- yield self .process_index (index )
47
+ self .file_queue = multiprocessing .Queue (maxsize = buffer_size )
56
48
57
- def process_index (self , index ):
58
- file = self . files . pop ( index )
59
- self . queue_download_ ()
60
- try :
61
- file . future . result ()
62
- return file . with_status ( Status . done )
63
- except Exception as e :
64
- return file . with_status ( Status . error , exception = e )
49
+ def _create_s3_client (self ):
50
+ return boto3 . client (
51
+ "s3" ,
52
+ endpoint_url = self . endpoint_url ,
53
+ aws_access_key_id = self . aws_access_key_id ,
54
+ aws_secret_access_key = self . aws_secret_access_key ,
55
+ region_name = self . region_name ,
56
+ )
65
57
66
- def queue_download_ (self ):
67
- if self .current_path_index < len (self ):
68
- buffer = io .BytesIO ()
69
- path = self .paths [self .current_path_index ]
70
- self .files .append (
71
- File (
72
- buffer = buffer ,
73
- future = self .transfer_manager .download (
74
- fileobj = buffer ,
75
- bucket = self .bucket_name ,
76
- key = str (path ),
58
+ def download_batch (self , batch : List [Tuple [int , Union [Path , str ]]]):
59
+ client = self ._create_s3_client ()
60
+ for index , path in batch :
61
+ try :
62
+ file = File (
63
+ content = self .callback (
64
+ client .get_object (Bucket = self .bucket_name , Key = str (path ))[
65
+ "Body"
66
+ ].read ()
77
67
),
78
68
path = path ,
69
+ status = Status .succeeded ,
79
70
)
80
- )
81
- self .current_path_index += 1
71
+ except Exception as e :
72
+ file = File (content = None , path = path , status = Status .failed , exception = e )
73
+ if self .ordered :
74
+ self .results [index ] = file
75
+ else :
76
+ self .file_queue .put (file )
77
+
78
+ def _worker (self ):
79
+ while len (self .paths ) > 0 :
80
+ batch = []
81
+ for _ in range (min (self .worker_batch_size , len (self .paths ))):
82
+ try :
83
+ index , path = self .paths .pop ()
84
+ batch .append ((index , path ))
85
+ except IndexError :
86
+ break
87
+ if len (batch ) > 0 :
88
+ self .download_batch (batch )
89
+
90
+ def __iter__ (self ) -> Generator [File , None , None ]:
91
+ workers = []
92
+ for _ in range (self .n_workers ):
93
+ worker_process = multiprocessing .Process (target = self ._worker )
94
+ worker_process .start ()
95
+ workers .append (worker_process )
96
+
97
+ if self .ordered :
98
+ for i in self .result_order :
99
+ while any (p .is_alive () for p in workers ) and i not in self .results :
100
+ continue # wait for the item to appear
101
+ yield self .results .pop (i )
102
+ else :
103
+ while any (p .is_alive () for p in workers ) or not self .file_queue .empty ():
104
+ try :
105
+ yield self .file_queue .get (timeout = 1 )
106
+ except Empty :
107
+ pass
108
+
109
+ for worker in workers :
110
+ worker .join ()
82
111
83
- def close (self ):
84
- self .transfer_manager . shutdown ( )
112
+ def __len__ (self ):
113
+ return len ( self .paths )
0 commit comments