Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing custom headers through client #1255

Merged
merged 3 commits into from
Jul 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/

### Enhancements

- None
- Allow passing of custom headers in `Client` calls - [#1255](https://github.com/PrefectHQ/prefect/pull/1255)

### Task Library

Expand Down
33 changes: 27 additions & 6 deletions src/prefect/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def local_token_path(self) -> str:
# -------------------------------------------------------------------------
# Utilities

def get(self, path: str, server: str = None, **params: BuiltIn) -> dict:
def get(
self, path: str, server: str = None, headers: dict = None, **params: BuiltIn
) -> dict:
"""
Convenience function for calling the Prefect API with token auth and GET request

Expand All @@ -101,18 +103,23 @@ def get(self, path: str, server: str = None, **params: BuiltIn) -> dict:
http://prefect-server/v1/auth/login, path would be 'auth/login'.
- server (str, optional): the server to send the GET request to;
defaults to `self.graphql_server`
- headers (dict, optional): Headers to pass with the request
- **params (dict): GET parameters

Returns:
- dict: Dictionary representation of the request made
"""
response = self._request(method="GET", path=path, params=params, server=server)
response = self._request(
method="GET", path=path, params=params, server=server, headers=headers
)
if response.text:
return response.json()
else:
return {}

def post(self, path: str, server: str = None, **params: BuiltIn) -> dict:
def post(
self, path: str, server: str = None, headers: dict = None, **params: BuiltIn
) -> dict:
"""
Convenience function for calling the Prefect API with token auth and POST request

Expand All @@ -121,12 +128,15 @@ def post(self, path: str, server: str = None, **params: BuiltIn) -> dict:
http://prefect-server/v1/auth/login, path would be 'auth/login'.
- server (str, optional): the server to send the POST request to;
defaults to `self.graphql_server`
- headers(dict): headers to pass with the request
- **params (dict): POST parameters

Returns:
- dict: Dictionary representation of the request made
"""
response = self._request(method="POST", path=path, params=params, server=server)
response = self._request(
method="POST", path=path, params=params, server=server, headers=headers
)
if response.text:
return response.json()
else:
Expand All @@ -136,6 +146,7 @@ def graphql(
self,
query: Any,
raise_on_error: bool = True,
headers: dict = None,
**variables: Union[bool, dict, str, int]
) -> GraphQLResult:
"""
Expand All @@ -146,6 +157,8 @@ def graphql(
parsed by prefect.utilities.graphql.parse_graphql().
- raise_on_error (bool): if True, a `ClientError` will be raised if the GraphQL
returns any `errors`.
- headers (dict): any additional headers that should be passed as part of the
request
- **variables (kwarg): Variables to be filled into a query with the key being
equivalent to the variables that are accepted by the query

Expand All @@ -160,6 +173,7 @@ def graphql(
query=parse_graphql(query),
variables=json.dumps(variables),
server=self.graphql_server,
headers=headers,
)

if raise_on_error and "errors" in result:
Expand All @@ -168,7 +182,12 @@ def graphql(
return as_nested_dict(result, GraphQLResult) # type: ignore

def _request(
self, method: str, path: str, params: dict = None, server: str = None
self,
method: str,
path: str,
params: dict = None,
server: str = None,
headers: dict = None,
) -> "requests.models.Response":
"""
Runs any specified request (GET, POST, DELETE) against the server
Expand All @@ -179,6 +198,7 @@ def _request(
- params (dict, optional): Parameters used for the request
- server (str, optional): The server to make requests against, base API
server is used if not specified
- headers (dict, optional): Headers to pass with the request

Returns:
- requests.models.Response: The response returned from the request
Expand All @@ -202,7 +222,8 @@ def _request(

params = params or {}

headers = {"Authorization": "Bearer {}".format(self.token)}
headers = headers or {}
headers.update({"Authorization": "Bearer {}".format(self.token)})
session = requests.Session()
retries = Retry(
total=6,
Expand Down
51 changes: 51 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,57 @@ def test_client_posts_raises_with_no_token(monkeypatch):
assert "Client.login" in str(exc.value)


def test_headers_are_passed_to_get(monkeypatch):
get = MagicMock()
session = MagicMock()
session.return_value.get = get
monkeypatch.setattr("requests.Session", session)
with set_temporary_config(
{"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"}
):
client = Client()
client.get("/foo/bar", headers={"x": "y", "Authorization": "z"})
assert get.called
assert get.call_args[1]["headers"] == {
"x": "y",
"Authorization": "Bearer secret_token",
}


def test_headers_are_passed_to_post(monkeypatch):
post = MagicMock()
session = MagicMock()
session.return_value.post = post
monkeypatch.setattr("requests.Session", session)
with set_temporary_config(
{"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"}
):
client = Client()
client.post("/foo/bar", headers={"x": "y", "Authorization": "z"})
assert post.called
assert post.call_args[1]["headers"] == {
"x": "y",
"Authorization": "Bearer secret_token",
}


def test_headers_are_passed_to_graphql(monkeypatch):
post = MagicMock()
session = MagicMock()
session.return_value.post = post
monkeypatch.setattr("requests.Session", session)
with set_temporary_config(
{"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"}
):
client = Client()
client.graphql("query {}", headers={"x": "y", "Authorization": "z"})
assert post.called
assert post.call_args[1]["headers"] == {
"x": "y",
"Authorization": "Bearer secret_token",
}


def test_client_posts_to_graphql_server(monkeypatch):
post = MagicMock(
return_value=MagicMock(json=MagicMock(return_value=dict(success=True)))
Expand Down