Skip to content

Commit

Permalink
Merge pull request #76 from eclecticiq/refresh-jwt-on-unauthorized
Browse files Browse the repository at this point in the history
Refresh JWT and retry request once on UNAUTHORIZED status message
  • Loading branch information
rjprins authored Nov 9, 2020
2 parents 6d6ed1a + 92907c5 commit 277dcbf
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 36 deletions.
35 changes: 22 additions & 13 deletions cabby/abstract.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from furl import furl
import logging

from libtaxii.common import generate_message_id
import libtaxii

from . import dispatcher, utils
from .converters import to_detailed_service_instance_entity
from .exceptions import (
NoURIProvidedError, ServiceNotFoundError,
AmbiguousServicesError, ClientException
AmbiguousServicesError,
ClientException,
NoURIProvidedError,
ServiceNotFoundError,
UnsuccessfulStatusError,
)
from six.moves import map

Expand Down Expand Up @@ -137,6 +140,7 @@ def refresh_jwt_token(self, session=None):
self._prepare_url(self.jwt_url),
self.username,
self.password)
session.auth = dispatcher.JWTAuth(self.jwt_token)
return self.jwt_token

def prepare_generic_session(self):
Expand Down Expand Up @@ -180,19 +184,24 @@ def _execute_request(self, request, uri=None, service_type=None):
if self.jwt_url and self.username and self.password:
if not self.jwt_token:
self.refresh_jwt_token(session=session)
session = dispatcher.set_jwt_token(session, self.jwt_token)

message = dispatcher.send_taxii_request(
session,
self._prepare_url(uri),
request,
taxii_binding=self.taxii_binding,
timeout=self.timeout)

return message
for attempt in (1, 2):
try:
return dispatcher.send_taxii_request(
session,
self._prepare_url(uri),
request,
taxii_binding=self.taxii_binding,
timeout=self.timeout)
except UnsuccessfulStatusError as exc:
# Refresh the token once if authorization failed and retry
if attempt == 1 and exc.status == libtaxii.ST_UNAUTHORIZED:
self.refresh_jwt_token(session=session)
continue
raise

def _generate_id(self):
return generate_message_id(version=self.services_version)
return libtaxii.common.generate_message_id(version=self.services_version)

def _get_service(self, service_type):
candidates = self.get_services(service_type=service_type)
Expand Down
4 changes: 1 addition & 3 deletions cabby/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


class ClientException(Exception):
pass

Expand All @@ -18,7 +16,7 @@ class InvalidResponseError(ClientException):
class UnsuccessfulStatusError(ClientException):

def __init__(self, taxii_status, *args, **kwargs):
msg = "Server Error: {}".format(_status_to_message(self.raw))
msg = "Server Error: {}".format(_status_to_message(taxii_status))
super(UnsuccessfulStatusError, self).__init__(msg, *args, **kwargs)

self.status = taxii_status.status_type
Expand Down
13 changes: 12 additions & 1 deletion tests/fixtures10.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# flake8: noqa

HOST = 'example.locahost'
HOST = 'example.localhost'

DISCOVERY_PATH = '/some/discovery/path'
DISCOVERY_URI_HTTP = "http://%s%s" % (HOST, DISCOVERY_PATH)
Expand Down Expand Up @@ -115,3 +115,14 @@
INBOX_RESPONSE = '''
<taxii:Status_Message xmlns:taxii="http://taxii.mitre.org/messages/taxii_xml_binding-1" xmlns:taxii_11="http://taxii.mitre.org/messages/taxii_xml_binding-1.1" xmlns:tdq="http://taxii.mitre.org/query/taxii_default_query-1" message_id="48205" in_response_to="13777" status_type="SUCCESS"/>
'''

STATUS_MESSAGE_UNAUTHORIZED = """\
<?xml version="1.0"?>
<taxii:Status_Message
xmlns:taxii="http://taxii.mitre.org/messages/taxii_xml_binding-1"
xmlns:taxii_11="http://taxii.mitre.org/messages/taxii_xml_binding-1.1"
xmlns:tdq="http://taxii.mitre.org/query/taxii_default_query-1"
message_id="123123"
in_response_to="0"
status_type="UNAUTHORIZED"/>
"""
11 changes: 11 additions & 0 deletions tests/fixtures11.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,14 @@
INBOX_RESPONSE = '''
<taxii_11:Status_Message xmlns:taxii="http://taxii.mitre.org/messages/taxii_xml_binding-1" xmlns:taxii_11="http://taxii.mitre.org/messages/taxii_xml_binding-1.1" xmlns:tdq="http://taxii.mitre.org/query/taxii_default_query-1" message_id="83710" in_response_to="57915" status_type="SUCCESS"/>
'''

STATUS_MESSAGE_UNAUTHORIZED = """\
<?xml version="1.0"?>
<taxii_11:Status_Message
xmlns:taxii="http://taxii.mitre.org/messages/taxii_xml_binding-1"
xmlns:taxii_11="http://taxii.mitre.org/messages/taxii_xml_binding-1.1"
xmlns:tdq="http://taxii.mitre.org/query/taxii_default_query-1"
message_id="1209263632454336330"
in_response_to="0"
status_type="UNAUTHORIZED"/>
"""
97 changes: 78 additions & 19 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

CUSTOM_HEADER_NAME = 'X-custom-header'
CUSTOM_HEADER_VALUE = 'header value with space!'
JWT_PATH = '/management/auth/'
JWT_URL = "http://example.localhost" + JWT_PATH


def get_fix(version):
Expand All @@ -36,22 +38,24 @@ def make_client(version, **kwargs):
return client


def register_uri(uri, body, version, headers=None, mock=None, **kwargs):
content_type = VID_TAXII_XML_11 if version == 11 else VID_TAXII_XML_10
headers = headers or {}
headers.update({
'X-TAXII-Content-Type': content_type
})
if not mock:
mock = responses
mock.add(
def register_uri(uri, body, version, headers=None):
if headers is None:
headers = {}
headers.update(make_taxii_headers(version))
responses.add(
method=responses.POST,
url=uri,
body=body,
content_type='application/xml',
content_type="application/xml",
stream=True,
adding_headers=headers,
**kwargs)
)


def make_taxii_headers(version):
return {
"X-TAXII-Content-Type": VID_TAXII_XML_11 if version == 11 else VID_TAXII_XML_10
}


def get_sent_message(version, mock=None):
Expand Down Expand Up @@ -131,10 +135,6 @@ def test_invalid_response_status(version):
@pytest.mark.parametrize("version", [11, 10])
@responses.activate
def test_jwt_auth_response(version):
jwt_path = '/management/auth/'
jwt_url = 'http://{}{}'.format(get_fix(version).HOST, jwt_path)

token = 'dummy'
username = 'dummy-username'
password = 'dummy-password'

Expand All @@ -147,17 +147,18 @@ def jwt_request_callback(request):
assert body['username'] == username
assert body['password'] == password

content = json.dumps({'token': token}).encode()
content = json.dumps({"token": "dummy"}).encode()
return (200, {}, content)

# https://github.com/getsentry/responses/pull/268
responses.mock._matches.append(responses.CallbackResponse(
method=responses.POST,
url=jwt_url,
url=JWT_URL,
callback=jwt_request_callback,
content_type='application/json',
stream=True,
))

discovery_uri = get_fix(version).DISCOVERY_URI_HTTP

register_uri(
Expand All @@ -172,7 +173,7 @@ def jwt_request_callback(request):
client.set_auth(
username=username,
password=password,
jwt_auth_url=jwt_path
jwt_auth_url=JWT_PATH
)
services = client.discover_services(uri=discovery_uri)
assert len(services) == 4
Expand All @@ -182,7 +183,7 @@ def jwt_request_callback(request):
client.set_auth(
username=username,
password=password,
jwt_auth_url=jwt_url
jwt_auth_url=JWT_URL
)
services = client.discover_services(uri=discovery_uri)
assert len(services) == 4
Expand Down Expand Up @@ -247,3 +248,61 @@ def timeout_request_callback(request):

with pytest.raises(requests.exceptions.Timeout):
client.discover_services(uri=uri)


@pytest.mark.parametrize("version", [11, 10])
@responses.activate
def test_retry_once_on_unauthorized(version):
# Test if the client refreshes the JWT if it receives an UNAUTHORIZED
# status message.
# Flow is as follows when client.poll() is called:
# 1. Authenticate and get first_token
# 2. Do poll request with first_token: Get UNAUTHORIZED response.
# 3. Authenticate again and get second_token
# 4. Do poll request with second_token: Get POLL_RESPONSE.

# Set up two responses with tokens for auth request
first_token = "first"
second_token = "second"
for token in (first_token, second_token):
responses.add(
method=responses.POST,
url=JWT_URL,
json={"token": token},
content_type="application/json",
stream=True,
)

client = make_client(version)
client.set_auth(username="username", password="pass", jwt_auth_url=JWT_PATH)

# Set up two responses for poll request: First is UNAUTHORIZED, the second is
# a normal POLL_RESPONSE if the token was refreshed.
attempts = []

def poll_callback(request):
attempts.append(request)
_, _, token = request.headers["Authorization"].partition("Bearer ")
if len(attempts) == 1:
assert token == first_token
return (
200,
make_taxii_headers(version),
get_fix(version).STATUS_MESSAGE_UNAUTHORIZED,
)
else:
assert len(attempts) == 2
assert token == second_token
return (200, make_taxii_headers(version), get_fix(version).POLL_RESPONSE)

responses.mock._matches.append(
responses.CallbackResponse(
responses.POST,
url="http://example.localhost/poll",
callback=poll_callback,
stream=True,
)
)

list(client.poll(collection_name="X", uri="/poll"))
assert client.jwt_token == second_token

0 comments on commit 277dcbf

Please sign in to comment.