|
1 | 1 | import importlib
|
| 2 | +import json |
2 | 3 | import os
|
| 4 | +import subprocess |
3 | 5 | import time
|
4 | 6 |
|
| 7 | +import jwcrypto.jwk |
| 8 | +import jwt |
| 9 | +import requests |
5 | 10 | from flask import request
|
6 |
| -from mock import patch |
| 11 | +from mock import Mock, patch |
7 | 12 | from sqlalchemy.orm.exc import NoResultFound
|
8 | 13 |
|
9 | 14 | from redash import models, settings
|
10 | 15 | from redash.authentication import (
|
11 | 16 | api_key_load_user_from_request,
|
12 | 17 | get_login_url,
|
13 | 18 | hmac_load_user_from_request,
|
| 19 | + jwt_auth, |
| 20 | + org_settings, |
14 | 21 | sign,
|
15 | 22 | )
|
16 | 23 | from redash.authentication.google_oauth import (
|
@@ -405,3 +412,69 @@ def test_disabled_user_should_not_receive_password_reset_link(self):
|
405 | 412 | self.assertEqual(response.status_code, 200)
|
406 | 413 | send_password_reset_email_mock.assert_not_called()
|
407 | 414 | send_user_disabled_email_mock.assert_called_with(user)
|
| 415 | + |
| 416 | + |
| 417 | +class TestJWTAuthentication(BaseTestCase): |
| 418 | + def setUp(self): |
| 419 | + super(TestJWTAuthentication, self).setUp() |
| 420 | + self.auth_audience = "My Org" |
| 421 | + self.auth_issuer = "Admin" |
| 422 | + self.token_name = "jwt-token" |
| 423 | + self.rsa_private_key = "/tmp/jwtRS256.key" |
| 424 | + self.rsa_public_key = "/tmp/jwtRS256.pem" |
| 425 | + |
| 426 | + if not os.path.exists(self.rsa_public_key): |
| 427 | + subprocess.check_output(["openssl", "genrsa", "-out", self.rsa_private_key, "4096"]) |
| 428 | + subprocess.check_output( |
| 429 | + ["openssl", "rsa", "-pubout", "-in", self.rsa_private_key, "-out", self.rsa_public_key] |
| 430 | + ) |
| 431 | + |
| 432 | + org_settings["auth_jwt_login_enabled"] = True |
| 433 | + org_settings["auth_jwt_auth_public_certs_url"] = "file://{}".format(self.rsa_public_key) |
| 434 | + org_settings["auth_jwt_auth_issuer"] = self.auth_issuer |
| 435 | + org_settings["auth_jwt_auth_audience"] = self.auth_audience |
| 436 | + org_settings["auth_jwt_auth_header_name"] = self.token_name |
| 437 | + |
| 438 | + def tearDown(self): |
| 439 | + org_settings["auth_jwt_login_enabled"] = False |
| 440 | + org_settings["auth_jwt_auth_public_certs_url"] = "" |
| 441 | + org_settings["auth_jwt_auth_issuer"] = "" |
| 442 | + org_settings["auth_jwt_auth_audience"] = "" |
| 443 | + org_settings["auth_jwt_auth_header_name"] = "" |
| 444 | + |
| 445 | + def test_jwt_no_token(self): |
| 446 | + response = self.get_request("/data_sources", org=self.factory.org) |
| 447 | + self.assertEqual(response.status_code, 302) |
| 448 | + |
| 449 | + def test_jwt_from_pem_file(self): |
| 450 | + user = self.factory.create_user() |
| 451 | + |
| 452 | + issued_at_timestamp = time.time() |
| 453 | + expiration_timestamp = issued_at_timestamp + 60 |
| 454 | + |
| 455 | + data = { |
| 456 | + "aud": self.auth_audience, |
| 457 | + "email": user.email, |
| 458 | + "exp": expiration_timestamp, |
| 459 | + "iat": issued_at_timestamp, |
| 460 | + "iss": self.auth_issuer, |
| 461 | + } |
| 462 | + with open(self.rsa_private_key) as keyfile: |
| 463 | + sign_key = keyfile.read().strip() |
| 464 | + token_data = jwt.encode(data, sign_key, algorithm="RS256") |
| 465 | + |
| 466 | + response = self.get_request("/data_sources", org=self.factory.org, headers={self.token_name: token_data}) |
| 467 | + self.assertEqual(response.status_code, 200) |
| 468 | + |
| 469 | + @patch.object(requests, "get") |
| 470 | + def test_jwk_decode(self, mock_get): |
| 471 | + with open(self.rsa_public_key, "rb") as keyfile: |
| 472 | + public_key = jwcrypto.jwk.JWK.from_pem(keyfile.read()) |
| 473 | + jwk_keys = {"keys": [json.loads(public_key.export())]} |
| 474 | + |
| 475 | + mockresponse = Mock() |
| 476 | + mockresponse.json = lambda: jwk_keys |
| 477 | + mock_get.return_value = mockresponse |
| 478 | + |
| 479 | + keys = jwt_auth.get_public_keys("http://localhost/key.jwt") |
| 480 | + self.assertEqual(keys[0].key_size, 4096) |
0 commit comments