Skip to content

Commit 241ea41

Browse files
committed
Allow RSA key used for JWT to be specified as a file path
- auth_jwt_auth_public_certs_url may file:// in addition to http/https - Log an error if payload does not contain an email address
1 parent 4a84738 commit 241ea41

File tree

4 files changed

+114
-15
lines changed

4 files changed

+114
-15
lines changed

redash/authentication/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,10 @@ def jwt_token_load_user_from_request(request):
187187
if not payload:
188188
return
189189

190+
if "email" not in payload:
191+
logger.info("No email field in token, refusing to login")
192+
return
193+
190194
try:
191195
user = models.User.get_by_email_and_org(payload["email"], org)
192196
except models.NoResultFound:

redash/authentication/jwt_auth.py

+35-14
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,50 @@
66

77
logger = logging.getLogger("jwt_auth")
88

9+
FILE_SCHEME_PREFIX = "file://"
10+
11+
12+
def get_public_key_from_file(url):
13+
file_path = url[len(FILE_SCHEME_PREFIX) :]
14+
with open(file_path) as key_file:
15+
key_str = key_file.read()
16+
17+
get_public_keys.key_cache[url] = [key_str]
18+
return key_str
19+
20+
21+
def get_public_key_from_net(url):
22+
r = requests.get(url)
23+
r.raise_for_status()
24+
data = r.json()
25+
if "keys" in data:
26+
public_keys = []
27+
for key_dict in data["keys"]:
28+
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(simplejson.dumps(key_dict))
29+
public_keys.append(public_key)
30+
31+
get_public_keys.key_cache[url] = public_keys
32+
return public_keys
33+
else:
34+
get_public_keys.key_cache[url] = data
35+
return data
36+
937

1038
def get_public_keys(url):
1139
"""
1240
Returns:
1341
List of RSA public keys usable by PyJWT.
1442
"""
1543
key_cache = get_public_keys.key_cache
44+
keys = {}
1645
if url in key_cache:
17-
return key_cache[url]
46+
keys = key_cache[url]
1847
else:
19-
r = requests.get(url)
20-
r.raise_for_status()
21-
data = r.json()
22-
if "keys" in data:
23-
public_keys = []
24-
for key_dict in data["keys"]:
25-
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(simplejson.dumps(key_dict))
26-
public_keys.append(public_key)
27-
28-
get_public_keys.key_cache[url] = public_keys
29-
return public_keys
48+
if url.startswith(FILE_SCHEME_PREFIX):
49+
keys = [get_public_key_from_file(url)]
3050
else:
31-
get_public_keys.key_cache[url] = data
32-
return data
51+
keys = get_public_key_from_net(url)
52+
return keys
3353

3454

3555
get_public_keys.key_cache = {}
@@ -58,4 +78,5 @@ def verify_jwt_token(jwt_token, expected_issuer, expected_audience, algorithms,
5878
break
5979
except Exception as e:
6080
logging.exception(e)
81+
6182
return payload, valid_token

requirements_dev.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
jwcrypto==1.5.0
12
pytest==7.4.0
23
pytest-cov==4.1.0
34
coverage==7.2.7

tests/test_authentication.py

+74-1
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
import importlib
2+
import json
23
import os
4+
import subprocess
35
import time
46

7+
import jwcrypto.jwk
8+
import jwt
9+
import requests
510
from flask import request
6-
from mock import patch
11+
from mock import Mock, patch
712
from sqlalchemy.orm.exc import NoResultFound
813

914
from redash import models, settings
1015
from redash.authentication import (
1116
api_key_load_user_from_request,
1217
get_login_url,
1318
hmac_load_user_from_request,
19+
jwt_auth,
20+
org_settings,
1421
sign,
1522
)
1623
from redash.authentication.google_oauth import (
@@ -405,3 +412,69 @@ def test_disabled_user_should_not_receive_password_reset_link(self):
405412
self.assertEqual(response.status_code, 200)
406413
send_password_reset_email_mock.assert_not_called()
407414
send_user_disabled_email_mock.assert_called_with(user)
415+
416+
417+
class TestJWTAuthentication(BaseTestCase):
418+
def setUp(self):
419+
super(TestJWTAuthentication, self).setUp()
420+
self.auth_audience = "My Org"
421+
self.auth_issuer = "Admin"
422+
self.token_name = "jwt-token"
423+
self.rsa_private_key = "/tmp/jwtRS256.key"
424+
self.rsa_public_key = "/tmp/jwtRS256.pem"
425+
426+
if not os.path.exists(self.rsa_public_key):
427+
subprocess.check_output(["openssl", "genrsa", "-out", self.rsa_private_key, "4096"])
428+
subprocess.check_output(
429+
["openssl", "rsa", "-pubout", "-in", self.rsa_private_key, "-out", self.rsa_public_key]
430+
)
431+
432+
org_settings["auth_jwt_login_enabled"] = True
433+
org_settings["auth_jwt_auth_public_certs_url"] = "file://{}".format(self.rsa_public_key)
434+
org_settings["auth_jwt_auth_issuer"] = self.auth_issuer
435+
org_settings["auth_jwt_auth_audience"] = self.auth_audience
436+
org_settings["auth_jwt_auth_header_name"] = self.token_name
437+
438+
def tearDown(self):
439+
org_settings["auth_jwt_login_enabled"] = False
440+
org_settings["auth_jwt_auth_public_certs_url"] = ""
441+
org_settings["auth_jwt_auth_issuer"] = ""
442+
org_settings["auth_jwt_auth_audience"] = ""
443+
org_settings["auth_jwt_auth_header_name"] = ""
444+
445+
def test_jwt_no_token(self):
446+
response = self.get_request("/data_sources", org=self.factory.org)
447+
self.assertEqual(response.status_code, 302)
448+
449+
def test_jwt_from_pem_file(self):
450+
user = self.factory.create_user()
451+
452+
issued_at_timestamp = time.time()
453+
expiration_timestamp = issued_at_timestamp + 60
454+
455+
data = {
456+
"aud": self.auth_audience,
457+
"email": user.email,
458+
"exp": expiration_timestamp,
459+
"iat": issued_at_timestamp,
460+
"iss": self.auth_issuer,
461+
}
462+
with open(self.rsa_private_key) as keyfile:
463+
sign_key = keyfile.read().strip()
464+
token_data = jwt.encode(data, sign_key, algorithm="RS256")
465+
466+
response = self.get_request("/data_sources", org=self.factory.org, headers={self.token_name: token_data})
467+
self.assertEqual(response.status_code, 200)
468+
469+
@patch.object(requests, "get")
470+
def test_jwk_decode(self, mock_get):
471+
with open(self.rsa_public_key, "rb") as keyfile:
472+
public_key = jwcrypto.jwk.JWK.from_pem(keyfile.read())
473+
jwk_keys = {"keys": [json.loads(public_key.export())]}
474+
475+
mockresponse = Mock()
476+
mockresponse.json = lambda: jwk_keys
477+
mock_get.return_value = mockresponse
478+
479+
keys = jwt_auth.get_public_keys("http://localhost/key.jwt")
480+
self.assertEqual(keys[0].key_size, 4096)

0 commit comments

Comments
 (0)