Skip to content

Add Predicate for authorizationConsentRequired for device code grant #2048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.authentication;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;

/**
* An {@link OAuth2AuthenticationContext} that holds an
* {@link OAuth2DeviceVerificationAuthenticationToken} and additional information and is
* used when validating the OAuth 2.0 Device Verification Request parameters, as well as
* determining if authorization consent is required.
*
* @author Dinesh Gupta
* @since 2.0.0
* @see OAuth2AuthenticationContext
* @see OAuth2DeviceVerificationAuthenticationToken
* @see OAuth2DeviceVerificationAuthenticationProvider#setAuthorizationConsentRequired(java.util.function.Predicate)
*/
public final class OAuth2DeviceVerificationAuthenticationContext implements OAuth2AuthenticationContext {

private final Map<Object, Object> context;

private OAuth2DeviceVerificationAuthenticationContext(Map<Object, Object> context) {
this.context = Collections.unmodifiableMap(new HashMap<>(context));
}

@SuppressWarnings("unchecked")
@Nullable
@Override
public <T extends Authentication> T getAuthentication() {
return (T) get(OAuth2DeviceVerificationAuthenticationToken.class);
}

@Override
public boolean hasKey(Object key) {
Assert.notNull(key, "key cannot be null");
return this.context.containsKey(key);
}

@SuppressWarnings("unchecked")
@Nullable
@Override
public <V> V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null;
}

/**
* Returns the {@link RegisteredClient registered client}.
* @return the {@link RegisteredClient}
*/
public RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class);
}

/**
* Returns the {@link OAuth2Authorization authorization}.
* @return the {@link OAuth2Authorization}, or {@code null} if not available
*/
@Nullable
public OAuth2Authorization getAuthorization() {
return get(OAuth2Authorization.class);
}

/**
* Returns the {@link OAuth2AuthorizationConsent authorization consent}.
* @return the {@link OAuth2AuthorizationConsent}, or {@code null} if not available
*/
@Nullable
public OAuth2AuthorizationConsent getAuthorizationConsent() {
return get(OAuth2AuthorizationConsent.class);
}

/**
* Returns the requested scopes. Never {@code null}; always a {@link Set} (possibly
* empty).
* @return the requested scopes
*/
@SuppressWarnings("unchecked")
public Set<String> getRequestedScopes() {
Set<String> scopes = get(Set.class);
return scopes != null ? scopes : Collections.emptySet();
}

/**
* Constructs a new {@link Builder} with the provided
* {@link OAuth2DeviceVerificationAuthenticationToken}.
* @param authentication the {@link OAuth2DeviceVerificationAuthenticationToken}
* @return the {@link Builder}
*/
public static Builder with(OAuth2DeviceVerificationAuthenticationToken authentication) {
return new Builder(authentication);
}

/**
* A builder for {@link OAuth2DeviceVerificationAuthenticationContext}.
*/
public static final class Builder {

private final Map<Object, Object> context = new HashMap<>();

private Builder(OAuth2DeviceVerificationAuthenticationToken authentication) {
Assert.notNull(authentication, "authentication cannot be null");
context.put(OAuth2DeviceVerificationAuthenticationToken.class, authentication);
}

/**
* Sets the {@link RegisteredClient registered client}.
* @param registeredClient the {@link RegisteredClient}
* @return the {@link Builder} for further configuration
*/
public Builder registeredClient(RegisteredClient registeredClient) {
context.put(RegisteredClient.class, registeredClient);
return this;
}

/**
* Sets the {@link OAuth2Authorization authorization}.
* @param authorization the {@link OAuth2Authorization}
* @return the {@link Builder} for further configuration
*/
public Builder authorization(@Nullable OAuth2Authorization authorization) {
if (authorization != null) {
context.put(OAuth2Authorization.class, authorization);
}
return this;
}

/**
* Sets the {@link OAuth2AuthorizationConsent authorization consent}.
* @param authorizationConsent the {@link OAuth2AuthorizationConsent}
* @return the {@link Builder} for further configuration
*/
public Builder authorizationConsent(@Nullable OAuth2AuthorizationConsent authorizationConsent) {
if (authorizationConsent != null) {
context.put(OAuth2AuthorizationConsent.class, authorizationConsent);
}
return this;
}

/**
* Sets the requested scopes. Never {@code null}; always a {@link Set} (possibly
* empty).
* @param requestedScopes the requested scopes
* @return the {@link Builder} for further configuration
*/
public Builder requestedScopes(@Nullable Set<String> requestedScopes) {
context.put(Set.class, requestedScopes != null ? requestedScopes : Collections.emptySet());
return this;
}

/**
* Builds a new {@link OAuth2DeviceVerificationAuthenticationContext}.
* @return the {@link OAuth2DeviceVerificationAuthenticationContext}
*/
public OAuth2DeviceVerificationAuthenticationContext build() {
Assert.notNull(context.get(RegisteredClient.class), "registeredClient cannot be null");
return new OAuth2DeviceVerificationAuthenticationContext(context);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.security.Principal;
import java.util.Base64;
import java.util.Set;
import java.util.function.Predicate;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -78,6 +79,8 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut

private final OAuth2AuthorizationConsentService authorizationConsentService;

private Predicate<OAuth2DeviceVerificationAuthenticationContext> authorizationConsentRequired = OAuth2DeviceVerificationAuthenticationProvider::isAuthorizationConsentRequired;

/**
* Constructs an {@code OAuth2DeviceVerificationAuthenticationProvider} using the
* provided parameters.
Expand Down Expand Up @@ -140,12 +143,19 @@ public Authentication authenticate(Authentication authentication) throws Authent
this.logger.trace("Retrieved registered client");
}

OAuth2DeviceVerificationAuthenticationContext.Builder authenticationContextBuilder = OAuth2DeviceVerificationAuthenticationContext
.with(deviceVerificationAuthentication)
.registeredClient(registeredClient)
.authorization(authorization);

Set<String> requestedScopes = authorization.getAttribute(OAuth2ParameterNames.SCOPE);
authenticationContextBuilder.requestedScopes(requestedScopes);

OAuth2AuthorizationConsent currentAuthorizationConsent = this.authorizationConsentService
.findById(registeredClient.getId(), principal.getName());
authenticationContextBuilder.authorizationConsent(currentAuthorizationConsent);

if (requiresAuthorizationConsent(requestedScopes, currentAuthorizationConsent)) {
if (this.authorizationConsentRequired.test(authenticationContextBuilder.build())) {
String state = DEFAULT_STATE_GENERATOR.generateKey();
authorization = OAuth2Authorization.from(authorization)
.principalName(principal.getName())
Expand Down Expand Up @@ -201,13 +211,38 @@ public boolean supports(Class<?> authentication) {
return OAuth2DeviceVerificationAuthenticationToken.class.isAssignableFrom(authentication);
}

private static boolean requiresAuthorizationConsent(Set<String> requestedScopes,
OAuth2AuthorizationConsent authorizationConsent) {
/**
* Sets the {@code Predicate} used to determine if authorization consent is required
* during the OAuth 2.0 Device Verification flow.
*
* <p>
* The {@link OAuth2DeviceVerificationAuthenticationContext} provides the predicate
* access to the following context attributes:
* <ul>
* <li>The {@link RegisteredClient} associated with the authorization request.</li>
* <li>The {@link OAuth2Authorization} associated with the device verification.</li>
* <li>The {@link OAuth2AuthorizationConsent} previously granted to the
* {@link RegisteredClient}, or {@code null} if not available.</li>
* </ul>
* </p>
* @param authorizationConsentRequired the {@code Predicate} used to determine if
* authorization consent is required for device verification
* @since 2.0.0
*/
public void setAuthorizationConsentRequired(
Predicate<OAuth2DeviceVerificationAuthenticationContext> authorizationConsentRequired) {
Assert.notNull(authorizationConsentRequired, "authorizationConsentRequired cannot be null");
this.authorizationConsentRequired = authorizationConsentRequired;
}

private static boolean isAuthorizationConsentRequired(
OAuth2DeviceVerificationAuthenticationContext authenticationContext) {

if (authorizationConsent != null && authorizationConsent.getScopes().containsAll(requestedScopes)) {
if (authenticationContext.getAuthorizationConsent() != null && authenticationContext.getAuthorizationConsent()
.getScopes()
.containsAll(authenticationContext.getRequestedScopes())) {
return false;
}

return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -50,10 +51,12 @@
import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder;
import org.springframework.security.oauth2.server.authorization.context.TestAuthorizationServerContext;
import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
import org.springframework.security.oauth2.server.authorization.settings.ClientSettings;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
Expand Down Expand Up @@ -124,6 +127,13 @@ public void constructorWhenAuthorizationConsentServiceIsNullThenThrowIllegalArgu
// @formatter:on
}

@Test
public void setAuthorizationConsentRequiredWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authenticationProvider.setAuthorizationConsentRequired(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("authorizationConsentRequired cannot be null");
}

@Test
public void supportsWhenTypeOAuth2DeviceVerificationAuthenticationTokenThenReturnTrue() {
assertThat(this.authenticationProvider.supports(OAuth2DeviceVerificationAuthenticationToken.class)).isTrue();
Expand Down Expand Up @@ -381,6 +391,81 @@ public void authenticateWhenAuthorizationConsentExistsAndRequestedScopesDoNotMat
.isEqualTo(authenticationResult.getState());
}

@Test
void authenticateWhenPredicateTrueThenReturnsConsentToken() {
@SuppressWarnings("unchecked")
Predicate<OAuth2DeviceVerificationAuthenticationContext> consentPredicate = mock(Predicate.class);
given(consentPredicate.test(any())).willReturn(true);
authenticationProvider.setAuthorizationConsentRequired(consentPredicate);

RegisteredClient client = TestRegisteredClients.registeredClient().build();
given(registeredClientRepository.findById(client.getId())).willReturn(client);

OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(client)
.token(createDeviceCode())
.token(createUserCode())
.attribute(OAuth2ParameterNames.SCOPE, client.getScopes())
.build();

TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password");
principal.setAuthenticated(true);

OAuth2DeviceVerificationAuthenticationToken authRequest = new OAuth2DeviceVerificationAuthenticationToken(
principal, USER_CODE, Collections.emptyMap());

given(authorizationService.findByToken(USER_CODE,
OAuth2DeviceVerificationAuthenticationProvider.USER_CODE_TOKEN_TYPE))
.willReturn(authorization);
given(authorizationConsentService.findById(client.getId(), principal.getName())).willReturn(null);

Authentication result = authenticationProvider.authenticate(authRequest);

assertThat(result).isInstanceOf(OAuth2DeviceAuthorizationConsentAuthenticationToken.class);
OAuth2DeviceAuthorizationConsentAuthenticationToken consentToken = (OAuth2DeviceAuthorizationConsentAuthenticationToken) result;

assertThat(consentToken.isAuthenticated()).isTrue();
assertThat(consentToken.getClientId()).isEqualTo(client.getClientId());
assertThat(consentToken.getPrincipal()).isEqualTo(authRequest.getPrincipal());
assertThat(consentToken.getUserCode()).isEqualTo(authRequest.getUserCode());
assertThat(consentToken.getRequestedScopes()).containsExactlyInAnyOrderElementsOf(client.getScopes());
assertThat(consentToken.getState()).isNotNull();

verify(consentPredicate).test(any());
}

@Test
void authenticateWhenPredicateFalseThenSkipsConsentPage() {
RegisteredClient client = TestRegisteredClients.registeredClient()
.clientSettings(ClientSettings.builder().requireAuthorizationConsent(false).build())
.build();

authenticationProvider.setAuthorizationConsentRequired(
ctx -> ctx.getRegisteredClient().getClientSettings().isRequireAuthorizationConsent());

OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(client)
.token(createDeviceCode())
.token(createUserCode())
.attribute(OAuth2ParameterNames.SCOPE, client.getScopes())
.build();

TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password");
principal.setAuthenticated(true);

OAuth2DeviceVerificationAuthenticationToken authRequest = new OAuth2DeviceVerificationAuthenticationToken(
principal, USER_CODE, Collections.emptyMap());

given(registeredClientRepository.findById(client.getId())).willReturn(client);
given(authorizationService.findByToken(USER_CODE,
OAuth2DeviceVerificationAuthenticationProvider.USER_CODE_TOKEN_TYPE))
.willReturn(authorization);
given(authorizationConsentService.findById(client.getId(), principal.getName())).willReturn(null);

Authentication result = authenticationProvider.authenticate(authRequest);

assertThat(result).isInstanceOf(OAuth2DeviceVerificationAuthenticationToken.class);
assertThat(result.isAuthenticated()).isTrue();
}

private static void mockAuthorizationServerContext() {
AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().build();
TestAuthorizationServerContext authorizationServerContext = new TestAuthorizationServerContext(
Expand Down