Skip to content

Commit 92c8d70

Browse files
authored
File upload support (#145)
Implementation of File Upload feature. Using the graphql-multipart-request-spec
1 parent dad37ac commit 92c8d70

File tree

8 files changed

+567
-10
lines changed

8 files changed

+567
-10
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ The main features of GQL are:
3838
* Possibility to [validate the queries locally](https://gql.readthedocs.io/en/latest/usage/validation.html) using a GraphQL schema provided locally or fetched from the backend using an instrospection query
3939
* Supports GraphQL queries, mutations and subscriptions
4040
* Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage)
41+
* Supports [File uploads](https://gql.readthedocs.io/en/latest/usage/file_upload.html)
4142

4243
## Installation
4344

docs/transports/aiohttp.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _aiohttp_transport:
2+
13
AIOHTTPTransport
24
================
35

docs/usage/file_upload.rst

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
File uploads
2+
============
3+
4+
GQL supports file uploads with the :ref:`aiohttp transport <aiohttp_transport>`
5+
using the `GraphQL multipart request spec`_.
6+
7+
.. _GraphQL multipart request spec: https://github.com/jaydenseric/graphql-multipart-request-spec
8+
9+
Single File
10+
-----------
11+
12+
In order to upload a single file, you need to:
13+
14+
* set the file as a variable value in the mutation
15+
* provide the opened file to the `variable_values` argument of `execute`
16+
* set the `upload_files` argument to True
17+
18+
.. code-block:: python
19+
20+
transport = AIOHTTPTransport(url='YOUR_URL')
21+
22+
client = Client(transport=sample_transport)
23+
24+
query = gql('''
25+
mutation($file: Upload!) {
26+
singleUpload(file: $file) {
27+
id
28+
}
29+
}
30+
''')
31+
32+
with open("YOUR_FILE_PATH", "rb") as f:
33+
34+
params = {"file": f}
35+
36+
result = client.execute(
37+
query, variable_values=params, upload_files=True
38+
)
39+
40+
File list
41+
---------
42+
43+
It is also possible to upload multiple files using a list.
44+
45+
.. code-block:: python
46+
47+
transport = AIOHTTPTransport(url='YOUR_URL')
48+
49+
client = Client(transport=sample_transport)
50+
51+
query = gql('''
52+
mutation($files: [Upload!]!) {
53+
multipleUpload(files: $files) {
54+
id
55+
}
56+
}
57+
''')
58+
59+
f1 = open("YOUR_FILE_PATH_1", "rb")
60+
f2 = open("YOUR_FILE_PATH_1", "rb")
61+
62+
params = {"files": [f1, f2]}
63+
64+
result = client.execute(
65+
query, variable_values=params, upload_files=True
66+
)
67+
68+
f1.close()
69+
f2.close()

docs/usage/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ Usage
99
subscriptions
1010
variables
1111
headers
12+
file_upload

gql/transport/aiohttp.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
import logging
13
from ssl import SSLContext
24
from typing import Any, AsyncGenerator, Dict, Optional, Union
35

@@ -8,6 +10,7 @@
810
from aiohttp.typedefs import LooseCookies, LooseHeaders
911
from graphql import DocumentNode, ExecutionResult, print_ast
1012

13+
from ..utils import extract_files
1114
from .async_transport import AsyncTransport
1215
from .exceptions import (
1316
TransportAlreadyConnected,
@@ -16,6 +19,8 @@
1619
TransportServerError,
1720
)
1821

22+
log = logging.getLogger(__name__)
23+
1924

2025
class AIOHTTPTransport(AsyncTransport):
2126
""":ref:`Async Transport <async_transports>` to execute GraphQL queries
@@ -32,7 +37,7 @@ def __init__(
3237
auth: Optional[BasicAuth] = None,
3338
ssl: Union[SSLContext, bool, Fingerprint] = False,
3439
timeout: Optional[int] = None,
35-
client_session_args: Dict[str, Any] = {},
40+
client_session_args: Optional[Dict[str, Any]] = None,
3641
) -> None:
3742
"""Initialize the transport with the given aiohttp parameters.
3843
@@ -54,7 +59,6 @@ def __init__(
5459
self.ssl: Union[SSLContext, bool, Fingerprint] = ssl
5560
self.timeout: Optional[int] = timeout
5661
self.client_session_args = client_session_args
57-
5862
self.session: Optional[aiohttp.ClientSession] = None
5963

6064
async def connect(self) -> None:
@@ -81,7 +85,8 @@ async def connect(self) -> None:
8185
)
8286

8387
# Adding custom parameters passed from init
84-
client_session_args.update(self.client_session_args)
88+
if self.client_session_args:
89+
client_session_args.update(self.client_session_args) # type: ignore
8590

8691
self.session = aiohttp.ClientSession(**client_session_args)
8792

@@ -104,7 +109,8 @@ async def execute(
104109
document: DocumentNode,
105110
variable_values: Optional[Dict[str, str]] = None,
106111
operation_name: Optional[str] = None,
107-
extra_args: Dict[str, Any] = {},
112+
extra_args: Dict[str, Any] = None,
113+
upload_files: bool = False,
108114
) -> ExecutionResult:
109115
"""Execute the provided document AST against the configured remote server
110116
using the current session.
@@ -118,25 +124,70 @@ async def execute(
118124
:param variables_values: An optional Dict of variable values
119125
:param operation_name: An optional Operation name for the request
120126
:param extra_args: additional arguments to send to the aiohttp post method
127+
:param upload_files: Set to True if you want to put files in the variable values
121128
:returns: an ExecutionResult object.
122129
"""
123130

124131
query_str = print_ast(document)
132+
125133
payload: Dict[str, Any] = {
126134
"query": query_str,
127135
}
128136

129-
if variable_values:
130-
payload["variables"] = variable_values
131137
if operation_name:
132138
payload["operationName"] = operation_name
133139

134-
post_args = {
135-
"json": payload,
136-
}
140+
if upload_files:
141+
142+
# If the upload_files flag is set, then we need variable_values
143+
assert variable_values is not None
144+
145+
# If we upload files, we will extract the files present in the
146+
# variable_values dict and replace them by null values
147+
nulled_variable_values, files = extract_files(variable_values)
148+
149+
# Save the nulled variable values in the payload
150+
payload["variables"] = nulled_variable_values
151+
152+
# Prepare aiohttp to send multipart-encoded data
153+
data = aiohttp.FormData()
154+
155+
# Generate the file map
156+
# path is nested in a list because the spec allows multiple pointers
157+
# to the same file. But we don't support that.
158+
# Will generate something like {"0": ["variables.file"]}
159+
file_map = {str(i): [path] for i, path in enumerate(files)}
160+
161+
# Enumerate the file streams
162+
# Will generate something like {'0': <_io.BufferedReader ...>}
163+
file_streams = {str(i): files[path] for i, path in enumerate(files)}
164+
165+
# Add the payload to the operations field
166+
operations_str = json.dumps(payload)
167+
log.debug("operations %s", operations_str)
168+
data.add_field(
169+
"operations", operations_str, content_type="application/json"
170+
)
171+
172+
# Add the file map field
173+
file_map_str = json.dumps(file_map)
174+
log.debug("file_map %s", file_map_str)
175+
data.add_field("map", file_map_str, content_type="application/json")
176+
177+
# Add the extracted files as remaining fields
178+
data.add_fields(*file_streams.items())
179+
180+
post_args: Dict[str, Any] = {"data": data}
181+
182+
else:
183+
if variable_values:
184+
payload["variables"] = variable_values
185+
186+
post_args = {"json": payload}
137187

138188
# Pass post_args to aiohttp post method
139-
post_args.update(extra_args)
189+
if extra_args:
190+
post_args.update(extra_args)
140191

141192
if self.session is None:
142193
raise TransportClosed("Transport is not connected")

gql/utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""Utilities to manipulate several python objects."""
22

3+
import io
4+
from typing import Any, Dict, Tuple
5+
36

47
# From this response in Stackoverflow
58
# http://stackoverflow.com/a/19053800/1072990
@@ -8,3 +11,43 @@ def to_camel_case(snake_str):
811
# We capitalize the first letter of each component except the first one
912
# with the 'title' method and join them together.
1013
return components[0] + "".join(x.title() if x else "_" for x in components[1:])
14+
15+
16+
def is_file_like(value: Any) -> bool:
17+
"""Check if a value represents a file like object"""
18+
return isinstance(value, io.IOBase)
19+
20+
21+
def extract_files(variables: Dict) -> Tuple[Dict, Dict]:
22+
files = {}
23+
24+
def recurse_extract(path, obj):
25+
"""
26+
recursively traverse obj, doing a deepcopy, but
27+
replacing any file-like objects with nulls and
28+
shunting the originals off to the side.
29+
"""
30+
nonlocal files
31+
if isinstance(obj, list):
32+
nulled_obj = []
33+
for key, value in enumerate(obj):
34+
value = recurse_extract(f"{path}.{key}", value)
35+
nulled_obj.append(value)
36+
return nulled_obj
37+
elif isinstance(obj, dict):
38+
nulled_obj = {}
39+
for key, value in obj.items():
40+
value = recurse_extract(f"{path}.{key}", value)
41+
nulled_obj[key] = value
42+
return nulled_obj
43+
elif is_file_like(obj):
44+
# extract obj from its parent and put it into files instead.
45+
files[path] = obj
46+
return None
47+
else:
48+
# base case: pass through unchanged
49+
return obj
50+
51+
nulled_variables = recurse_extract("variables", variables)
52+
53+
return nulled_variables, files

tests/conftest.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import os
55
import pathlib
66
import ssl
7+
import tempfile
78
import types
89
from concurrent.futures import ThreadPoolExecutor
10+
from typing import Union
911

1012
import pytest
1113
import websockets
@@ -187,6 +189,35 @@ async def send_connection_ack(ws):
187189
await ws.send('{"event":"phx_reply", "payload": {"status": "ok"}, "ref": 1}')
188190

189191

192+
class TemporaryFile:
193+
"""Class used to generate temporary files for the tests"""
194+
195+
def __init__(self, content: Union[str, bytearray]):
196+
197+
mode = "w" if isinstance(content, str) else "wb"
198+
199+
# We need to set the newline to '' so that the line returns
200+
# are not replaced by '\r\n' on windows
201+
newline = "" if isinstance(content, str) else None
202+
203+
self.file = tempfile.NamedTemporaryFile(
204+
mode=mode, newline=newline, delete=False
205+
)
206+
207+
with self.file as f:
208+
f.write(content)
209+
210+
@property
211+
def filename(self):
212+
return self.file.name
213+
214+
def __enter__(self):
215+
return self
216+
217+
def __exit__(self, type, value, traceback):
218+
os.unlink(self.filename)
219+
220+
190221
def get_server_handler(request):
191222
"""Get the server handler.
192223

0 commit comments

Comments
 (0)