Skip to content

Commit

Permalink
- Improvements with OKTA OIDC provider integration
Browse files Browse the repository at this point in the history
  • Loading branch information
afabiani committed Nov 5, 2024
1 parent 0257c75 commit 00dfe39
Showing 1 changed file with 113 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.http.*;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2ClientContext;
Expand All @@ -63,6 +64,7 @@
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.HttpMessageConverterExtractor;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.context.request.RequestContextHolder;

Expand Down Expand Up @@ -123,84 +125,121 @@ public SessionToken refresh(String refreshToken, String accessToken) {
}

/**
* Invokes the refresh endpoint and return a session token holding the updated tokens details.
* Invokes the refresh endpoint to get a new session token with updated token details.
*
* @param refreshToken the refresh token.
* @param accessToken the access token.
* @param configuration the OAuth2Configuration.
* @return the SessionToken.
* <p>This method attempts to refresh the session by exchanging the provided refresh token for a
* new access token. If the refresh token is invalid or the request fails after several retries,
* the session is cleared, and the user is redirected to the login page.
*
* @param refreshToken the refresh token to use for obtaining new access and refresh tokens
* @param accessToken the current access token
* @param configuration the OAuth2Configuration containing client credentials and endpoint URI
* @return a SessionToken containing the new token details, or null if the refresh process
* failed
*/
protected SessionToken doRefresh(
String refreshToken, String accessToken, OAuth2Configuration configuration) {
SessionToken sessionToken = null;
int maxRetries = 3;
int attempt = 0;
boolean success = false;

// Setup HTTP headers and body for the request
RestTemplate restTemplate = new RestTemplate();
HttpHeaders headers = getHttpHeaders(accessToken, configuration);

MultiValueMap<String, String> requestBody = new LinkedMultiValueMap<>();
requestBody.add("grant_type", "refresh_token");
requestBody.add("refresh_token", refreshToken);
requestBody.add("client_secret", configuration.getClientSecret());
requestBody.add("client_id", configuration.getClientId());

HttpEntity<MultiValueMap<String, String>> requestEntity =
new HttpEntity<>(requestBody, headers);

OAuth2AccessToken newToken = null;
try {
newToken =
restTemplate
.exchange(
configuration
.buildRefreshTokenURI(), // Use exchange method for POST
// request
HttpMethod.POST,
requestEntity, // Include request body
OAuth2AccessToken.class)
.getBody();
} catch (Exception ex) {
LOGGER.error("Error trying to obtain a refresh token.", ex);
}
while (attempt < maxRetries && !success) {
attempt++;
LOGGER.info("Attempting to refresh token, attempt {} of {}", attempt, maxRetries);

if (newToken != null && newToken.getValue() != null && !newToken.getValue().isEmpty()) {
// update the Authentication
OAuth2RefreshToken newRefreshToken = newToken.getRefreshToken();
OAuth2RefreshToken refreshTokenToUse =
newRefreshToken != null
&& newRefreshToken.getValue() != null
&& !newRefreshToken.getValue().isEmpty()
? newRefreshToken
: new DefaultOAuth2RefreshToken(refreshToken);
updateAuthToken(accessToken, newToken, refreshTokenToUse, configuration);
sessionToken =
sessionToken(
newToken.getValue(),
refreshTokenToUse.getValue(),
newToken.getExpiration());
} else if (accessToken != null) {
// update the Authentication
sessionToken = sessionToken(accessToken, refreshToken, null);
} else {
// the refresh token was invalid. let's clear the session and send a remote logout.
// then redirect to the login entry point.
LOGGER.info(
"Unable to refresh the token. The following request was performed: {}. Redirecting to login.",
configuration.buildRefreshTokenURI("offline"));
doLogout(null);
try {
getResponse()
.sendRedirect(
"../../openid/"
+ configuration.getProvider().toLowerCase()
+ "/login");
} catch (IOException e) {
LOGGER.error("Error while sending redirect to login service. ", e);
throw new RuntimeException(e);
ResponseEntity<OAuth2AccessToken> response =
restTemplate.exchange(
configuration.buildRefreshTokenURI(),
HttpMethod.POST,
requestEntity,
OAuth2AccessToken.class);

if (response.getStatusCode().is2xxSuccessful()) {
OAuth2AccessToken newToken = response.getBody();
if (newToken != null
&& newToken.getValue() != null
&& !newToken.getValue().isEmpty()) {
// Process and update the new token details
OAuth2RefreshToken newRefreshToken = newToken.getRefreshToken();
OAuth2RefreshToken refreshTokenToUse =
(newRefreshToken != null && newRefreshToken.getValue() != null)
? newRefreshToken
: new DefaultOAuth2RefreshToken(refreshToken);

updateAuthToken(accessToken, newToken, refreshTokenToUse, configuration);
sessionToken =
sessionToken(
newToken.getValue(),
refreshTokenToUse.getValue(),
newToken.getExpiration());

LOGGER.info("Token refreshed successfully on attempt {}", attempt);
success = true;
} else {
LOGGER.warn("Received empty or null token on attempt {}", attempt);
}
} else if (response.getStatusCode().is4xxClientError()) {
// For client errors (e.g., 400, 401, 403), do not retry.
LOGGER.error(
"Client error occurred: {}. Stopping further attempts.",
response.getStatusCode());
break;
} else {
// For server errors (5xx), continue retrying
LOGGER.warn("Server error occurred: {}. Retrying...", response.getStatusCode());
}
} catch (RestClientException ex) {
LOGGER.error("Attempt {}: Error refreshing token: {}", attempt, ex.getMessage());
if (attempt == maxRetries) {
LOGGER.error("Max retries reached. Unable to refresh token.");
}
}
}

// Handle unsuccessful refresh
if (!success) {
handleRefreshFailure(accessToken, refreshToken, configuration);
}
return sessionToken;
}

/**
* Handles the refresh failure by clearing the session, logging out remotely, and redirecting to
* login.
*
* @param accessToken the current access token
* @param refreshToken the current refresh token
* @param configuration the OAuth2Configuration with endpoint details
*/
private void handleRefreshFailure(
String accessToken, String refreshToken, OAuth2Configuration configuration) {
LOGGER.info(
"Unable to refresh token after max retries. Clearing session and redirecting to login.");
doLogout(null);

try {
String redirectUrl =
"../../openid/" + configuration.getProvider().toLowerCase() + "/login";
getResponse().sendRedirect(redirectUrl);
} catch (IOException e) {
LOGGER.error("Error while sending redirect to login service: ", e);
throw new RuntimeException("Failed to redirect to login", e);
}
}

private static HttpHeaders getHttpHeaders(
String accessToken, OAuth2Configuration configuration) {
HttpHeaders headers = new HttpHeaders();
Expand Down Expand Up @@ -240,23 +279,26 @@ private void updateAuthToken(

if (LOGGER.isDebugEnabled())
LOGGER.info("Updating the cache and the SecurityContext with new Auth details");
TokenDetails details = getTokenDetails(authentication);
String idToken = details.getIdToken();
cache().removeEntry(oldToken);
PreAuthenticatedAuthenticationToken updated =
new PreAuthenticatedAuthenticationToken(
authentication.getPrincipal(),
authentication.getCredentials(),
authentication.getAuthorities());
DefaultOAuth2AccessToken accessToken = new DefaultOAuth2AccessToken(newToken);
if (refreshToken != null) {
accessToken.setRefreshToken(refreshToken);
if (authentication != null && !(authentication instanceof AnonymousAuthenticationToken)) {
TokenDetails details = getTokenDetails(authentication);
String idToken = details.getIdToken();
cache().removeEntry(oldToken);
PreAuthenticatedAuthenticationToken updated =
new PreAuthenticatedAuthenticationToken(
authentication.getPrincipal(),
authentication.getCredentials(),
authentication.getAuthorities());
DefaultOAuth2AccessToken accessToken = new DefaultOAuth2AccessToken(newToken);
if (refreshToken != null) {
accessToken.setRefreshToken(refreshToken);
}
if (LOGGER.isDebugEnabled())
LOGGER.debug(
"Creating new details. AccessToken: {} IdToken: {}", accessToken, idToken);
updated.setDetails(new TokenDetails(accessToken, idToken, conf.getBeanName()));
cache().putCacheEntry(newToken.getValue(), updated);
SecurityContextHolder.getContext().setAuthentication(updated);
}
if (LOGGER.isDebugEnabled())
LOGGER.debug("Creating new details. AccessToken: {} IdToken: {}", accessToken, idToken);
updated.setDetails(new TokenDetails(accessToken, idToken, conf.getBeanName()));
cache().putCacheEntry(newToken.getValue(), updated);
SecurityContextHolder.getContext().setAuthentication(updated);
}

private OAuth2AccessToken retrieveAccessToken(String accessToken) {
Expand Down

0 comments on commit 00dfe39

Please sign in to comment.