23
23
from graphql import DocumentNode , ExecutionResult , print_ast
24
24
from multidict import CIMultiDictProxy
25
25
26
- from ..utils import extract_files
27
26
from .appsync_auth import AppSyncAuthentication
28
27
from .async_transport import AsyncTransport
29
28
from .common .aiohttp_closed_event import create_aiohttp_closed_event
33
32
TransportProtocolError ,
34
33
TransportServerError ,
35
34
)
35
+ from .file_upload import FileVar , close_files , extract_files , open_files
36
36
37
37
log = logging .getLogger (__name__ )
38
38
@@ -207,6 +207,10 @@ async def execute(
207
207
file_classes = self .file_classes ,
208
208
)
209
209
210
+ # Opening the files using the FileVar parameters
211
+ open_files (list (files .values ()), transport_supports_streaming = True )
212
+ self .files = files
213
+
210
214
# Save the nulled variable values in the payload
211
215
payload ["variables" ] = nulled_variable_values
212
216
@@ -220,8 +224,8 @@ async def execute(
220
224
file_map = {str (i ): [path ] for i , path in enumerate (files )}
221
225
222
226
# Enumerate the file streams
223
- # Will generate something like {'0': <_io.BufferedReader ...> }
224
- file_streams = {str (i ): files [path ] for i , path in enumerate (files )}
227
+ # Will generate something like {'0': FileVar object }
228
+ file_vars = {str (i ): files [path ] for i , path in enumerate (files )}
225
229
226
230
# Add the payload to the operations field
227
231
operations_str = self .json_serialize (payload )
@@ -235,12 +239,15 @@ async def execute(
235
239
log .debug ("file_map %s" , file_map_str )
236
240
data .add_field ("map" , file_map_str , content_type = "application/json" )
237
241
238
- # Add the extracted files as remaining fields
239
- for k , f in file_streams .items ():
240
- name = getattr (f , "name" , k )
241
- content_type = getattr (f , "content_type" , None )
242
+ for k , file_var in file_vars .items ():
243
+ assert isinstance (file_var , FileVar )
242
244
243
- data .add_field (k , f , filename = name , content_type = content_type )
245
+ data .add_field (
246
+ k ,
247
+ file_var .f ,
248
+ filename = file_var .filename ,
249
+ content_type = file_var .content_type ,
250
+ )
244
251
245
252
post_args : Dict [str , Any ] = {"data" : data }
246
253
@@ -267,51 +274,59 @@ async def execute(
267
274
if self .session is None :
268
275
raise TransportClosed ("Transport is not connected" )
269
276
270
- async with self .session .post (self .url , ssl = self .ssl , ** post_args ) as resp :
271
-
272
- # Saving latest response headers in the transport
273
- self .response_headers = resp .headers
277
+ try :
278
+ async with self .session .post (self .url , ssl = self .ssl , ** post_args ) as resp :
274
279
275
- async def raise_response_error (
276
- resp : aiohttp .ClientResponse , reason : str
277
- ) -> NoReturn :
278
- # We raise a TransportServerError if the status code is 400 or higher
279
- # We raise a TransportProtocolError in the other cases
280
+ # Saving latest response headers in the transport
281
+ self .response_headers = resp .headers
280
282
281
- try :
282
- # Raise a ClientResponseError if response status is 400 or higher
283
- resp .raise_for_status ()
284
- except ClientResponseError as e :
285
- raise TransportServerError (str (e ), e .status ) from e
286
-
287
- result_text = await resp .text ()
288
- raise TransportProtocolError (
289
- f"Server did not return a GraphQL result: "
290
- f"{ reason } : "
291
- f"{ result_text } "
292
- )
283
+ async def raise_response_error (
284
+ resp : aiohttp .ClientResponse , reason : str
285
+ ) -> NoReturn :
286
+ # We raise a TransportServerError if status code is 400 or higher
287
+ # We raise a TransportProtocolError in the other cases
293
288
294
- try :
295
- result = await resp .json (loads = self .json_deserialize , content_type = None )
289
+ try :
290
+ # Raise ClientResponseError if response status is 400 or higher
291
+ resp .raise_for_status ()
292
+ except ClientResponseError as e :
293
+ raise TransportServerError (str (e ), e .status ) from e
296
294
297
- if log .isEnabledFor (logging .INFO ):
298
295
result_text = await resp .text ()
299
- log .info ("<<< %s" , result_text )
296
+ raise TransportProtocolError (
297
+ f"Server did not return a GraphQL result: "
298
+ f"{ reason } : "
299
+ f"{ result_text } "
300
+ )
300
301
301
- except Exception :
302
- await raise_response_error (resp , "Not a JSON answer" )
302
+ try :
303
+ result = await resp .json (
304
+ loads = self .json_deserialize , content_type = None
305
+ )
303
306
304
- if result is None :
305
- await raise_response_error (resp , "Not a JSON answer" )
307
+ if log .isEnabledFor (logging .INFO ):
308
+ result_text = await resp .text ()
309
+ log .info ("<<< %s" , result_text )
306
310
307
- if "errors" not in result and "data" not in result :
308
- await raise_response_error (resp , 'No "data" or "errors" keys in answer' )
311
+ except Exception :
312
+ await raise_response_error (resp , "Not a JSON answer" )
309
313
310
- return ExecutionResult (
311
- errors = result .get ("errors" ),
312
- data = result .get ("data" ),
313
- extensions = result .get ("extensions" ),
314
- )
314
+ if result is None :
315
+ await raise_response_error (resp , "Not a JSON answer" )
316
+
317
+ if "errors" not in result and "data" not in result :
318
+ await raise_response_error (
319
+ resp , 'No "data" or "errors" keys in answer'
320
+ )
321
+
322
+ return ExecutionResult (
323
+ errors = result .get ("errors" ),
324
+ data = result .get ("data" ),
325
+ extensions = result .get ("extensions" ),
326
+ )
327
+ finally :
328
+ if upload_files :
329
+ close_files (list (self .files .values ()))
315
330
316
331
def subscribe (
317
332
self ,
0 commit comments