From 94b0eab4c9dc3cf84ed96039e5c97dbc134458d7 Mon Sep 17 00:00:00 2001 From: Alessio Fabiani Date: Tue, 12 Nov 2024 15:00:29 +0100 Subject: [PATCH] - Improvements with OKTA OIDC provider integration (#386) --- .../services/rest/model/SessionToken.java | 32 +- .../KeycloakSessionServiceDelegate.java | 18 +- .../security/oauth2/OAuth2Configuration.java | 300 ++++++++++-------- .../oauth2/OAuth2SessionServiceDelegate.java | 141 +++++--- .../OpenIdConnectConfiguration.java | 5 +- .../RefreshTokenServiceTest.java | 120 ++++++- 6 files changed, 431 insertions(+), 185 deletions(-) diff --git a/src/modules/rest/api/src/main/java/it/geosolutions/geostore/services/rest/model/SessionToken.java b/src/modules/rest/api/src/main/java/it/geosolutions/geostore/services/rest/model/SessionToken.java index 12674768..722d7301 100644 --- a/src/modules/rest/api/src/main/java/it/geosolutions/geostore/services/rest/model/SessionToken.java +++ b/src/modules/rest/api/src/main/java/it/geosolutions/geostore/services/rest/model/SessionToken.java @@ -5,10 +5,12 @@ @XmlRootElement public class SessionToken { - String token_type; - String access_token; - String refresh_token; - Long expires; + private String token_type; + private String access_token; + private String refresh_token; + private Long expires; + private String error; + private String warning; @XmlElement(name = "token_type") public String getTokenType() { @@ -38,11 +40,29 @@ public void setExpires(Long expires) { } @XmlElement(name = "refresh_token") + public String getRefreshToken() { + return refresh_token; + } + public void setRefreshToken(String refresh_token) { this.refresh_token = refresh_token; } - public String getRefreshToken() { - return refresh_token; + @XmlElement(name = "error") + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + + @XmlElement(name = "warning") + public String getWarning() { + return warning; + } + + public void setWarning(String warning) { + this.warning = warning; } } diff --git a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/keycloak/KeycloakSessionServiceDelegate.java b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/keycloak/KeycloakSessionServiceDelegate.java index bb3582bc..66ec6c77 100644 --- a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/keycloak/KeycloakSessionServiceDelegate.java +++ b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/keycloak/KeycloakSessionServiceDelegate.java @@ -138,6 +138,16 @@ private SessionToken sessionToken(String accessToken, String refreshToken, Date public void doLogout(String sessionId) { HttpServletRequest request = OAuth2Utils.getRequest(); HttpServletResponse response = OAuth2Utils.getResponse(); + AdapterConfig configuration = + GeoStoreContext.bean(KeyCloakConfiguration.class).readAdapterConfig(); + + // Check if request, response, or configuration are null + if (request == null || response == null || configuration == null) { + LOGGER.warn( + "Request, response, or configuration is null, unable to proceed with logout."); + return; + } + KeyCloakHelper helper = GeoStoreContext.bean(KeyCloakHelper.class); KeycloakDeployment deployment = helper.getDeployment(request, response); Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); @@ -146,12 +156,10 @@ public void doLogout(String sessionId) { refreshToken = ((KeycloakTokenDetails) authentication.getDetails()).getRefreshToken(); } String logoutUrl = deployment.getLogoutUrl().build().toString(); - AdapterConfig adapterConfig = - GeoStoreContext.bean(KeyCloakConfiguration.class).readAdapterConfig(); - Configuration clientConfiguration = helper.getClientConfiguration(adapterConfig); + Configuration clientConfiguration = helper.getClientConfiguration(configuration); Http http = new Http(clientConfiguration, (params, headers) -> {}); - String clientId = adapterConfig.getResource(); - String secret = (String) adapterConfig.getCredentials().get("secret"); + String clientId = configuration.getResource(); + String secret = (String) configuration.getCredentials().get("secret"); try { http.post(logoutUrl) .form() diff --git a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2Configuration.java b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2Configuration.java index 9e4b4ddf..8d57120a 100644 --- a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2Configuration.java +++ b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2Configuration.java @@ -42,37 +42,103 @@ import org.springframework.web.util.UriComponentsBuilder; /** - * This class represents the geostore configuration for an OAuth2/OpenId provider. An - * OAuth2Configuration bean should be provided for each OAuth2 provider. The bean id has to be + * This class represents the OAuth2/OpenID Connect configuration for GeoStore. It includes settings + * for endpoints, client credentials, and other OAuth2 provider details. Each OAuth2 provider + * requires a specific OAuth2Configuration bean, identified with the naming convention * {providerName}OAuth2Config. */ public class OAuth2Configuration extends IdPConfiguration { + private static final Logger LOGGER = LogManager.getLogger(OAuth2Configuration.class); + + // Constants public static final String CONFIG_NAME_SUFFIX = "OAuth2Config"; public static final String CONFIGURATION_NAME = "CONFIGURATION_NAME"; - private static final Logger LOGGER = - LogManager.getLogger(OAuth2GeoStoreAuthenticationFilter.class); - protected String clientId; - protected String clientSecret; - protected String accessTokenUri; - protected String authorizationUri; - protected String checkTokenEndpointUrl; - protected String logoutUri; - protected boolean globalLogoutEnabled = false; - protected String scopes; - protected String idTokenUri; - protected String discoveryUrl; - protected String revokeEndpoint; - protected boolean enableRedirectEntryPoint = false; - protected String principalKey; - protected String rolesClaim; - protected String groupsClaim; - - /** - * Get an authentication entry point instance meant to handle redirect to the authorization - * page. - * - * @return the authentication entry point. + + // OAuth2 provider client details + private String clientId; + private String clientSecret; + + // OAuth2 URIs and endpoints + private String accessTokenUri; + private String authorizationUri; + private String checkTokenEndpointUrl; + private String logoutUri; + private String revokeEndpoint; + + // Additional settings + private boolean globalLogoutEnabled = false; + private String scopes; + private String idTokenUri; + private String discoveryUrl; + private boolean enableRedirectEntryPoint = false; + private String principalKey; + private String rolesClaim; + private String groupsClaim; + + // Retry and backoff configurations + private long initialBackoffDelay = 1000; // Default: 1 second + private double backoffMultiplier = 2.0; // Default multiplier + private int maxRetries = 3; // Default max retries + + /** + * Gets the maximum number of retries allowed for refreshing tokens. + * + * @return maxRetries - the maximum retry attempts. + */ + public int getMaxRetries() { + return maxRetries; + } + + /** + * Sets the maximum number of retries allowed for refreshing tokens. + * + * @param maxRetries - the maximum retry attempts to set. + */ + public void setMaxRetries(int maxRetries) { + this.maxRetries = maxRetries; + } + + /** + * Gets the initial backoff delay (in milliseconds) for retry attempts. + * + * @return initialBackoffDelay - the initial delay in milliseconds. + */ + public long getInitialBackoffDelay() { + return initialBackoffDelay; + } + + /** + * Sets the initial backoff delay (in milliseconds) for retry attempts. + * + * @param initialBackoffDelay - the initial delay in milliseconds. + */ + public void setInitialBackoffDelay(long initialBackoffDelay) { + this.initialBackoffDelay = initialBackoffDelay; + } + + /** + * Gets the multiplier applied to backoff delay for each retry attempt. + * + * @return backoffMultiplier - the multiplier for exponential backoff. + */ + public double getBackoffMultiplier() { + return backoffMultiplier; + } + + /** + * Sets the multiplier for exponential backoff delay between retry attempts. + * + * @param backoffMultiplier - the multiplier for backoff. + */ + public void setBackoffMultiplier(double backoffMultiplier) { + this.backoffMultiplier = backoffMultiplier; + } + + /** + * Provides an entry point to redirect to the authorization page for authentication. + * + * @return the AuthenticationEntryPoint handling authorization redirection. */ public AuthenticationEntryPoint getAuthenticationEntryPoint() { return (request, response, authException) -> { @@ -82,77 +148,80 @@ public AuthenticationEntryPoint getAuthenticationEntryPoint() { } /** - * Build the authorization uri to the OAuth2 provider. + * Builds the authorization URI, adding response type, client ID, scope, and redirect URI. * - * @return the authorization uri completed with the various query strings. + * @return the complete authorization URI. */ public String buildLoginUri() { return buildLoginUri(null, new String[] {}); } /** - * Build the authorization uri to the OAuth2 provider. + * Builds the authorization URI with an optional access type. * - * @param accessType the access type request param value. Can be null. - * @return the authorization uri completed with the various query strings. + * @param accessType - the access type, e.g., "offline" or "online"; can be null. + * @return the complete authorization URI. */ public String buildLoginUri(String accessType) { return buildLoginUri(accessType, new String[] {}); } /** - * @param accessType the access type request param value. Can be null. - * @param additionalScopes additional scopes aren't set at from geostore-ovr.properties. Can be - * null. - * @return the + * Builds the authorization URI with access type and additional scopes. + * + * @param accessType - the type of access requested, can be null. + * @param additionalScopes - additional scopes required beyond configured scopes. + * @return the complete authorization URI. */ public String buildLoginUri(String accessType, String... additionalScopes) { - final StringBuilder loginUri = new StringBuilder(getAuthorizationUri()); - loginUri.append("?") - .append("response_type=code") - .append("&") - .append("client_id=") + StringBuilder loginUri = new StringBuilder(getAuthorizationUri()); + loginUri.append("?response_type=code") + .append("&client_id=") .append(getClientId()) - .append("&") - .append("scope=") + .append("&scope=") .append(getScopes().replace(",", "%20")); - for (String s : additionalScopes) { - loginUri.append("%20").append(s); + + for (String scope : additionalScopes) { + loginUri.append("%20").append(scope); + } + + loginUri.append("&redirect_uri=").append(getRedirectUri()); + + if (accessType != null) { + loginUri.append("&access_type=").append(accessType); } - loginUri.append("&").append("redirect_uri=").append(getRedirectUri()); - if (accessType != null) loginUri.append("&").append("access_type=").append(accessType); - String finalUrl = loginUri.toString(); - if (LOGGER.isDebugEnabled()) - LOGGER.info("Going to request authorization to this endpoint {}", finalUrl); - return finalUrl; + + LOGGER.debug("Authorization endpoint URI built: {}", loginUri); + return loginUri.toString(); } /** - * Builds the refresh token URI. + * Constructs a URI to refresh the access token. * - * @return the complete refresh token uri. + * @return the refresh token URI. */ public String buildRefreshTokenURI() { return buildRefreshTokenURI(null); } /** - * Builds the refresh token URI. + * Constructs a URI to refresh the access token with an optional access type. * - * @param accessType the access type request param. - * @return the complete refresh token uri. + * @param accessType - access type to be appended to the URI. + * @return the complete refresh token URI. */ public String buildRefreshTokenURI(String accessType) { - final StringBuilder refreshUri = new StringBuilder(getAccessTokenUri()); - refreshUri - .append("?") - .append("&") - .append("client_id=") - .append(getClientId()) - .append("&") - .append("scope=") - .append(getScopes().replace(",", "%20")); - if (accessType != null) refreshUri.append("&").append("access_type=").append(accessType); + StringBuilder refreshUri = + new StringBuilder(getAccessTokenUri()) + .append("?client_id=") + .append(getClientId()) + .append("&scope=") + .append(getScopes().replace(",", "%20")); + + if (accessType != null) { + refreshUri.append("&access_type=").append(accessType); + } + return refreshUri.toString(); } @@ -337,16 +406,14 @@ public String getProvider() { } /** - * Append the request params to the URL. + * Appends query parameters to a URL. * - * @param params the request params. - * @param url the url. - * @return the complete url. + * @param params - the request parameters. + * @param url - the base URL. + * @return the URL with appended parameters. */ protected String appendParameters(MultiValueMap params, String url) { - UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(url); - builder.queryParams(params); - return builder.build().toUriString(); + return UriComponentsBuilder.fromHttpUrl(url).queryParams(params).build().toUriString(); } protected static void getLogoutRequestParams( @@ -358,28 +425,25 @@ protected static void getLogoutRequestParams( } /** - * Build the revoke endpoint. + * Builds the endpoint for token revocation. * - * @param token the access_token to revoke. - * @return the revoke endpoint. + * @param token - the token to be revoked. + * @param accessToken - the access token for authorization. + * @param configuration - OAuth2 configuration. + * @return the configured revoke endpoint, or null if not available. */ public Endpoint buildRevokeEndpoint( String token, String accessToken, OAuth2Configuration configuration) { - Endpoint result = null; - if (revokeEndpoint != null) { - HttpHeaders headers = getHttpHeaders(accessToken, configuration); + if (revokeEndpoint == null) return null; - MultiValueMap bodyParams = new LinkedMultiValueMap<>(); - bodyParams.add("token", token); - bodyParams.add("client_id", clientId); + HttpHeaders headers = getHttpHeaders(accessToken, configuration); + MultiValueMap bodyParams = new LinkedMultiValueMap<>(); + bodyParams.add("token", token); + bodyParams.add("client_id", clientId); - HttpEntity> requestEntity = - new HttpEntity<>(bodyParams, headers); - - result = new Endpoint(HttpMethod.POST, revokeEndpoint); - result.setRequestEntity(requestEntity); - } - return result; + HttpEntity> requestEntity = + new HttpEntity<>(bodyParams, headers); + return new Endpoint(HttpMethod.POST, revokeEndpoint, requestEntity); } private static HttpHeaders getHttpHeaders( @@ -390,27 +454,23 @@ private static HttpHeaders getHttpHeaders( } /** - * Build the logout endpoint. + * Builds the endpoint for logout. * - * @param token the current access_token. - * @return the logout endpoint. + * @param token - the token for the session to end. + * @param accessToken - access token to authorize the logout. + * @param configuration - OAuth2 configuration. + * @return the logout endpoint with parameters appended, or null if logoutUri is null. */ public Endpoint buildLogoutEndpoint( String token, String accessToken, OAuth2Configuration configuration) { - Endpoint result = null; - if (logoutUri != null) { - HttpHeaders headers = getHeaders(accessToken, configuration); + if (logoutUri == null) return null; - MultiValueMap params = new LinkedMultiValueMap<>(); - getLogoutRequestParams(token, clientId, params); - - HttpEntity> requestEntity = - new HttpEntity<>(null, headers); + HttpHeaders headers = getHeaders(accessToken, configuration); + MultiValueMap params = new LinkedMultiValueMap<>(); + getLogoutRequestParams(token, clientId, params); - result = new Endpoint(HttpMethod.GET, appendParameters(params, logoutUri)); - result.setRequestEntity(requestEntity); - } - return result; + return new Endpoint( + HttpMethod.GET, appendParameters(params, logoutUri), new HttpEntity<>(headers)); } private static HttpHeaders getHeaders(String accessToken, OAuth2Configuration configuration) { @@ -493,18 +553,17 @@ public void setGroupsClaim(String groupsClaim) { this.groupsClaim = groupsClaim; } - /** Class the representing and endpoint with a HTTP method. */ + /** Represents a configurable HTTP endpoint with method and request entity. */ public static class Endpoint { - private String url; - - private HttpMethod method; + private final String url; + private final HttpMethod method; + private final HttpEntity requestEntity; - private HttpEntity requestEntity; - - public Endpoint(HttpMethod method, String url) { + public Endpoint(HttpMethod method, String url, HttpEntity requestEntity) { this.method = method; this.url = url; + this.requestEntity = requestEntity; } /** @return the url. */ @@ -512,37 +571,14 @@ public String getUrl() { return url; } - /** - * Set the url. - * - * @param url the url. - */ - public void setUrl(String url) { - this.url = url; - } - /** @return the HttpMethod. */ public HttpMethod getMethod() { return method; } - /** - * Set the HttpMethod. - * - * @param method the HttpMethod. - */ - public void setMethod(HttpMethod method) { - this.method = method; - } - /** @return */ public HttpEntity getRequestEntity() { return requestEntity; } - - /** @param requestEntity */ - public void setRequestEntity(HttpEntity requestEntity) { - this.requestEntity = requestEntity; - } } } diff --git a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2SessionServiceDelegate.java b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2SessionServiceDelegate.java index 1165d4c3..9d52ea07 100644 --- a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2SessionServiceDelegate.java +++ b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/OAuth2SessionServiceDelegate.java @@ -54,6 +54,7 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.OAuth2ClientContext; import org.springframework.security.oauth2.client.OAuth2RestTemplate; +import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException; import org.springframework.security.oauth2.client.token.AccessTokenRequest; import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken; import org.springframework.security.oauth2.common.DefaultOAuth2RefreshToken; @@ -90,10 +91,12 @@ public OAuth2SessionServiceDelegate( @Override public SessionToken refresh(String refreshToken, String accessToken) { + String errorMessage = ""; + String warningMessage = ""; HttpServletRequest request = getRequest(); - if (accessToken == null) + if (accessToken == null || accessToken.isEmpty()) accessToken = OAuth2Utils.tokenFromParamsOrBearer(ACCESS_TOKEN_PARAM, request); - if (accessToken == null) + if (accessToken == null || accessToken.isEmpty()) throw new NotFoundWebEx("Either the accessToken or the refresh token are missing"); OAuth2AccessToken currentToken = retrieveAccessToken(accessToken); @@ -111,22 +114,48 @@ public SessionToken refresh(String refreshToken, String accessToken) { if (LOGGER.isDebugEnabled()) LOGGER.info("Going to refresh the token."); try { sessionToken = doRefresh(refreshTokenToUse, accessToken, configuration); + } catch (UserRedirectRequiredException e) { + // Log the warning and set the warning message in the session token + warningMessage = "A redirect is required to get the user's approval."; + LOGGER.warn(warningMessage); } catch (NullPointerException npe) { - LOGGER.error("Current configuration wasn't correctly initialized."); + // Log the error and set the error message in the session token + errorMessage = "Current configuration wasn't correctly initialized."; + LOGGER.error("Current configuration wasn't correctly initialized.", npe); + } catch (Exception e) { + // Log the error and set the error message in the session token + errorMessage = "An error occurred during token refresh: " + e.getMessage(); + LOGGER.error(errorMessage); } } - if (sessionToken == null) + if (sessionToken == null && !isTokenExpired(currentToken)) { + if (warningMessage.isEmpty()) + warningMessage = + "Refresh Session Token was NULL for some reason... Seeding it with previous Access Token!"; sessionToken = sessionToken(accessToken, refreshTokenToUse, currentToken.getExpiration()); + } - request.setAttribute( - OAuth2AuthenticationDetails.ACCESS_TOKEN_VALUE, sessionToken.getAccessToken()); - request.setAttribute( - OAuth2AuthenticationDetails.ACCESS_TOKEN_TYPE, sessionToken.getTokenType()); + if (sessionToken != null) { + if (!warningMessage.isEmpty()) sessionToken.setWarning(warningMessage); + if (!errorMessage.isEmpty()) sessionToken.setError(errorMessage); + request.setAttribute( + OAuth2AuthenticationDetails.ACCESS_TOKEN_VALUE, sessionToken.getAccessToken()); + request.setAttribute( + OAuth2AuthenticationDetails.ACCESS_TOKEN_TYPE, sessionToken.getTokenType()); + } return sessionToken; } + private boolean isTokenExpired(OAuth2AccessToken token) { + return token != null + && !token.getValue().isEmpty() + && (token.getExpiration() == null + || (token.getExpiration() != null + && token.getExpiration().before(new Date()))); + } + /** * Invokes the refresh endpoint to get a new session token with updated token details. * @@ -143,12 +172,10 @@ public SessionToken refresh(String refreshToken, String accessToken) { protected SessionToken doRefresh( String refreshToken, String accessToken, OAuth2Configuration configuration) { SessionToken sessionToken = null; - int maxRetries = 3; int attempt = 0; - boolean success = false; + String errorMessage = ""; + String warningMessage = ""; - // Setup HTTP headers and body for the request - // Use restTemplate() method to get RestTemplate instance OAuth2RestTemplate restTemplate = restTemplate(); HttpHeaders headers = getHttpHeaders(accessToken, configuration); MultiValueMap requestBody = new LinkedMultiValueMap<>(); @@ -159,7 +186,10 @@ protected SessionToken doRefresh( HttpEntity> requestEntity = new HttpEntity<>(requestBody, headers); - while (attempt < maxRetries && !success) { + long backoffDelay = configuration.getInitialBackoffDelay(); + int maxRetries = configuration.getMaxRetries() > 0 ? configuration.getMaxRetries() : 1; + + while (attempt < maxRetries) { attempt++; LOGGER.info("Attempting to refresh token, attempt {} of {}", attempt, maxRetries); @@ -173,10 +203,7 @@ protected SessionToken doRefresh( if (response.getStatusCode().is2xxSuccessful()) { OAuth2AccessToken newToken = response.getBody(); - if (newToken != null - && newToken.getValue() != null - && !newToken.getValue().isEmpty()) { - // Process and update the new token details + if (newToken != null && !isTokenExpired(newToken)) { OAuth2RefreshToken newRefreshToken = newToken.getRefreshToken(); OAuth2RefreshToken refreshTokenToUse = (newRefreshToken != null && newRefreshToken.getValue() != null) @@ -191,44 +218,72 @@ protected SessionToken doRefresh( newToken.getExpiration()); LOGGER.info("Token refreshed successfully on attempt {}", attempt); - success = true; + attempt = maxRetries; // Exit retry loop on redirect exception + break; } else { - LOGGER.warn("Received empty or null token on attempt {}", attempt); + LOGGER.warn("Received invalid or expired token on attempt {}", attempt); } } else if (response.getStatusCode().is4xxClientError()) { - // For client errors (e.g., 400, 401, 403), do not retry. LOGGER.error( "Client error occurred: {}. Stopping further attempts.", response.getStatusCode()); break; } else { - // For server errors (5xx), continue retrying LOGGER.warn("Server error occurred: {}. Retrying...", response.getStatusCode()); } + } catch (UserRedirectRequiredException e) { + // Handle redirect exception, set warning, and prevent further attempts + warningMessage = "A redirect is required to get the user's approval."; + LOGGER.warn(warningMessage); + sessionToken = sessionToken(accessToken, refreshToken, null); // Keep current token + sessionToken.setWarning(warningMessage); + attempt = maxRetries; // Exit retry loop on redirect exception + break; } catch (RestClientException ex) { LOGGER.error("Attempt {}: Error refreshing token: {}", attempt, ex.getMessage()); if (attempt == maxRetries) { - LOGGER.error("Max retries reached. Unable to refresh token."); + errorMessage = "Max retries reached. Unable to refresh token."; + LOGGER.error(errorMessage); + break; + } + } + + // Apply backoff delay before the next retry, unless it's the last attempt + if (attempt < maxRetries) { + try { + LOGGER.info("Waiting for {} ms before next retry.", backoffDelay); + Thread.sleep(backoffDelay); + } catch (InterruptedException e) { + LOGGER.warn("Backoff delay interrupted", e); + Thread.currentThread().interrupt(); // Preserve interrupt status } + backoffDelay *= + configuration.getBackoffMultiplier(); // Increase delay for next attempt } } - // Handle unsuccessful refresh - if (!success) { - handleRefreshFailure(accessToken, refreshToken, configuration); + // Only call handleRefreshFailure if sessionToken is null after all attempts and errors + if (sessionToken == null) { + try { + handleRefreshFailure(accessToken, refreshToken, configuration); + } catch (Exception e) { + errorMessage = + "Could not successfully perform the 'doLogout' procedure due to an internal error."; + LOGGER.error(errorMessage, e); + } } return sessionToken; } /** * Handles the refresh failure by clearing the session, logging out remotely, and redirecting to - * login. + * log in. * * @param accessToken the current access token * @param refreshToken the current refresh token * @param configuration the OAuth2Configuration with endpoint details */ - private void handleRefreshFailure( + public void handleRefreshFailure( String accessToken, String refreshToken, OAuth2Configuration configuration) { LOGGER.info( "Unable to refresh token after max retries. Clearing session and redirecting to login."); @@ -248,12 +303,12 @@ private static HttpHeaders getHttpHeaders( String accessToken, OAuth2Configuration configuration) { HttpHeaders headers = new HttpHeaders(); if (configuration != null - && configuration.clientId != null - && configuration.clientSecret != null) + && configuration.getClientId() != null + && configuration.getClientSecret() != null) headers.setBasicAuth( - configuration.clientId, - configuration - .clientSecret); // Set client ID and client secret for authentication + configuration.getClientId(), + configuration.getClientSecret()); // Set client ID and client secret for + // authentication else if (accessToken != null && !accessToken.isEmpty()) { headers.set("Authorization", "Bearer " + accessToken); } @@ -340,6 +395,14 @@ public void doLogout(String sessionId) { HttpServletRequest request = getRequest(); HttpServletResponse response = getResponse(); OAuth2RestTemplate restTemplate = restTemplate(); + OAuth2Configuration configuration = configuration(); + + // Check if request, response, or configuration are null + if (request == null || response == null || configuration == null) { + LOGGER.warn( + "Request, response, or configuration is null, unable to proceed with logout."); + return; + } String token = null; String accessToken = null; @@ -355,7 +418,9 @@ public void doLogout(String sessionId) { } if (token == null) { - if (restTemplate.getOAuth2ClientContext().getAccessToken() != null) { + if (restTemplate != null + && restTemplate.getOAuth2ClientContext() != null + && restTemplate.getOAuth2ClientContext().getAccessToken() != null) { token = restTemplate .getOAuth2ClientContext() @@ -375,7 +440,9 @@ public void doLogout(String sessionId) { } if (accessToken == null) { - if (restTemplate.getOAuth2ClientContext().getAccessToken() != null) { + if (restTemplate != null + && restTemplate.getOAuth2ClientContext() != null + && restTemplate.getOAuth2ClientContext().getAccessToken() != null) { accessToken = restTemplate.getOAuth2ClientContext().getAccessToken().getValue(); } if (accessToken == null) { @@ -389,8 +456,7 @@ public void doLogout(String sessionId) { } } - OAuth2Configuration configuration = configuration(); - if (configuration != null && configuration.isEnabled()) { + if (configuration.isEnabled()) { if (token != null && accessToken != null && !token.isEmpty() @@ -536,6 +602,9 @@ protected OAuth2Configuration configuration() { if (enabledConfig.isPresent()) { return enabledConfig.get(); } + } else { + LOGGER.error("OAuth2Configuration is not initialized properly."); + throw new IllegalStateException("Configuration is required but not initialized."); } return null; } diff --git a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/openid_connect/OpenIdConnectConfiguration.java b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/openid_connect/OpenIdConnectConfiguration.java index e86a4cd2..9688722b 100644 --- a/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/openid_connect/OpenIdConnectConfiguration.java +++ b/src/modules/rest/impl/src/main/java/it/geosolutions/geostore/services/rest/security/oauth2/openid_connect/OpenIdConnectConfiguration.java @@ -123,13 +123,12 @@ public Endpoint buildLogoutEndpoint( params.put( "post_logout_redirect_uri", Collections.singletonList(getPostLogoutRedirectUri())); - getLogoutRequestParams(token, clientId, params); + getLogoutRequestParams(token, getClientId(), params); HttpEntity> requestEntity = new HttpEntity<>(null, headers); - result = new Endpoint(HttpMethod.GET, appendParameters(params, uri)); - result.setRequestEntity(requestEntity); + result = new Endpoint(HttpMethod.GET, appendParameters(params, uri), requestEntity); } return result; } diff --git a/src/modules/rest/impl/src/test/java/it/geosolutions/geostore/rest/security/oauth2/openid_connect/RefreshTokenServiceTest.java b/src/modules/rest/impl/src/test/java/it/geosolutions/geostore/rest/security/oauth2/openid_connect/RefreshTokenServiceTest.java index 68d63ceb..162a3158 100644 --- a/src/modules/rest/impl/src/test/java/it/geosolutions/geostore/rest/security/oauth2/openid_connect/RefreshTokenServiceTest.java +++ b/src/modules/rest/impl/src/test/java/it/geosolutions/geostore/rest/security/oauth2/openid_connect/RefreshTokenServiceTest.java @@ -9,7 +9,7 @@ import it.geosolutions.geostore.services.rest.security.oauth2.OAuth2SessionServiceDelegate; import it.geosolutions.geostore.services.rest.security.oauth2.OAuth2Utils; import it.geosolutions.geostore.services.rest.security.oauth2.TokenDetails; -import java.util.Date; +import java.util.*; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.junit.jupiter.api.*; @@ -23,6 +23,7 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.OAuth2RestTemplate; +import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException; import org.springframework.security.oauth2.common.*; import org.springframework.web.client.HttpClientErrorException; import org.springframework.web.client.HttpServerErrorException; @@ -52,6 +53,7 @@ void setUp() { // Create an instance of the test subclass serviceDelegate = spy(new TestOAuth2SessionServiceDelegate()); + // Ensure restTemplate is set correctly serviceDelegate.setRestTemplate(restTemplate); serviceDelegate.setConfiguration(configuration); serviceDelegate.authenticationCache = authenticationCache; @@ -65,6 +67,7 @@ void setUp() { // Mock configuration behavior when(configuration.isEnabled()).thenReturn(true); + when(configuration.getMaxRetries()).thenReturn(3); when(configuration.getClientId()).thenReturn("testClientId"); when(configuration.getClientSecret()).thenReturn("testClientSecret"); when(configuration.buildRefreshTokenURI()).thenReturn("https://example.com/oauth2/token"); @@ -113,7 +116,7 @@ void testRefreshWithValidTokens() { String refreshToken = "providedRefreshToken"; String accessToken = "providedAccessToken"; - // Mock the RestTemplate exchange method to simulate a successful token refresh + // Mock a successful refresh response DefaultOAuth2AccessToken newAccessToken = new DefaultOAuth2AccessToken("newAccessToken"); OAuth2RefreshToken newRefreshToken = new DefaultOAuth2RefreshToken("newRefreshToken"); newAccessToken.setRefreshToken(newRefreshToken); @@ -123,6 +126,14 @@ void testRefreshWithValidTokens() { ResponseEntity responseEntity = new ResponseEntity<>(newAccessToken, HttpStatus.OK); + // Mock configuration and restTemplate behavior + when(configuration.isEnabled()).thenReturn(true); + when(configuration.getClientId()).thenReturn("testClientId"); + when(configuration.getClientSecret()).thenReturn("testClientSecret"); + when(configuration.buildRefreshTokenURI()).thenReturn("https://example.com/oauth2/token"); + when(configuration.getInitialBackoffDelay()).thenReturn(1000L); + when(configuration.getMaxRetries()).thenReturn(3); + when(restTemplate.exchange( anyString(), eq(HttpMethod.POST), @@ -130,6 +141,10 @@ void testRefreshWithValidTokens() { eq(OAuth2AccessToken.class))) .thenReturn(responseEntity); + // Mock request and response to avoid NullPointerExceptions in doLogout + when(serviceDelegate.getRequest()).thenReturn(mockRequest); + when(serviceDelegate.getResponse()).thenReturn(mockResponse); + // Act SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken); @@ -145,6 +160,13 @@ void testRefreshWithValidTokens() { sessionToken.getExpires() > System.currentTimeMillis(), "Token expiration should be in the future"); assertEquals("bearer", sessionToken.getTokenType(), "Token type should be 'bearer'"); + + // Verify that the cache was updated with the new token + verify(authenticationCache).putCacheEntry(eq("newAccessToken"), any(Authentication.class)); + + // Verify that handleRefreshFailure (and therefore doLogout) was never called + verify(serviceDelegate, never()) + .handleRefreshFailure(anyString(), anyString(), any(OAuth2Configuration.class)); } @Test @@ -174,6 +196,13 @@ void testRefreshWithInvalidRefreshToken() { "existingRefreshToken", sessionToken.getRefreshToken(), "Refresh token should remain unchanged"); + assertNotNull(sessionToken.getWarning(), "Warning message should be set"); + assertTrue( + sessionToken + .getWarning() + .contains( + "Refresh Session Token was NULL for some reason... Seeding it with previous Access Token!"), + "Expected error message in SessionToken"); } @Test @@ -204,7 +233,13 @@ void testRefreshWithServerError() { "existingRefreshToken", sessionToken.getRefreshToken(), "Refresh token should remain unchanged after server error"); - // You can also verify that the method retried the expected number of times + assertNotNull(sessionToken.getWarning(), "Warning message should be set"); + assertTrue( + sessionToken + .getWarning() + .contains( + "Refresh Session Token was NULL for some reason... Seeding it with previous Access Token!"), + "Expected error message in SessionToken"); verify(restTemplate, times(3)) .exchange( anyString(), @@ -242,6 +277,10 @@ void testRefreshWithNullResponse() { "existingRefreshToken", sessionToken.getRefreshToken(), "Refresh token should remain unchanged"); + assertNotNull(sessionToken.getWarning(), "Warning message should be set"); + assertTrue( + sessionToken.getWarning().contains("Seeding it with previous Access Token!"), + "Expected warning message in SessionToken"); } @Test @@ -397,6 +436,81 @@ void testRefreshWithExpiredAccessToken() { "Token expiration should be in the future"); } + @Test + void testRefreshWithExpiredTokenAndUnsuccessfulRefresh() { + // Arrange + String refreshToken = "expiredRefreshToken"; + String accessToken = "expiredAccessToken"; + + // Set the current access token to be expired + mockOAuth2AccessToken.setExpiration( + new Date(System.currentTimeMillis() - 1000)); // Set expiration in the past + serviceDelegate.currentAccessToken = mockOAuth2AccessToken; + + // Mock the RestTemplate exchange method to simulate failure in all attempts to refresh the + // token + when(restTemplate.exchange( + anyString(), + eq(HttpMethod.POST), + any(HttpEntity.class), + eq(OAuth2AccessToken.class))) + .thenThrow(new HttpClientErrorException(HttpStatus.UNAUTHORIZED)); + + // Act + SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken); + + // Assert + assertNull( + sessionToken, + "SessionToken should be null when the token is expired and cannot be refreshed"); + verify(restTemplate, times(3)) + .exchange( + anyString(), + eq(HttpMethod.POST), + any(HttpEntity.class), + eq(OAuth2AccessToken.class)); + } + + @Test + void testRefreshWithUserRedirectRequiredException() { + // Arrange + String refreshToken = "providedRefreshToken"; + String accessToken = "providedAccessToken"; + + // Set the mock RestTemplate to throw UserRedirectRequiredException + when(restTemplate.exchange( + anyString(), + eq(HttpMethod.POST), + any(HttpEntity.class), + eq(OAuth2AccessToken.class))) + .thenThrow(new UserRedirectRequiredException("redirect_uri", new HashMap<>())); + + // Act + SessionToken sessionToken = serviceDelegate.refresh(refreshToken, accessToken); + + // Assert + assertNotNull( + sessionToken, "SessionToken should not be null even when redirect is required"); + assertEquals( + "providedAccessToken", + sessionToken.getAccessToken(), + "Access token should remain unchanged"); + assertEquals( + "existingRefreshToken", + sessionToken.getRefreshToken(), + "Refresh token should remain unchanged"); + assertNotNull(sessionToken.getWarning(), "Warning message should be set"); + assertTrue( + sessionToken + .getWarning() + .contains("A redirect is required to get the user's approval"), + "Expected redirect warning message in SessionToken"); + + // Ensure handleRefreshFailure (and doLogout) was not called + verify(serviceDelegate, never()) + .handleRefreshFailure(anyString(), anyString(), any(OAuth2Configuration.class)); + } + /** Test subclass of OAuth2SessionServiceDelegate for testing purposes. */ class TestOAuth2SessionServiceDelegate extends OAuth2SessionServiceDelegate {