Skip to content

Fix regressions introduced in 3.1.0 #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/database_blacklist/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def login():
refresh_token = create_refresh_token(identity=username)

# Store the tokens in our store with a status of not currently revoked.
add_token_to_database(access_token)
add_token_to_database(refresh_token)
add_token_to_database(access_token, app.config['JWT_IDENTITY_CLAIM'])
add_token_to_database(refresh_token, app.config['JWT_IDENTITY_CLAIM'])

ret = {
'access_token': access_token,
Expand All @@ -72,7 +72,7 @@ def refresh():
# Do the same thing that we did in the login endpoint here
current_user = get_jwt_identity()
access_token = create_access_token(identity=current_user)
add_token_to_database(access_token)
add_token_to_database(access_token, app.config['JWT_IDENTITY_CLAIM'])
return jsonify({'access_token': access_token}), 201

# Provide a way for a user to look at their tokens
Expand Down
5 changes: 3 additions & 2 deletions examples/database_blacklist/blacklist_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ def _epoch_utc_to_datetime(epoch_utc):
return datetime.fromtimestamp(epoch_utc)


def add_token_to_database(encoded_token):
def add_token_to_database(encoded_token, identity_claim):
"""
Adds a new token to the database. It is not revoked when it is added.
:param identity_claim:
"""
decoded_token = decode_token(encoded_token)
jti = decoded_token['jti']
token_type = decoded_token['type']
user_identity = decoded_token['identity']
user_identity = decoded_token[identity_claim]
expires = _epoch_utc_to_datetime(decoded_token['exp'])
revoked = False

Expand Down
8 changes: 4 additions & 4 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def jwt_required(fn):
def wrapper(*args, **kwargs):
jwt_data = _decode_jwt_from_request(request_type='access')
ctx_stack.top.jwt = jwt_data
_load_user(jwt_data['identity'])
_load_user(jwt_data[config.identity_claim])
return fn(*args, **kwargs)
return wrapper

Expand All @@ -53,7 +53,7 @@ def wrapper(*args, **kwargs):
try:
jwt_data = _decode_jwt_from_request(request_type='access')
ctx_stack.top.jwt = jwt_data
_load_user(jwt_data['identity'])
_load_user(jwt_data[config.identity_claim])
except NoAuthorizationError:
pass
return fn(*args, **kwargs)
Expand All @@ -77,7 +77,7 @@ def wrapper(*args, **kwargs):
raise FreshTokenRequired('Fresh token required')

ctx_stack.top.jwt = jwt_data
_load_user(jwt_data['identity'])
_load_user(jwt_data[config.identity_claim])
return fn(*args, **kwargs)
return wrapper

Expand All @@ -92,7 +92,7 @@ def jwt_refresh_token_required(fn):
def wrapper(*args, **kwargs):
jwt_data = _decode_jwt_from_request(request_type='refresh')
ctx_stack.top.jwt = jwt_data
_load_user(jwt_data['identity'])
_load_user(jwt_data[config.identity_claim])
return fn(*args, **kwargs)
return wrapper

Expand Down
1 change: 1 addition & 0 deletions tests/test_blacklist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def setUp(self):
self.app = Flask(__name__)
self.app.secret_key = 'super=secret'
self.app.config['JWT_BLACKLIST_ENABLED'] = True
self.app.config['JWT_IDENTITY_CLAIM'] = 'sub'
self.jwt_manager = JWTManager(self.app)
self.client = self.app.test_client()
self.blacklist = set()
Expand Down
54 changes: 30 additions & 24 deletions tests/test_jwt_encode_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,25 @@ def test_encode_access_token(self):
algorithm = 'HS256'
token_expire_delta = timedelta(minutes=5)
user_claims = {'foo': 'bar'}
identity_claim = 'identity'

# Check with a fresh token
with self.app.test_request_context():
identity = 'user1'
token = encode_access_token(identity, secret, algorithm, token_expire_delta,
fresh=True, user_claims=user_claims, csrf=False,
identity_claim='identity')
identity_claim=identity_claim)
data = jwt.decode(token, secret, algorithms=[algorithm])
self.assertIn('exp', data)
self.assertIn('iat', data)
self.assertIn('nbf', data)
self.assertIn('jti', data)
self.assertIn('identity', data)
self.assertIn(identity_claim, data)
self.assertIn('fresh', data)
self.assertIn('type', data)
self.assertIn('user_claims', data)
self.assertNotIn('csrf', data)
self.assertEqual(data['identity'], identity)
self.assertEqual(data[identity_claim], identity)
self.assertEqual(data['fresh'], True)
self.assertEqual(data['type'], 'access')
self.assertEqual(data['user_claims'], user_claims)
Expand All @@ -61,18 +62,18 @@ def test_encode_access_token(self):
identity = 12345 # identity can be anything json serializable
token = encode_access_token(identity, secret, algorithm, token_expire_delta,
fresh=False, user_claims=user_claims, csrf=True,
identity_claim='identity')
identity_claim=identity_claim)
data = jwt.decode(token, secret, algorithms=[algorithm])
self.assertIn('exp', data)
self.assertIn('iat', data)
self.assertIn('nbf', data)
self.assertIn('jti', data)
self.assertIn('identity', data)
self.assertIn(identity_claim, data)
self.assertIn('fresh', data)
self.assertIn('type', data)
self.assertIn('user_claims', data)
self.assertIn('csrf', data)
self.assertEqual(data['identity'], identity)
self.assertEqual(data[identity_claim], identity)
self.assertEqual(data['fresh'], False)
self.assertEqual(data['type'], 'access')
self.assertEqual(data['user_claims'], user_claims)
Expand All @@ -86,16 +87,17 @@ def test_encode_invalid_access_token(self):
# Check with non-serializable json
with self.app.test_request_context():
user_claims = datetime
identity_claim = 'identity'
with self.assertRaises(Exception):
encode_access_token('user1', 'secret', 'HS256',
timedelta(hours=1), True, user_claims,
csrf=True, identity_claim='identity')
csrf=True, identity_claim=identity_claim)

user_claims = {'foo': timedelta(hours=4)}
with self.assertRaises(Exception):
encode_access_token('user1', 'secret', 'HS256',
timedelta(hours=1), True, user_claims,
csrf=True, identity_claim='identity')
csrf=True, identity_claim=identity_claim)

def test_encode_refresh_token(self):
secret = 'super-totally-secret-key'
Expand Down Expand Up @@ -212,25 +214,27 @@ def test_decode_jwt(self):

def test_decode_invalid_jwt(self):
with self.app.test_request_context():
identity_claim = 'identity'
# Verify underlying pyjwt expires verification works
with self.assertRaises(jwt.ExpiredSignatureError):
token_data = {
'exp': datetime.utcnow() - timedelta(minutes=5),
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')
csrf=False, identity_claim=identity_claim)

# Missing jti
with self.assertRaises(JWTDecodeError):

token_data = {
'exp': datetime.utcnow() + timedelta(minutes=5),
'identity': 'banana',
identity_claim: 'banana',
'type': 'refresh'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')
csrf=False, identity_claim=identity_claim)

# Missing identity
with self.assertRaises(JWTDecodeError):
Expand All @@ -241,83 +245,85 @@ def test_decode_invalid_jwt(self):
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')
csrf=False, identity_claim=identity_claim)

# Non-matching identity claim
with self.assertRaises(JWTDecodeError):
token_data = {
'exp': datetime.utcnow() + timedelta(minutes=5),
'identity': 'banana',
identity_claim: 'banana',
'type': 'refresh'
}
other_identity_claim = 'sub'
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
self.assertNotEqual(identity_claim, other_identity_claim)
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='sub')
csrf=False, identity_claim=other_identity_claim)

# Missing type
with self.assertRaises(JWTDecodeError):
token_data = {
'jti': 'banana',
'identity': 'banana',
identity_claim: 'banana',
'exp': datetime.utcnow() + timedelta(minutes=5),
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')
csrf=False, identity_claim=identity_claim)

# Missing fresh in access token
with self.assertRaises(JWTDecodeError):
token_data = {
'jti': 'banana',
'identity': 'banana',
identity_claim: 'banana',
'exp': datetime.utcnow() + timedelta(minutes=5),
'type': 'access',
'user_claims': {}
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')
csrf=False, identity_claim=identity_claim)

# Missing user claims in access token
with self.assertRaises(JWTDecodeError):
token_data = {
'jti': 'banana',
'identity': 'banana',
identity_claim: 'banana',
'exp': datetime.utcnow() + timedelta(minutes=5),
'type': 'access',
'fresh': True
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')
csrf=False, identity_claim=identity_claim)

# Bad token type
with self.assertRaises(JWTDecodeError):
token_data = {
'jti': 'banana',
'identity': 'banana',
identity_claim: 'banana',
'exp': datetime.utcnow() + timedelta(minutes=5),
'type': 'banana',
'fresh': True,
'user_claims': 'banana'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')
csrf=False, identity_claim=identity_claim)

# Missing csrf in csrf enabled token
with self.assertRaises(JWTDecodeError):
token_data = {
'jti': 'banana',
'identity': 'banana',
identity_claim: 'banana',
'exp': datetime.utcnow() + timedelta(minutes=5),
'type': 'access',
'fresh': True,
'user_claims': 'banana'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=True,
identity_claim='identity')
identity_claim=identity_claim)

def test_create_jwt_with_object(self):
# Complex object to test building a JWT from. Normally if you are using
Expand Down
9 changes: 7 additions & 2 deletions tests/test_protected_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def setUp(self):
self.app.config['JWT_ALGORITHM'] = 'HS256'
self.app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(seconds=1)
self.app.config['JWT_REFRESH_TOKEN_EXPIRES'] = timedelta(seconds=1)
self.app.config['JWT_IDENTITY_CLAIM'] = 'sub'
self.jwt_manager = JWTManager(self.app)
self.client = self.app.test_client()

Expand Down Expand Up @@ -454,6 +455,9 @@ def claims():
claims_keys = [claim for claim in jwt]
return jsonify(claims_keys), 200

# Grab custom identity claim
identity_claim = self.app.config['JWT_IDENTITY_CLAIM']

# Login
response = self.client.post('/auth/login')
data = json.loads(response.get_data(as_text=True))
Expand All @@ -466,7 +470,7 @@ def claims():
self.assertIn('iat', data)
self.assertIn('nbf', data)
self.assertIn('jti', data)
self.assertIn('identity', data)
self.assertIn(identity_claim, data)
self.assertIn('fresh', data)
self.assertIn('type', data)
self.assertIn('user_claims', data)
Expand Down Expand Up @@ -836,12 +840,13 @@ def test_access_endpoints_with_cookie_missing_csrf_field(self):

def test_access_endpoints_with_cookie_csrf_claim_not_string(self):
now = datetime.utcnow()
identity_claim = self.app.config['JWT_IDENTITY_CLAIM']
token_data = {
'exp': now + timedelta(minutes=5),
'iat': now,
'nbf': now,
'jti': 'banana',
'identity': 'banana',
identity_claim: 'banana',
'type': 'refresh',
'csrf': 404
}
Expand Down