Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support configuring OAuth2AuthorizationRequestResolver as bean #15237

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -58,7 +58,7 @@
* {@link ClientRegistrationRepository} {@code @Bean} may be registered instead.
*
* <h2>Security Filters</h2>
*
* <p>
* The following {@code Filter}'s are populated for {@link #authorizationCodeGrant()}:
*
* <ul>
Expand All @@ -67,7 +67,7 @@
* </ul>
*
* <h2>Shared Objects Created</h2>
*
* <p>
* The following shared objects are populated:
*
* <ul>
Expand All @@ -76,7 +76,7 @@
* </ul>
*
* <h2>Shared Objects Used</h2>
*
* <p>
* The following shared objects are used:
*
* <ul>
Expand Down Expand Up @@ -283,10 +283,12 @@ private OAuth2AuthorizationRequestResolver getAuthorizationRequestResolver() {
if (this.authorizationRequestResolver != null) {
return this.authorizationRequestResolver;
}
ClientRegistrationRepository clientRegistrationRepository = OAuth2ClientConfigurerUtils
.getClientRegistrationRepository(getBuilder());
return new DefaultOAuth2AuthorizationRequestResolver(clientRegistrationRepository,
OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
ResolvableType resolvableType = ResolvableType.forClass(OAuth2AuthorizationRequestResolver.class);
OAuth2AuthorizationRequestResolver bean = getBeanOrNull(resolvableType);
return (bean != null) ? bean
: new DefaultOAuth2AuthorizationRequestResolver(
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(getBuilder()),
OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
}

private OAuth2AuthorizationCodeGrantFilter createAuthorizationCodeGrantFilter(B builder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4532,9 +4532,12 @@ private ReactiveClientRegistrationRepository getClientRegistrationRepository() {
}

private OAuth2AuthorizationRequestRedirectWebFilter getRedirectWebFilter() {
OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter;
if (this.authorizationRequestResolver != null) {
return new OAuth2AuthorizationRequestRedirectWebFilter(this.authorizationRequestResolver);
ServerOAuth2AuthorizationRequestResolver result = this.authorizationRequestResolver;
if (result == null) {
result = getBeanOrNull(ServerOAuth2AuthorizationRequestResolver.class);
}
if (result != null) {
return new OAuth2AuthorizationRequestRedirectWebFilter(result);
}
return new OAuth2AuthorizationRequestRedirectWebFilter(getClientRegistrationRepository());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -285,6 +285,18 @@ public void configureWhenCustomAuthorizationRedirectStrategySetThenAuthorization
verify(authorizationRedirectStrategy).sendRedirect(any(), any(), anyString());
}

@Test
public void configureWhenCustomAuthorizationRequestResolverBeanPresentThenAuthorizationRequestIncludesCustomParameters()
throws Exception {
this.spring.register(OAuth2ClientBeanConfig.class).autowire();
// @formatter:off
this.mockMvc.perform(get("/oauth2/authorization/registration-1"))
.andExpect(status().is3xxRedirection())
.andReturn();
// @formatter:on
verify(authorizationRequestResolver).resolve(any());
}

@EnableWebSecurity
@Configuration
@EnableWebMvc
Expand Down Expand Up @@ -362,4 +374,59 @@ OAuth2AuthorizedClientRepository authorizedClientRepository() {

}

@EnableWebSecurity
@Configuration
@EnableWebMvc
static class OAuth2ClientBeanConfig {

@Bean
SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.authorizeRequests()
.anyRequest().authenticated()
.and()
.requestCache()
.requestCache(requestCache)
.and()
.oauth2Client()
.authorizationCodeGrant()
.authorizationRedirectStrategy(authorizationRedirectStrategy)
.accessTokenResponseClient(accessTokenResponseClient);
return http.build();
// @formatter:on
}

@Bean
ClientRegistrationRepository clientRegistrationRepository() {
return clientRegistrationRepository;
}

@Bean
OAuth2AuthorizedClientRepository authorizedClientRepository() {
return authorizedClientRepository;
}

@Bean
OAuth2AuthorizationRequestResolver authorizationRequestResolver() {
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = authorizationRequestResolver;
authorizationRequestResolver = mock(OAuth2AuthorizationRequestResolver.class);
given(authorizationRequestResolver.resolve(any()))
.willAnswer((invocation) -> defaultAuthorizationRequestResolver.resolve(invocation.getArgument(0)));
return authorizationRequestResolver;
}

@RestController
class ResourceController {

@GetMapping("/resource1")
String resource1(
@RegisteredOAuth2AuthorizedClient("registration-1") OAuth2AuthorizedClient authorizedClient) {
return "resource1";
}

}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
Expand Down Expand Up @@ -457,6 +458,7 @@ public void oauth2LoginWhenCustomBeansThenUsed() {
OidcUser user = TestOidcUsers.create();
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = config.userService;
given(userService.loadUser(any())).willReturn(Mono.just(user));
ServerOAuth2AuthorizationRequestResolver resolver = config.resolver;
// @formatter:off
webTestClient.get()
.uri("/login/oauth2/code/google")
Expand All @@ -466,6 +468,7 @@ public void oauth2LoginWhenCustomBeansThenUsed() {
verify(config.jwtDecoderFactory).createDecoder(any());
verify(tokenResponseClient).getTokenResponse(any());
verify(securityContextRepository).save(any(), any());
verify(resolver).resolve(any());
}

// gh-5562
Expand Down Expand Up @@ -837,6 +840,10 @@ static class OAuth2LoginWithCustomBeansConfig {

ServerSecurityContextRepository securityContextRepository = mock(ServerSecurityContextRepository.class);

ServerOAuth2AuthorizationRequestResolver resolver = spy(
new DefaultServerOAuth2AuthorizationRequestResolver(new InMemoryReactiveClientRegistrationRepository(
TestClientRegistrations.clientRegistration().build())));

@Bean
SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
// @formatter:off
Expand Down Expand Up @@ -864,6 +871,11 @@ ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory() {
return this.jwtDecoderFactory;
}

@Bean
ServerOAuth2AuthorizationRequestResolver resolver() {
return this.resolver;
}

@Bean
ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient() {
return this.tokenResponseClient;
Expand Down