Skip to content

Commit

Permalink
Merge pull request #515 from escattone/PKCE-in-session-refresh-middle…
Browse files Browse the repository at this point in the history
…ware

add PKCE to SessionRefresh middleware
  • Loading branch information
akatsoulas authored Jan 10, 2024
2 parents f75ff62 + 8bf691f commit bf0d143
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 1 deletion.
31 changes: 30 additions & 1 deletion mozilla_django_oidc/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mozilla_django_oidc.utils import (
absolutify,
add_state_and_verifier_and_nonce_to_session,
generate_code_challenge,
import_from_settings,
)

Expand Down Expand Up @@ -152,7 +153,35 @@ def process_request(self, request):
nonce = get_random_string(self.OIDC_NONCE_SIZE)
params.update({"nonce": nonce})

add_state_and_verifier_and_nonce_to_session(request, state, params)
if self.get_settings("OIDC_USE_PKCE", False):
code_verifier_length = self.get_settings("OIDC_PKCE_CODE_VERIFIER_SIZE", 64)
# Check that code_verifier_length is between the min and max length
# defined in https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
if not (43 <= code_verifier_length <= 128):
raise ValueError("code_verifier_length must be between 43 and 128")

# Generate code_verifier and code_challenge pair
code_verifier = get_random_string(code_verifier_length)
code_challenge_method = self.get_settings(
"OIDC_PKCE_CODE_CHALLENGE_METHOD", "S256"
)
code_challenge = generate_code_challenge(
code_verifier, code_challenge_method
)

# Append code_challenge to authentication request parameters
params.update(
{
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
}
)
else:
code_verifier = None

add_state_and_verifier_and_nonce_to_session(
request, state, params, code_verifier
)

request.session["oidc_login_next"] = request.get_full_path()

Expand Down
121 changes: 121 additions & 0 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,36 @@ def test_is_ajax(self, mock_middleware_random):
json_payload = json.loads(response.content.decode("utf-8"))
self.assertEqual(json_payload["refresh_url"], response["refresh_url"])

@override_settings(OIDC_USE_PKCE=True)
def test_is_ajax_with_pkce(self, mock_middleware_random):
mock_middleware_random.return_value = "examplestring"

request = self.factory.get("/foo", HTTP_X_REQUESTED_WITH="XMLHttpRequest")
request.session = {}
request.user = self.user

response = self.middleware.process_request(request)
self.assertEqual(response.status_code, 403)
# The URL to go to is available both as a header and as a key
# in the JSON response.
self.assertTrue(response["refresh_url"])
url, qs = response["refresh_url"].split("?")
self.assertEqual(url, "http://example.com/authorize")
expected_query = {
"response_type": ["code"],
"redirect_uri": ["http://testserver/callback/"],
"client_id": ["foo"],
"nonce": ["examplestring"],
"prompt": ["none"],
"scope": ["openid email"],
"state": ["examplestring"],
"code_challenge_method": ["S256"],
"code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"],
}
self.assertEqual(expected_query, parse_qs(qs))
json_payload = json.loads(response.content.decode("utf-8"))
self.assertEqual(json_payload["refresh_url"], response["refresh_url"])

def test_no_oidc_token_expiration_forces_renewal(self, mock_middleware_random):
mock_middleware_random.return_value = "examplestring"

Expand All @@ -101,6 +131,34 @@ def test_no_oidc_token_expiration_forces_renewal(self, mock_middleware_random):
}
self.assertEqual(expected_query, parse_qs(qs))

@override_settings(OIDC_USE_PKCE=True)
def test_no_oidc_token_expiration_forces_renewal_with_pkce(
self, mock_middleware_random
):
mock_middleware_random.return_value = "examplestring"

request = self.factory.get("/foo")
request.user = self.user
request.session = {}

response = self.middleware.process_request(request)

self.assertEqual(response.status_code, 302)
url, qs = response.url.split("?")
self.assertEqual(url, "http://example.com/authorize")
expected_query = {
"response_type": ["code"],
"redirect_uri": ["http://testserver/callback/"],
"client_id": ["foo"],
"nonce": ["examplestring"],
"prompt": ["none"],
"scope": ["openid email"],
"state": ["examplestring"],
"code_challenge_method": ["S256"],
"code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"],
}
self.assertEqual(expected_query, parse_qs(qs))

def test_expired_token_forces_renewal(self, mock_middleware_random):
mock_middleware_random.return_value = "examplestring"

Expand All @@ -124,6 +182,32 @@ def test_expired_token_forces_renewal(self, mock_middleware_random):
}
self.assertEqual(expected_query, parse_qs(qs))

@override_settings(OIDC_USE_PKCE=True)
def test_expired_token_forces_renewal_with_pkce(self, mock_middleware_random):
mock_middleware_random.return_value = "examplestring"

request = self.factory.get("/foo")
request.user = self.user
request.session = {"oidc_id_token_expiration": time.time() - 10}

response = self.middleware.process_request(request)

self.assertEqual(response.status_code, 302)
url, qs = response.url.split("?")
self.assertEqual(url, "http://example.com/authorize")
expected_query = {
"response_type": ["code"],
"redirect_uri": ["http://testserver/callback/"],
"client_id": ["foo"],
"nonce": ["examplestring"],
"prompt": ["none"],
"scope": ["openid email"],
"state": ["examplestring"],
"code_challenge_method": ["S256"],
"code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"],
}
self.assertEqual(expected_query, parse_qs(qs))


# This adds a "home page" we can test against.
def fakeview(req):
Expand Down Expand Up @@ -306,6 +390,43 @@ def test_expired_token_redirects_to_sso(self, mock_middleware_random):
}
self.assertEqual(expected_query, parse_qs(qs))

@override_settings(OIDC_OP_AUTHORIZATION_ENDPOINT="http://example.com/authorize")
@override_settings(OIDC_RP_CLIENT_ID="foo")
@override_settings(OIDC_RENEW_ID_TOKEN_EXPIRY_SECONDS=120)
@override_settings(OIDC_USE_PKCE=True)
@patch("mozilla_django_oidc.middleware.get_random_string")
def test_expired_token_redirects_to_sso_with_pkce(self, mock_middleware_random):
mock_middleware_random.return_value = "examplestring"

client = ClientWithUser()
client.login(username=self.user.username, password="password")

# Set expiration to some time in the past
session = client.session
session["oidc_id_token_expiration"] = time.time() - 100
session[
"_auth_user_backend"
] = "mozilla_django_oidc.auth.OIDCAuthenticationBackend"
session.save()

resp = client.get("/mdo_fake_view/")
self.assertEqual(resp.status_code, 302)

url, qs = resp.url.split("?")
self.assertEqual(url, "http://example.com/authorize")
expected_query = {
"response_type": ["code"],
"redirect_uri": ["http://testserver/callback/"],
"client_id": ["foo"],
"nonce": ["examplestring"],
"prompt": ["none"],
"scope": ["openid email"],
"state": ["examplestring"],
"code_challenge_method": ["S256"],
"code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"],
}
self.assertEqual(expected_query, parse_qs(qs))

@override_settings(OIDC_OP_AUTHORIZATION_ENDPOINT="http://example.com/authorize")
@override_settings(OIDC_RP_CLIENT_ID="foo")
@override_settings(OIDC_RENEW_ID_TOKEN_EXPIRY_SECONDS=120)
Expand Down

0 comments on commit bf0d143

Please sign in to comment.