2
2
import io
3
3
import mimetypes
4
4
import os
5
- from typing import Any , AsyncIterator , Awaitable , Callable , Collection , Dict , Optional
5
+ from typing import (
6
+ Any ,
7
+ AsyncIterator ,
8
+ Awaitable ,
9
+ Callable ,
10
+ Collection ,
11
+ Dict ,
12
+ Mapping ,
13
+ Optional ,
14
+ cast ,
15
+ )
6
16
from urllib .parse import urlparse
7
17
8
18
import httpx
@@ -59,7 +69,7 @@ def webhook_headers() -> "dict[str, str]":
59
69
60
70
async def on_request_trace_context_hook (request : httpx .Request ) -> None :
61
71
ctx = current_trace_context () or {}
62
- request .headers .update (ctx )
72
+ request .headers .update (cast ( Mapping [ str , str ], ctx ) )
63
73
64
74
65
75
def httpx_webhook_client () -> httpx .AsyncClient :
@@ -109,6 +119,22 @@ def httpx_file_client() -> httpx.AsyncClient:
109
119
)
110
120
111
121
122
+ class ChunkFileReader :
123
+ def __init__ (self , fh : io .IOBase ) -> None :
124
+ self .fh = fh
125
+
126
+ async def __aiter__ (self ) -> AsyncIterator [bytes ]:
127
+ self .fh .seek (0 )
128
+ while True :
129
+ chunk = self .fh .read (1024 * 1024 )
130
+ if isinstance (chunk , str ):
131
+ chunk = chunk .encode ("utf-8" )
132
+ if not chunk :
133
+ log .info ("finished reading file" )
134
+ break
135
+ yield chunk
136
+
137
+
112
138
# there's a case for splitting this apart or inlining parts of it
113
139
# I'm somewhat sympathetic to separating webhooks and files, but they both have
114
140
# the same semantics of holding a client for the lifetime of runner
@@ -159,10 +185,11 @@ async def sender(response: Any, event: WebhookEvent) -> None:
159
185
160
186
# files
161
187
162
- async def upload_file (self , fh : io .IOBase , url : Optional [str ]) -> str :
188
+ async def upload_file (
189
+ self , fh : io .IOBase , * , url : Optional [str ], prediction_id : Optional [str ]
190
+ ) -> str :
163
191
"""put file to signed endpoint"""
164
192
log .debug ("upload_file" )
165
- fh .seek (0 )
166
193
# try to guess the filename of the given object
167
194
name = getattr (fh , "name" , "file" )
168
195
filename = os .path .basename (name ) or "file"
@@ -180,17 +207,12 @@ async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str:
180
207
# ensure trailing slash
181
208
url_with_trailing_slash = url if url .endswith ("/" ) else url + "/"
182
209
183
- async def chunk_file_reader () -> AsyncIterator [bytes ]:
184
- while 1 :
185
- chunk = fh .read (1024 * 1024 )
186
- if isinstance (chunk , str ):
187
- chunk = chunk .encode ("utf-8" )
188
- if not chunk :
189
- log .info ("finished reading file" )
190
- break
191
- yield chunk
192
-
193
210
url = url_with_trailing_slash + filename
211
+
212
+ headers = {"Content-Type" : content_type }
213
+ if prediction_id is not None :
214
+ headers ["X-Prediction-ID" ] = prediction_id
215
+
194
216
# this is a somewhat unfortunate hack, but it works
195
217
# and is critical for upload training/quantization outputs
196
218
# if we get multipart uploads working or a separate API route
@@ -200,29 +222,36 @@ async def chunk_file_reader() -> AsyncIterator[bytes]:
200
222
resp1 = await self .file_client .put (
201
223
url ,
202
224
content = b"" ,
203
- headers = { "Content-Type" : content_type } ,
225
+ headers = headers ,
204
226
follow_redirects = False ,
205
227
)
206
228
if resp1 .status_code == 307 and resp1 .headers ["Location" ]:
207
229
log .info ("got file upload redirect from api" )
208
230
url = resp1 .headers ["Location" ]
231
+
209
232
log .info ("doing real upload to %s" , url )
210
233
resp = await self .file_client .put (
211
234
url ,
212
- content = chunk_file_reader ( ),
213
- headers = { "Content-Type" : content_type } ,
235
+ content = ChunkFileReader ( fh ),
236
+ headers = headers ,
214
237
)
215
238
# TODO: if file size is >1MB, show upload throughput
216
239
resp .raise_for_status ()
217
240
218
- # strip any signing gubbins from the URL
219
- final_url = urlparse (str (resp .url ))._replace (query = "" ).geturl ()
241
+ # Try to extract the final asset URL from the `Location` header
242
+ # otherwise fallback to the URL of the final request.
243
+ final_url = str (resp .url )
244
+ if "location" in resp .headers :
245
+ final_url = resp .headers .get ("location" )
220
246
221
- return final_url
247
+ # strip any signing gubbins from the URL
248
+ return urlparse (final_url )._replace (query = "" ).geturl ()
222
249
223
250
# this previously lived in json.upload_files, but it's clearer here
224
251
# this is a great pattern that should be adopted for input files
225
- async def upload_files (self , obj : Any , url : Optional [str ]) -> Any :
252
+ async def upload_files (
253
+ self , obj : Any , * , url : Optional [str ], prediction_id : Optional [str ]
254
+ ) -> Any :
226
255
"""
227
256
Iterates through an object from make_encodeable and uploads any files.
228
257
When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files.
@@ -234,15 +263,21 @@ async def upload_files(self, obj: Any, url: Optional[str]) -> Any:
234
263
# TODO: upload concurrently
235
264
if isinstance (obj , dict ):
236
265
return {
237
- key : await self .upload_files (value , url ) for key , value in obj .items ()
266
+ key : await self .upload_files (
267
+ value , url = url , prediction_id = prediction_id
268
+ )
269
+ for key , value in obj .items ()
238
270
}
239
271
if isinstance (obj , list ):
240
- return [await self .upload_files (value , url ) for value in obj ]
272
+ return [
273
+ await self .upload_files (value , url = url , prediction_id = prediction_id )
274
+ for value in obj
275
+ ]
241
276
if isinstance (obj , Path ):
242
277
with obj .open ("rb" ) as f :
243
- return await self .upload_file (f , url )
278
+ return await self .upload_file (f , url = url , prediction_id = prediction_id )
244
279
if isinstance (obj , io .IOBase ):
245
- return await self .upload_file (obj , url )
280
+ return await self .upload_file (obj , url = url , prediction_id = prediction_id )
246
281
return obj
247
282
248
283
# we could also handle inputs here, with a convert_prediction_input function
0 commit comments