Skip to content

Commit acf05c2

Browse files
committed
Always return current ClientRegistration in loadAuthorizedClient
This changes `InMemoryOAuth2AuthorizedClientService.loadAuthorizedClient` (and its reactive counterpart) to always return `OAuth2AuthorizedClient` instances containing the current `ClientRegistration` as obtained from the `ClientRegistrationRepository`. Before this change, the first `ClientRegistration` instance was cached, with the effect that any changes made in the `ClientRegistrationRepository` (such as a new client secret) would not have taken effect. Closes gh-15511
1 parent 30c9860 commit acf05c2

File tree

4 files changed

+143
-19
lines changed

4 files changed

+143
-19
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -80,7 +80,13 @@ public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRe
8080
if (registration == null) {
8181
return null;
8282
}
83-
return (T) this.authorizedClients.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName));
83+
OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients
84+
.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName));
85+
if (cachedAuthorizedClient == null) {
86+
return null;
87+
}
88+
return (T) new OAuth2AuthorizedClient(registration, cachedAuthorizedClient.getPrincipalName(),
89+
cachedAuthorizedClient.getAccessToken(), cachedAuthorizedClient.getRefreshToken());
8490
}
8591

8692
@Override

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -62,8 +62,19 @@ public <T extends OAuth2AuthorizedClient> Mono<T> loadAuthorizedClient(String cl
6262
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
6363
Assert.hasText(principalName, "principalName cannot be empty");
6464
return (Mono<T>) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
65-
.map((clientRegistration) -> new OAuth2AuthorizedClientId(clientRegistrationId, principalName))
66-
.flatMap((identifier) -> Mono.justOrEmpty(this.authorizedClients.get(identifier)));
65+
.mapNotNull((clientRegistration) -> {
66+
OAuth2AuthorizedClientId id = new OAuth2AuthorizedClientId(clientRegistrationId, principalName);
67+
OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients.get(id);
68+
if (cachedAuthorizedClient == null) {
69+
return null;
70+
}
71+
// @formatter:off
72+
return new OAuth2AuthorizedClient(clientRegistration,
73+
cachedAuthorizedClient.getPrincipalName(),
74+
cachedAuthorizedClient.getAccessToken(),
75+
cachedAuthorizedClient.getRefreshToken());
76+
// @formatter:on
77+
});
6778
}
6879

6980
@Override

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -28,12 +28,9 @@
2828
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
2929
import org.springframework.security.oauth2.core.OAuth2AccessToken;
3030

31-
import static org.assertj.core.api.Assertions.assertThat;
32-
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
33-
import static org.assertj.core.api.Assertions.assertThatObject;
34-
import static org.mockito.ArgumentMatchers.eq;
35-
import static org.mockito.BDDMockito.given;
36-
import static org.mockito.Mockito.mock;
31+
import static org.assertj.core.api.Assertions.*;
32+
import static org.mockito.ArgumentMatchers.*;
33+
import static org.mockito.BDDMockito.*;
3734

3835
/**
3936
* Tests for {@link InMemoryOAuth2AuthorizedClientService}.
@@ -79,9 +76,11 @@ public void constructorWhenAuthorizedClientsIsNullThenThrowIllegalArgumentExcept
7976
@Test
8077
public void constructorWhenAuthorizedClientsProvidedThenUseProvidedAuthorizedClients() {
8178
String registrationId = this.registration3.getRegistrationId();
79+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration3, this.principalName1,
80+
mock(OAuth2AccessToken.class));
8281
Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = Collections.singletonMap(
8382
new OAuth2AuthorizedClientId(this.registration3.getRegistrationId(), this.principalName1),
84-
mock(OAuth2AuthorizedClient.class));
83+
authorizedClient);
8584
ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
8685
given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3);
8786
InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
@@ -124,7 +123,35 @@ public void loadAuthorizedClientWhenClientRegistrationFoundAndAssociatedToPrinci
124123
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
125124
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
126125
.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
127-
assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
126+
assertAuthorizedClientEquals(authorizedClient, loadedAuthorizedClient);
127+
}
128+
129+
@Test
130+
public void loadAuthorizedClientWhenClientRegistrationIsUpdatedThenReturnAuthorizedClientWithUpdatedClientRegistration() {
131+
ClientRegistration updatedRegistration = ClientRegistration.withClientRegistration(this.registration1)
132+
.clientSecret("updated secret")
133+
.build();
134+
ClientRegistrationRepository repository = mock(ClientRegistrationRepository.class);
135+
given(repository.findByRegistrationId(this.registration1.getRegistrationId())).willReturn(this.registration1,
136+
updatedRegistration);
137+
138+
Authentication authentication = mock(Authentication.class);
139+
given(authentication.getName()).willReturn(this.principalName1);
140+
141+
InMemoryOAuth2AuthorizedClientService service = new InMemoryOAuth2AuthorizedClientService(repository);
142+
143+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1,
144+
mock(OAuth2AccessToken.class));
145+
service.saveAuthorizedClient(authorizedClient, authentication);
146+
147+
OAuth2AuthorizedClient authorizedClientWithUpdatedRegistration = new OAuth2AuthorizedClient(updatedRegistration,
148+
this.principalName1, mock(OAuth2AccessToken.class));
149+
OAuth2AuthorizedClient firstLoadedClient = service.loadAuthorizedClient(this.registration1.getRegistrationId(),
150+
this.principalName1);
151+
OAuth2AuthorizedClient secondLoadedClient = service.loadAuthorizedClient(this.registration1.getRegistrationId(),
152+
this.principalName1);
153+
assertAuthorizedClientEquals(authorizedClient, firstLoadedClient);
154+
assertAuthorizedClientEquals(authorizedClientWithUpdatedRegistration, secondLoadedClient);
128155
}
129156

130157
@Test
@@ -148,7 +175,7 @@ public void saveAuthorizedClientWhenSavedThenCanLoad() {
148175
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
149176
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
150177
.loadAuthorizedClient(this.registration3.getRegistrationId(), this.principalName2);
151-
assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
178+
assertAuthorizedClientEquals(authorizedClient, loadedAuthorizedClient);
152179
}
153180

154181
@Test
@@ -180,4 +207,29 @@ public void removeAuthorizedClientWhenSavedThenRemoved() {
180207
assertThat(loadedAuthorizedClient).isNull();
181208
}
182209

210+
private static void assertAuthorizedClientEquals(OAuth2AuthorizedClient expected, OAuth2AuthorizedClient actual) {
211+
assertThat(actual).isNotNull();
212+
assertThat(actual.getClientRegistration().getRegistrationId())
213+
.isEqualTo(expected.getClientRegistration().getRegistrationId());
214+
assertThat(actual.getClientRegistration().getClientName())
215+
.isEqualTo(expected.getClientRegistration().getClientName());
216+
assertThat(actual.getClientRegistration().getRedirectUri())
217+
.isEqualTo(expected.getClientRegistration().getRedirectUri());
218+
assertThat(actual.getClientRegistration().getAuthorizationGrantType())
219+
.isEqualTo(expected.getClientRegistration().getAuthorizationGrantType());
220+
assertThat(actual.getClientRegistration().getClientAuthenticationMethod())
221+
.isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod());
222+
assertThat(actual.getClientRegistration().getClientId())
223+
.isEqualTo(expected.getClientRegistration().getClientId());
224+
assertThat(actual.getClientRegistration().getClientSecret())
225+
.isEqualTo(expected.getClientRegistration().getClientSecret());
226+
assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName());
227+
assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
228+
assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
229+
assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
230+
assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
231+
assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
232+
assertThat(actual.getRefreshToken()).isEqualTo(expected.getRefreshToken());
233+
}
234+
183235
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientServiceTests.java

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,12 +18,14 @@
1818

1919
import java.time.Duration;
2020
import java.time.Instant;
21+
import java.util.function.Consumer;
2122

2223
import org.junit.jupiter.api.BeforeEach;
2324
import org.junit.jupiter.api.Test;
2425
import org.junit.jupiter.api.extension.ExtendWith;
2526
import org.mockito.Mock;
2627
import org.mockito.junit.jupiter.MockitoExtension;
28+
import reactor.core.publisher.Flux;
2729
import reactor.core.publisher.Mono;
2830
import reactor.test.StepVerifier;
2931

@@ -35,8 +37,8 @@
3537
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
3638
import org.springframework.security.oauth2.core.OAuth2AccessToken;
3739

38-
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
39-
import static org.mockito.BDDMockito.given;
40+
import static org.assertj.core.api.Assertions.*;
41+
import static org.mockito.BDDMockito.*;
4042

4143
/**
4244
* @author Rob Winch
@@ -153,11 +155,37 @@ public void loadAuthorizedClientWhenClientRegistrationFoundThenFound() {
153155
.saveAuthorizedClient(authorizedClient, this.principal)
154156
.then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
155157
StepVerifier.create(saveAndLoad)
156-
.expectNext(authorizedClient)
158+
.assertNext(isEqualTo(authorizedClient))
157159
.verifyComplete();
158160
// @formatter:on
159161
}
160162

163+
@Test
164+
@SuppressWarnings("unchecked")
165+
public void loadAuthorizedClientWhenClientRegistrationChangedThenCurrentVersionFound() {
166+
ClientRegistration changedClientRegistration = ClientRegistration
167+
.withClientRegistration(this.clientRegistration)
168+
.clientSecret("updated secret")
169+
.build();
170+
171+
given(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId))
172+
.willReturn(Mono.just(this.clientRegistration), Mono.just(changedClientRegistration));
173+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
174+
this.principalName, this.accessToken);
175+
OAuth2AuthorizedClient authorizedClientWithChangedRegistration = new OAuth2AuthorizedClient(
176+
changedClientRegistration, this.principalName, this.accessToken);
177+
178+
Flux<OAuth2AuthorizedClient> saveAndLoadTwice = this.authorizedClientService
179+
.saveAuthorizedClient(authorizedClient, this.principal)
180+
.then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName))
181+
.concatWith(
182+
this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
183+
StepVerifier.create(saveAndLoadTwice)
184+
.assertNext(isEqualTo(authorizedClient))
185+
.assertNext(isEqualTo(authorizedClientWithChangedRegistration))
186+
.verifyComplete();
187+
}
188+
161189
@Test
162190
public void saveAuthorizedClientWhenAuthorizedClientNullThenIllegalArgumentException() {
163191
OAuth2AuthorizedClient authorizedClient = null;
@@ -246,4 +274,31 @@ public void removeAuthorizedClientWhenClientRegistrationFoundRemovedThenNotFound
246274
// @formatter:on
247275
}
248276

277+
private static Consumer<OAuth2AuthorizedClient> isEqualTo(OAuth2AuthorizedClient expected) {
278+
return (actual) -> {
279+
assertThat(actual).isNotNull();
280+
assertThat(actual.getClientRegistration().getRegistrationId())
281+
.isEqualTo(expected.getClientRegistration().getRegistrationId());
282+
assertThat(actual.getClientRegistration().getClientName())
283+
.isEqualTo(expected.getClientRegistration().getClientName());
284+
assertThat(actual.getClientRegistration().getRedirectUri())
285+
.isEqualTo(expected.getClientRegistration().getRedirectUri());
286+
assertThat(actual.getClientRegistration().getAuthorizationGrantType())
287+
.isEqualTo(expected.getClientRegistration().getAuthorizationGrantType());
288+
assertThat(actual.getClientRegistration().getClientAuthenticationMethod())
289+
.isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod());
290+
assertThat(actual.getClientRegistration().getClientId())
291+
.isEqualTo(expected.getClientRegistration().getClientId());
292+
assertThat(actual.getClientRegistration().getClientSecret())
293+
.isEqualTo(expected.getClientRegistration().getClientSecret());
294+
assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName());
295+
assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
296+
assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
297+
assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
298+
assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
299+
assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
300+
assertThat(actual.getRefreshToken()).isEqualTo(expected.getRefreshToken());
301+
};
302+
}
303+
249304
}

0 commit comments

Comments
 (0)