Skip to content

Commit

Permalink
Allow Jwt assertion to be resolved
Browse files Browse the repository at this point in the history
Closes gh-9812
  • Loading branch information
jgrandja committed Jan 10, 2022
1 parent 6c5fd38 commit 525f404
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1098,3 +1098,9 @@ class OAuth2ResourceServerController {
}
----
====

[NOTE]
`JwtBearerReactiveOAuth2AuthorizedClientProvider` resolves the `Jwt` assertion via `OAuth2AuthorizationContext.getPrincipal().getPrincipal()` by default, hence the use of `JwtAuthenticationToken` in the preceding example.

[TIP]
If you need to resolve the `Jwt` assertion from a different source, you can provide `JwtBearerReactiveOAuth2AuthorizedClientProvider.setJwtAssertionResolver()` with a custom `Function<OAuth2AuthorizationContext, Mono<Jwt>>`.
Original file line number Diff line number Diff line change
Expand Up @@ -1352,3 +1352,9 @@ class OAuth2ResourceServerController {
}
----
====

[NOTE]
`JwtBearerOAuth2AuthorizedClientProvider` resolves the `Jwt` assertion via `OAuth2AuthorizationContext.getPrincipal().getPrincipal()` by default, hence the use of `JwtAuthenticationToken` in the preceding example.

[TIP]
If you need to resolve the `Jwt` assertion from a different source, you can provide `JwtBearerOAuth2AuthorizedClientProvider.setJwtAssertionResolver()` with a custom `Function<OAuth2AuthorizationContext, Jwt>`.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.function.Function;

import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.client.endpoint.DefaultJwtBearerTokenResponseClient;
Expand All @@ -45,6 +46,8 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth

private OAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient = new DefaultJwtBearerTokenResponseClient();

private Function<OAuth2AuthorizationContext, Jwt> jwtAssertionResolver = this::resolveJwtAssertion;

private Duration clockSkew = Duration.ofSeconds(60);

private Clock clock = Clock.systemUTC();
Expand Down Expand Up @@ -75,10 +78,10 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
// need for re-authorization
return null;
}
if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) {
Jwt jwt = this.jwtAssertionResolver.apply(context);
if (jwt == null) {
return null;
}
Jwt jwt = (Jwt) context.getPrincipal().getPrincipal();
// As per spec, in section 4.1 Using Assertions as Authorization Grants
// https://tools.ietf.org/html/rfc7521#section-4.1
//
Expand All @@ -97,6 +100,13 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
tokenResponse.getAccessToken());
}

private Jwt resolveJwtAssertion(OAuth2AuthorizationContext context) {
if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) {
return null;
}
return (Jwt) context.getPrincipal().getPrincipal();
}

private OAuth2AccessTokenResponse getTokenResponse(ClientRegistration clientRegistration,
JwtBearerGrantRequest jwtBearerGrantRequest) {
try {
Expand All @@ -123,6 +133,17 @@ public void setAccessTokenResponseClient(
this.accessTokenResponseClient = accessTokenResponseClient;
}

/**
* Sets the resolver used for resolving the {@link Jwt} assertion.
* @param jwtAssertionResolver the resolver used for resolving the {@link Jwt}
* assertion
* @since 5.7
*/
public void setJwtAssertionResolver(Function<OAuth2AuthorizationContext, Jwt> jwtAssertionResolver) {
Assert.notNull(jwtAssertionResolver, "jwtAssertionResolver cannot be null");
this.jwtAssertionResolver = jwtAssertionResolver;
}

/**
* Sets the maximum acceptable clock skew, which is used when checking the
* {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.function.Function;

import reactor.core.publisher.Mono;

Expand All @@ -45,6 +46,8 @@ public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements Re

private ReactiveOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient = new WebClientReactiveJwtBearerTokenResponseClient();

private Function<OAuth2AuthorizationContext, Mono<Jwt>> jwtAssertionResolver = this::resolveJwtAssertion;

private Duration clockSkew = Duration.ofSeconds(60);

private Clock clock = Clock.systemUTC();
Expand Down Expand Up @@ -74,10 +77,7 @@ public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext context
// need for re-authorization
return Mono.empty();
}
if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) {
return Mono.empty();
}
Jwt jwt = (Jwt) context.getPrincipal().getPrincipal();

// As per spec, in section 4.1 Using Assertions as Authorization Grants
// https://tools.ietf.org/html/rfc7521#section-4.1
//
Expand All @@ -90,13 +90,26 @@ public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext context
// issued with a reasonably short lifetime. Clients can refresh an
// expired access token by requesting a new one using the same
// assertion, if it is still valid, or with a new assertion.
return Mono.just(new JwtBearerGrantRequest(clientRegistration, jwt))

// @formatter:off
return this.jwtAssertionResolver.apply(context)
.map((jwt) -> new JwtBearerGrantRequest(clientRegistration, jwt))
.flatMap(this.accessTokenResponseClient::getTokenResponse)
.onErrorMap(OAuth2AuthorizationException.class,
(ex) -> new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(),
ex))
.map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(),
tokenResponse.getAccessToken()));
// @formatter:on
}

private Mono<Jwt> resolveJwtAssertion(OAuth2AuthorizationContext context) {
// @formatter:off
return Mono.just(context)
.map((ctx) -> ctx.getPrincipal().getPrincipal())
.filter((principal) -> principal instanceof Jwt)
.cast(Jwt.class);
// @formatter:on
}

private boolean hasTokenExpired(OAuth2Token token) {
Expand All @@ -115,6 +128,17 @@ public void setAccessTokenResponseClient(
this.accessTokenResponseClient = accessTokenResponseClient;
}

/**
* Sets the resolver used for resolving the {@link Jwt} assertion.
* @param jwtAssertionResolver the resolver used for resolving the {@link Jwt}
* assertion
* @since 5.7
*/
public void setJwtAssertionResolver(Function<OAuth2AuthorizationContext, Mono<Jwt>> jwtAssertionResolver) {
Assert.notNull(jwtAssertionResolver, "jwtAssertionResolver cannot be null");
this.jwtAssertionResolver = jwtAssertionResolver;
}

/**
* Sets the maximum acceptable clock skew, which is used when checking the
* {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,7 @@

import java.time.Duration;
import java.time.Instant;
import java.util.function.Function;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -42,6 +43,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

/**
* Tests for {@link JwtBearerOAuth2AuthorizedClientProvider}.
Expand Down Expand Up @@ -87,6 +89,13 @@ public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgument
.withMessage("accessTokenResponseClient cannot be null");
}

@Test
public void setJwtAssertionResolverWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setJwtAssertionResolver(null))
.withMessage("jwtAssertionResolver cannot be null");
}

@Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
Expand Down Expand Up @@ -198,7 +207,7 @@ public void authorizeWhenJwtBearerAndTokenNotExpiredButClockSkewForcesExpiryThen
}

@Test
public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableToAuthorize() {
public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtDoesNotResolveThenUnableToAuthorize() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration)
Expand All @@ -209,7 +218,7 @@ public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableTo
}

@Test
public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize() {
public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtResolvesThenAuthorize() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
// @formatter:off
Expand All @@ -224,4 +233,25 @@ public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize()
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
}

@Test
public void authorizeWhenCustomJwtAssertionResolverSetThenUsed() {
Function<OAuth2AuthorizationContext, Jwt> jwtAssertionResolver = mock(Function.class);
given(jwtAssertionResolver.apply(any())).willReturn(this.jwtAssertion);
this.authorizedClientProvider.setJwtAssertionResolver(jwtAssertionResolver);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
// @formatter:off
TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password");
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration)
.principal(principal)
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
verify(jwtAssertionResolver).apply(any());
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName());
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.function.Function;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -93,6 +94,13 @@ public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgument
.withMessage("accessTokenResponseClient cannot be null");
}

@Test
public void setJwtAssertionResolverWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setJwtAssertionResolver(null))
.withMessage("jwtAssertionResolver cannot be null");
}

@Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
Expand Down Expand Up @@ -222,7 +230,7 @@ public void authorizeWhenJwtBearerAndTokenNotExpiredButClockSkewForcesExpiryThen
}

@Test
public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableToAuthorize() {
public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtDoesNotResolveThenUnableToAuthorize() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration)
Expand Down Expand Up @@ -251,7 +259,7 @@ public void authorizeWhenInvalidRequestThenThrowClientAuthorizationException() {
}

@Test
public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize() {
public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtResolvesThenAuthorize() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
// @formatter:off
Expand All @@ -266,4 +274,25 @@ public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize()
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
}

@Test
public void authorizeWhenCustomJwtAssertionResolverSetThenUsed() {
Function<OAuth2AuthorizationContext, Mono<Jwt>> jwtAssertionResolver = mock(Function.class);
given(jwtAssertionResolver.apply(any())).willReturn(Mono.just(this.jwtAssertion));
this.authorizedClientProvider.setJwtAssertionResolver(jwtAssertionResolver);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
// @formatter:off
TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password");
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration)
.principal(principal)
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block();
verify(jwtAssertionResolver).apply(any());
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName());
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
}

}

0 comments on commit 525f404

Please sign in to comment.