Skip to content

Commit ef9896d

Browse files
committed
Refactor RefreshOidcIdTokenHandler and Add tests
Signed-off-by: Hao <kyrieeeee2@gmail.com>
1 parent c037fff commit ef9896d

File tree

2 files changed

+321
-11
lines changed

2 files changed

+321
-11
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandler.java

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.springframework.security.oauth2.core.OAuth2Error;
3232
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
3333
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
34+
import org.springframework.security.oauth2.core.oidc.OidcScopes;
3435
import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
3536
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
3637
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
@@ -46,6 +47,8 @@
4647
*/
4748
public class RefreshOidcIdTokenHandler implements ApplicationListener<OAuth2TokenRefreshedEvent> {
4849

50+
private static final String MISSING_ID_TOKEN_ERROR_CODE = "missing_id_token";
51+
4952
private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token";
5053

5154
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
@@ -56,20 +59,31 @@ public class RefreshOidcIdTokenHandler implements ApplicationListener<OAuth2Toke
5659
@Override
5760
public void onApplicationEvent(OAuth2TokenRefreshedEvent event) {
5861
OAuth2AuthorizedClient authorizedClient = event.getAuthorizedClient();
62+
63+
if (!authorizedClient.getClientRegistration().getScopes().contains(OidcScopes.OPENID)) {
64+
return;
65+
}
66+
67+
Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
68+
if (!(authentication instanceof OAuth2AuthenticationToken oauth2Authentication)) {
69+
return;
70+
}
71+
if (!(authentication.getPrincipal() instanceof DefaultOidcUser defaultOidcUser)) {
72+
return;
73+
}
74+
5975
OAuth2AccessTokenResponse accessTokenResponse = event.getAccessTokenResponse();
76+
77+
String idToken = (String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN);
78+
if (idToken == null || idToken.isBlank()) {
79+
OAuth2Error missingIdTokenError = new OAuth2Error(MISSING_ID_TOKEN_ERROR_CODE,
80+
"ID token is missing in the token response", null);
81+
throw new OAuth2AuthenticationException(missingIdTokenError, missingIdTokenError.toString());
82+
}
83+
6084
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
6185
OidcIdToken refreshedOidcToken = createOidcToken(clientRegistration, accessTokenResponse);
62-
Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
63-
if (authentication instanceof OAuth2AuthenticationToken oauth2AuthenticationToken) {
64-
if (authentication.getPrincipal() instanceof DefaultOidcUser defaultOidcUser) {
65-
OidcUser oidcUser = new DefaultOidcUser(defaultOidcUser.getAuthorities(), refreshedOidcToken,
66-
defaultOidcUser.getUserInfo(), StandardClaimNames.SUB);
67-
SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
68-
context.setAuthentication(new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(),
69-
oauth2AuthenticationToken.getAuthorizedClientRegistrationId()));
70-
this.securityContextHolderStrategy.setContext(context);
71-
}
72-
}
86+
updateSecurityContext(oauth2Authentication, defaultOidcUser, refreshedOidcToken);
7387
}
7488

7589
/**
@@ -92,6 +106,18 @@ public final void setJwtDecoderFactory(JwtDecoderFactory<ClientRegistration> jwt
92106
this.jwtDecoderFactory = jwtDecoderFactory;
93107
}
94108

109+
private void updateSecurityContext(OAuth2AuthenticationToken oauth2Authentication, DefaultOidcUser defaultOidcUser,
110+
OidcIdToken refreshedOidcToken) {
111+
OidcUser oidcUser = new DefaultOidcUser(defaultOidcUser.getAuthorities(), refreshedOidcToken,
112+
defaultOidcUser.getUserInfo(), StandardClaimNames.SUB);
113+
114+
SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
115+
context.setAuthentication(new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(),
116+
oauth2Authentication.getAuthorizedClientRegistrationId()));
117+
118+
this.securityContextHolderStrategy.setContext(context);
119+
}
120+
95121
private OidcIdToken createOidcToken(ClientRegistration clientRegistration,
96122
OAuth2AccessTokenResponse accessTokenResponse) {
97123
JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
/*
2+
* Copyright 2002-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.oauth2.client.oidc.authentication;
18+
19+
import java.time.Instant;
20+
import java.util.Collections;
21+
import java.util.HashMap;
22+
import java.util.Map;
23+
24+
import org.junit.jupiter.api.BeforeEach;
25+
import org.junit.jupiter.api.Test;
26+
import org.mockito.ArgumentCaptor;
27+
28+
import org.springframework.security.authentication.TestingAuthenticationToken;
29+
import org.springframework.security.core.authority.AuthorityUtils;
30+
import org.springframework.security.core.context.SecurityContext;
31+
import org.springframework.security.core.context.SecurityContextHolderStrategy;
32+
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
33+
import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
34+
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
35+
import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
36+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
37+
import org.springframework.security.oauth2.core.AuthorizationGrantType;
38+
import org.springframework.security.oauth2.core.OAuth2AccessToken;
39+
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
40+
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
41+
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
42+
import org.springframework.security.oauth2.core.oidc.OidcScopes;
43+
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
44+
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
45+
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
46+
import org.springframework.security.oauth2.jwt.Jwt;
47+
import org.springframework.security.oauth2.jwt.JwtDecoder;
48+
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
49+
import org.springframework.security.oauth2.jwt.JwtException;
50+
51+
import static org.assertj.core.api.Assertions.assertThat;
52+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
53+
import static org.mockito.ArgumentMatchers.any;
54+
import static org.mockito.BDDMockito.given;
55+
import static org.mockito.Mockito.mock;
56+
import static org.mockito.Mockito.never;
57+
import static org.mockito.Mockito.verify;
58+
59+
class RefreshOidcIdTokenHandlerTests {
60+
61+
private static final String EXISTING_ID_TOKEN_VALUE = "id-token-value";
62+
63+
private static final String REFRESHED_ID_TOKEN_VALUE = "new-id-token-value";
64+
65+
private static final String EXISTING_ACCESS_TOKEN_VALUE = "token-value";
66+
67+
private static final String REFRESHED_ACCESS_TOKEN_VALUE = "new-token-value";
68+
69+
private RefreshOidcIdTokenHandler handler;
70+
71+
private RefreshTokenOAuth2AuthorizedClientProvider provider;
72+
73+
private ClientRegistration clientRegistration;
74+
75+
private OAuth2AuthorizedClient authorizedClient;
76+
77+
private JwtDecoder jwtDecoder;
78+
79+
private SecurityContext securityContext;
80+
81+
private OidcIdToken existingIdToken;
82+
83+
@BeforeEach
84+
void setUp() {
85+
this.handler = new RefreshOidcIdTokenHandler();
86+
87+
this.clientRegistration = createClientRegistrationWithScopes(OidcScopes.OPENID);
88+
this.authorizedClient = createAuthorizedClient(this.clientRegistration);
89+
90+
this.provider = mock(RefreshTokenOAuth2AuthorizedClientProvider.class);
91+
92+
JwtDecoderFactory<ClientRegistration> jwtDecoderFactory = mock(JwtDecoderFactory.class);
93+
this.jwtDecoder = mock(JwtDecoder.class);
94+
SecurityContextHolderStrategy securityContextHolderStrategy = mock(SecurityContextHolderStrategy.class);
95+
this.securityContext = mock(SecurityContext.class);
96+
97+
this.handler.setJwtDecoderFactory(jwtDecoderFactory);
98+
this.handler.setSecurityContextHolderStrategy(securityContextHolderStrategy);
99+
100+
given(jwtDecoderFactory.createDecoder(any())).willReturn(this.jwtDecoder);
101+
given(securityContextHolderStrategy.createEmptyContext()).willReturn(this.securityContext);
102+
given(securityContextHolderStrategy.getContext()).willReturn(this.securityContext);
103+
104+
Map<String, Object> claims = new HashMap<>();
105+
claims.put("sub", "subject");
106+
Jwt existingIdTokenJwt = new Jwt(EXISTING_ID_TOKEN_VALUE, Instant.now(), Instant.now().plusSeconds(3600),
107+
Map.of("alg", "RS256"), claims);
108+
Jwt refreshedIdTokenJwt = new Jwt(REFRESHED_ID_TOKEN_VALUE, Instant.now(), Instant.now().plusSeconds(3600),
109+
Map.of("alg", "RS256"), claims);
110+
111+
this.existingIdToken = new OidcIdToken(existingIdTokenJwt.getTokenValue(), existingIdTokenJwt.getIssuedAt(),
112+
existingIdTokenJwt.getExpiresAt(), existingIdTokenJwt.getClaims());
113+
114+
given(this.jwtDecoder.decode(existingIdTokenJwt.getTokenValue())).willReturn(existingIdTokenJwt);
115+
given(this.jwtDecoder.decode(refreshedIdTokenJwt.getTokenValue())).willReturn(refreshedIdTokenJwt);
116+
}
117+
118+
@Test
119+
void handleEventWhenValidIdTokenThenUpdatesSecurityContext() {
120+
121+
DefaultOidcUser existingUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"),
122+
this.existingIdToken);
123+
OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(existingUser,
124+
existingUser.getAuthorities(), "registration-id");
125+
given(this.securityContext.getAuthentication()).willReturn(existingAuth);
126+
127+
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
128+
.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
129+
.tokenType(OAuth2AccessToken.TokenType.BEARER)
130+
.expiresIn(3600)
131+
.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE))
132+
.build();
133+
134+
OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
135+
accessTokenResponse);
136+
this.handler.onApplicationEvent(event);
137+
138+
ArgumentCaptor<OAuth2AuthenticationToken> authenticationCaptor = ArgumentCaptor
139+
.forClass(OAuth2AuthenticationToken.class);
140+
verify(this.securityContext).setAuthentication(authenticationCaptor.capture());
141+
142+
OAuth2AuthenticationToken newAuthentication = authenticationCaptor.getValue();
143+
assertThat(newAuthentication.getPrincipal()).isInstanceOf(DefaultOidcUser.class);
144+
DefaultOidcUser newUser = (DefaultOidcUser) newAuthentication.getPrincipal();
145+
assertThat(newUser.getIdToken().getTokenValue()).isEqualTo(REFRESHED_ID_TOKEN_VALUE);
146+
}
147+
148+
@Test
149+
void handleEventWhenAuthorizedClientIsNotOidcThenDoesNothing() {
150+
151+
this.clientRegistration = createClientRegistrationWithScopes("read");
152+
this.authorizedClient = createAuthorizedClient(this.clientRegistration);
153+
154+
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
155+
.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
156+
.tokenType(OAuth2AccessToken.TokenType.BEARER)
157+
.expiresIn(3600)
158+
.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE))
159+
.build();
160+
161+
OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
162+
accessTokenResponse);
163+
164+
this.handler.onApplicationEvent(event);
165+
166+
verify(this.securityContext, never()).setAuthentication(any());
167+
verify(this.jwtDecoder, never()).decode(any());
168+
}
169+
170+
@Test
171+
void handleEventWhenAuthenticationNotOAuth2AuthenticationTokenThenDoesNothing() {
172+
173+
given(this.securityContext.getAuthentication()).willReturn(mock(TestingAuthenticationToken.class));
174+
175+
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
176+
.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
177+
.tokenType(OAuth2AccessToken.TokenType.BEARER)
178+
.expiresIn(3600)
179+
.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE))
180+
.build();
181+
182+
OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
183+
accessTokenResponse);
184+
185+
this.handler.onApplicationEvent(event);
186+
187+
verify(this.securityContext, never()).setAuthentication(any());
188+
}
189+
190+
@Test
191+
void handleEventWhenNotOidcUserThenDoesNothing() {
192+
193+
OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(
194+
new DefaultOAuth2User(Collections.emptySet(),
195+
Collections.singletonMap("custom-attribute", "test-subject"), "custom-attribute"),
196+
AuthorityUtils.createAuthorityList("ROLE_USER"), "registration-id");
197+
given(this.securityContext.getAuthentication()).willReturn(existingAuth);
198+
199+
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
200+
.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
201+
.tokenType(OAuth2AccessToken.TokenType.BEARER)
202+
.expiresIn(3600)
203+
.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE))
204+
.build();
205+
206+
OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
207+
accessTokenResponse);
208+
209+
this.handler.onApplicationEvent(event);
210+
211+
verify(this.securityContext, never()).setAuthentication(any());
212+
}
213+
214+
@Test
215+
void handleEventWhenMissingIdTokenThenThrowsException() {
216+
217+
DefaultOidcUser existingUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"),
218+
this.existingIdToken);
219+
OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(existingUser,
220+
existingUser.getAuthorities(), "registration-id");
221+
given(this.securityContext.getAuthentication()).willReturn(existingAuth);
222+
223+
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
224+
.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
225+
.tokenType(OAuth2AccessToken.TokenType.BEARER)
226+
.expiresIn(3600)
227+
.additionalParameters(new HashMap<>()) // missing ID token
228+
.build();
229+
230+
OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
231+
accessTokenResponse);
232+
233+
assertThatExceptionOfType(OAuth2AuthenticationException.class)
234+
.isThrownBy(() -> this.handler.onApplicationEvent(event))
235+
.withMessageContaining("missing_id_token");
236+
}
237+
238+
@Test
239+
void handleEventWhenInvalidIdTokenThenThrowsException() {
240+
241+
DefaultOidcUser existingUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"),
242+
this.existingIdToken);
243+
OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(existingUser,
244+
existingUser.getAuthorities(), "registration-id");
245+
given(this.securityContext.getAuthentication()).willReturn(existingAuth);
246+
247+
given(this.jwtDecoder.decode(any())).willThrow(new JwtException("Invalid token"));
248+
249+
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
250+
.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
251+
.tokenType(OAuth2AccessToken.TokenType.BEARER)
252+
.expiresIn(3600)
253+
.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, "invalid-id-token"))
254+
.build();
255+
256+
OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
257+
accessTokenResponse);
258+
259+
assertThatExceptionOfType(OAuth2AuthenticationException.class)
260+
.isThrownBy(() -> this.handler.onApplicationEvent(event))
261+
.withMessageContaining("invalid_id_token");
262+
}
263+
264+
private ClientRegistration createClientRegistrationWithScopes(String... scope) {
265+
return ClientRegistration.withRegistrationId("registration-id")
266+
.clientId("client-id")
267+
.clientSecret("secret")
268+
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
269+
.redirectUri("http://localhost")
270+
.scope(scope)
271+
.authorizationUri("https://provider.com/oauth2/authorize")
272+
.tokenUri("https://provider.com/oauth2/token")
273+
.jwkSetUri("https://provider.com/jwk")
274+
.userInfoUri("https://provider.com/user")
275+
.build();
276+
}
277+
278+
private static OAuth2AuthorizedClient createAuthorizedClient(ClientRegistration clientRegistration) {
279+
return new OAuth2AuthorizedClient(clientRegistration, "principal-name",
280+
new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, EXISTING_ACCESS_TOKEN_VALUE, Instant.now(),
281+
Instant.now().plusSeconds(3600)));
282+
}
283+
284+
}

0 commit comments

Comments
 (0)