Skip to content

Commit b9cc6e1

Browse files
yhao3sjohnr
authored andcommitted
Ensure ID Token is updated after refresh token
Signed-off-by: Hao <kyrieeeee2@gmail.com>
1 parent 96b9820 commit b9cc6e1

File tree

9 files changed

+595
-4
lines changed

9 files changed

+595
-4
lines changed

config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
3535
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
3636
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
37+
import org.springframework.context.ApplicationContext;
38+
import org.springframework.context.ApplicationContextAware;
39+
import org.springframework.context.ApplicationEventPublisher;
3740
import org.springframework.context.annotation.AnnotationBeanNameGenerator;
3841
import org.springframework.context.annotation.Bean;
3942
import org.springframework.context.annotation.Configuration;
@@ -160,7 +163,7 @@ private OAuth2AuthorizedClientManager getAuthorizedClientManager() {
160163
* @since 6.2.0
161164
*/
162165
static final class OAuth2AuthorizedClientManagerRegistrar
163-
implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware {
166+
implements ApplicationContextAware, BeanDefinitionRegistryPostProcessor, BeanFactoryAware {
164167

165168
static final String BEAN_NAME = "authorizedClientManagerRegistrar";
166169

@@ -179,6 +182,8 @@ static final class OAuth2AuthorizedClientManagerRegistrar
179182

180183
private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator();
181184

185+
private ApplicationEventPublisher eventPublisher;
186+
182187
private ListableBeanFactory beanFactory;
183188

184189
@Override
@@ -302,6 +307,10 @@ private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider(
302307
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
303308
}
304309

310+
if (this.eventPublisher != null) {
311+
authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher);
312+
}
313+
305314
return authorizedClientProvider;
306315
}
307316

@@ -423,6 +432,11 @@ private <T> T getBeanOfType(ResolvableType resolvableType) {
423432
return objectProvider.getIfAvailable();
424433
}
425434

435+
@Override
436+
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
437+
this.eventPublisher = applicationContext;
438+
}
439+
426440
}
427441

428442
}

config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
5858
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
5959
import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider;
60+
import org.springframework.security.oauth2.client.oidc.authentication.RefreshOidcIdTokenHandler;
6061
import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry;
6162
import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation;
6263
import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry;
@@ -394,6 +395,15 @@ public void init(B http) throws Exception {
394395
oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper);
395396
}
396397
http.authenticationProvider(this.postProcess(oidcAuthorizationCodeAuthenticationProvider));
398+
399+
RefreshOidcIdTokenHandler refreshOidcIdTokenHandler = new RefreshOidcIdTokenHandler();
400+
if (this.getSecurityContextHolderStrategy() != null) {
401+
refreshOidcIdTokenHandler.setSecurityContextHolderStrategy(this.getSecurityContextHolderStrategy());
402+
}
403+
if (jwtDecoderFactory != null) {
404+
refreshOidcIdTokenHandler.setJwtDecoderFactory(jwtDecoderFactory);
405+
}
406+
registerDelegateApplicationListener(refreshOidcIdTokenHandler);
397407
}
398408
else {
399409
http.authenticationProvider(new OidcAuthenticationRequestChecker());

config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
3535
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
3636
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
37+
import org.springframework.context.ApplicationEventPublisher;
3738
import org.springframework.context.annotation.AnnotationBeanNameGenerator;
3839
import org.springframework.core.ResolvableType;
3940
import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider;
@@ -197,6 +198,12 @@ private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider(
197198
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
198199
}
199200

201+
ApplicationEventPublisher applicationEventPublisher = getBeanOfType(
202+
ResolvableType.forClass(ApplicationEventPublisher.class));
203+
if (applicationEventPublisher != null) {
204+
authorizedClientProvider.setApplicationEventPublisher(applicationEventPublisher);
205+
}
206+
200207
return authorizedClientProvider;
201208
}
202209

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Map;
2626
import java.util.function.Consumer;
2727

28+
import org.springframework.context.ApplicationEventPublisher;
2829
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
2930
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
3031
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
@@ -359,6 +360,8 @@ public final class RefreshTokenGrantBuilder implements Builder {
359360

360361
private OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient;
361362

363+
private ApplicationEventPublisher eventPublisher;
364+
362365
private Duration clockSkew;
363366

364367
private Clock clock;
@@ -379,6 +382,17 @@ public RefreshTokenGrantBuilder accessTokenResponseClient(
379382
return this;
380383
}
381384

385+
/**
386+
* Sets the {@link ApplicationEventPublisher} used when an access token is
387+
* refreshed.
388+
* @param eventPublisher the {@link ApplicationEventPublisher}
389+
* @return the {@link RefreshTokenGrantBuilder}
390+
*/
391+
public RefreshTokenGrantBuilder eventPublisher(ApplicationEventPublisher eventPublisher) {
392+
this.eventPublisher = eventPublisher;
393+
return this;
394+
}
395+
382396
/**
383397
* Sets the maximum acceptable clock skew, which is used when checking the access
384398
* token expiry. An access token is considered expired if
@@ -414,6 +428,9 @@ public OAuth2AuthorizedClientProvider build() {
414428
if (this.accessTokenResponseClient != null) {
415429
authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient);
416430
}
431+
if (this.eventPublisher != null) {
432+
authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher);
433+
}
417434
if (this.clockSkew != null) {
418435
authorizedClientProvider.setClockSkew(this.clockSkew);
419436
}

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@
2424
import java.util.HashSet;
2525
import java.util.Set;
2626

27+
import org.springframework.context.ApplicationEventPublisher;
28+
import org.springframework.context.ApplicationEventPublisherAware;
2729
import org.springframework.lang.Nullable;
2830
import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient;
2931
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
3032
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
33+
import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
3134
import org.springframework.security.oauth2.core.AuthorizationGrantType;
3235
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
3336
import org.springframework.security.oauth2.core.OAuth2Token;
@@ -43,10 +46,13 @@
4346
* @see OAuth2AuthorizedClientProvider
4447
* @see DefaultRefreshTokenTokenResponseClient
4548
*/
46-
public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider {
49+
public final class RefreshTokenOAuth2AuthorizedClientProvider
50+
implements OAuth2AuthorizedClientProvider, ApplicationEventPublisherAware {
4751

4852
private OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient();
4953

54+
private ApplicationEventPublisher eventPublisher;
55+
5056
private Duration clockSkew = Duration.ofSeconds(60);
5157

5258
private Clock clock = Clock.systemUTC();
@@ -91,8 +97,17 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
9197
authorizedClient.getClientRegistration(), authorizedClient.getAccessToken(),
9298
authorizedClient.getRefreshToken(), scopes);
9399
OAuth2AccessTokenResponse tokenResponse = getTokenResponse(authorizedClient, refreshTokenGrantRequest);
94-
return new OAuth2AuthorizedClient(context.getAuthorizedClient().getClientRegistration(),
95-
context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken());
100+
101+
OAuth2AuthorizedClient updatedOAuth2AuthorizedClient = new OAuth2AuthorizedClient(
102+
authorizedClient.getClientRegistration(), context.getPrincipal().getName(),
103+
tokenResponse.getAccessToken(), tokenResponse.getRefreshToken());
104+
105+
if (this.eventPublisher != null) {
106+
this.eventPublisher
107+
.publishEvent(new OAuth2TokenRefreshedEvent(this, updatedOAuth2AuthorizedClient, tokenResponse));
108+
}
109+
110+
return updatedOAuth2AuthorizedClient;
96111
}
97112

98113
private OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizedClient authorizedClient,
@@ -149,4 +164,9 @@ public void setClock(Clock clock) {
149164
this.clock = clock;
150165
}
151166

167+
@Override
168+
public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
169+
this.eventPublisher = applicationEventPublisher;
170+
}
171+
152172
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright 2002-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.oauth2.client.event;
18+
19+
import org.springframework.context.ApplicationEvent;
20+
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
21+
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
22+
23+
/**
24+
* An event that is published when an OAuth2 access token is refreshed.
25+
*/
26+
public class OAuth2TokenRefreshedEvent extends ApplicationEvent {
27+
28+
private final OAuth2AuthorizedClient authorizedClient;
29+
30+
private final OAuth2AccessTokenResponse accessTokenResponse;
31+
32+
public OAuth2TokenRefreshedEvent(Object source, OAuth2AuthorizedClient authorizedClient,
33+
OAuth2AccessTokenResponse accessTokenResponse) {
34+
super(source);
35+
this.authorizedClient = authorizedClient;
36+
this.accessTokenResponse = accessTokenResponse;
37+
}
38+
39+
public OAuth2AuthorizedClient getAuthorizedClient() {
40+
return this.authorizedClient;
41+
}
42+
43+
public OAuth2AccessTokenResponse getAccessTokenResponse() {
44+
return this.accessTokenResponse;
45+
}
46+
47+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Copyright 2002-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.oauth2.client.oidc.authentication;
18+
19+
import java.util.Map;
20+
21+
import org.springframework.context.ApplicationListener;
22+
import org.springframework.security.core.Authentication;
23+
import org.springframework.security.core.context.SecurityContext;
24+
import org.springframework.security.core.context.SecurityContextHolder;
25+
import org.springframework.security.core.context.SecurityContextHolderStrategy;
26+
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
27+
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
28+
import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
29+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
30+
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
31+
import org.springframework.security.oauth2.core.OAuth2Error;
32+
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
33+
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
34+
import org.springframework.security.oauth2.core.oidc.OidcScopes;
35+
import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
36+
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
37+
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
38+
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
39+
import org.springframework.security.oauth2.jwt.Jwt;
40+
import org.springframework.security.oauth2.jwt.JwtDecoder;
41+
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
42+
import org.springframework.security.oauth2.jwt.JwtException;
43+
import org.springframework.util.Assert;
44+
45+
/**
46+
* An {@link ApplicationListener} that listens for {@link OAuth2TokenRefreshedEvent}s
47+
*/
48+
public class RefreshOidcIdTokenHandler implements ApplicationListener<OAuth2TokenRefreshedEvent> {
49+
50+
private static final String MISSING_ID_TOKEN_ERROR_CODE = "missing_id_token";
51+
52+
private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token";
53+
54+
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
55+
.getContextHolderStrategy();
56+
57+
private JwtDecoderFactory<ClientRegistration> jwtDecoderFactory = new OidcIdTokenDecoderFactory();
58+
59+
@Override
60+
public void onApplicationEvent(OAuth2TokenRefreshedEvent event) {
61+
OAuth2AuthorizedClient authorizedClient = event.getAuthorizedClient();
62+
63+
if (!authorizedClient.getClientRegistration().getScopes().contains(OidcScopes.OPENID)) {
64+
return;
65+
}
66+
67+
Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
68+
if (!(authentication instanceof OAuth2AuthenticationToken oauth2Authentication)) {
69+
return;
70+
}
71+
if (!(authentication.getPrincipal() instanceof DefaultOidcUser defaultOidcUser)) {
72+
return;
73+
}
74+
75+
OAuth2AccessTokenResponse accessTokenResponse = event.getAccessTokenResponse();
76+
77+
String idToken = (String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN);
78+
if (idToken == null || idToken.isBlank()) {
79+
OAuth2Error missingIdTokenError = new OAuth2Error(MISSING_ID_TOKEN_ERROR_CODE,
80+
"ID token is missing in the token response", null);
81+
throw new OAuth2AuthenticationException(missingIdTokenError, missingIdTokenError.toString());
82+
}
83+
84+
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
85+
OidcIdToken refreshedOidcToken = createOidcToken(clientRegistration, accessTokenResponse);
86+
updateSecurityContext(oauth2Authentication, defaultOidcUser, refreshedOidcToken);
87+
}
88+
89+
/**
90+
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
91+
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
92+
*/
93+
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
94+
this.securityContextHolderStrategy = securityContextHolderStrategy;
95+
}
96+
97+
/**
98+
* Sets the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature
99+
* verification. The factory returns a {@link JwtDecoder} associated to the provided
100+
* {@link ClientRegistration}.
101+
* @param jwtDecoderFactory the {@link JwtDecoderFactory} used for {@link OidcIdToken}
102+
* signature verification
103+
*/
104+
public final void setJwtDecoderFactory(JwtDecoderFactory<ClientRegistration> jwtDecoderFactory) {
105+
Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null");
106+
this.jwtDecoderFactory = jwtDecoderFactory;
107+
}
108+
109+
private void updateSecurityContext(OAuth2AuthenticationToken oauth2Authentication, DefaultOidcUser defaultOidcUser,
110+
OidcIdToken refreshedOidcToken) {
111+
OidcUser oidcUser = new DefaultOidcUser(defaultOidcUser.getAuthorities(), refreshedOidcToken,
112+
defaultOidcUser.getUserInfo(), StandardClaimNames.SUB);
113+
114+
SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
115+
context.setAuthentication(new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(),
116+
oauth2Authentication.getAuthorizedClientRegistrationId()));
117+
118+
this.securityContextHolderStrategy.setContext(context);
119+
}
120+
121+
private OidcIdToken createOidcToken(ClientRegistration clientRegistration,
122+
OAuth2AccessTokenResponse accessTokenResponse) {
123+
JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
124+
Jwt jwt = getJwt(accessTokenResponse, jwtDecoder);
125+
return new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims());
126+
}
127+
128+
private Jwt getJwt(OAuth2AccessTokenResponse accessTokenResponse, JwtDecoder jwtDecoder) {
129+
try {
130+
Map<String, Object> parameters = accessTokenResponse.getAdditionalParameters();
131+
return jwtDecoder.decode((String) parameters.get(OidcParameterNames.ID_TOKEN));
132+
}
133+
catch (JwtException ex) {
134+
OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), null);
135+
throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex);
136+
}
137+
}
138+
139+
}

0 commit comments

Comments
 (0)