Skip to content

Commit

Permalink
Add setBodyExtractor
Browse files Browse the repository at this point in the history
  • Loading branch information
bishoybasily authored and jzheaux committed Sep 22, 2021
1 parent c3ba233 commit 8606904
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
Expand All @@ -35,6 +36,7 @@
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
Expand Down Expand Up @@ -68,6 +70,9 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T

private Converter<T, HttpHeaders> headersConverter = this::populateTokenRequestHeaders;

private BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = OAuth2BodyExtractors
.oauth2AccessTokenResponse();

AbstractWebClientReactiveOAuth2AccessTokenResponseClient() {
}

Expand Down Expand Up @@ -204,7 +209,7 @@ Set<String> defaultScopes(T grantRequest) {
* @return the token response from the response body.
*/
private Mono<OAuth2AccessTokenResponse> readTokenResponse(T grantRequest, ClientResponse response) {
return response.body(OAuth2BodyExtractors.oauth2AccessTokenResponse())
return response.body(this.bodyExtractor)
.map((tokenResponse) -> populateTokenResponse(grantRequest, tokenResponse));
}

Expand Down Expand Up @@ -291,4 +296,17 @@ public final void addHeadersConverter(Converter<T, HttpHeaders> headersConverter
};
}

/**
* Sets the {@link BodyExtractor} that will be used to decode the
* {@link OAuth2AccessTokenResponse}
* @param bodyExtractor the {@link BodyExtractor} that will be used to decode the
* {@link OAuth2AccessTokenResponse}
* @since 5.6
*/
public final void setBodyExtractor(
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor) {
Assert.notNull(bodyExtractor, "bodyExtractor cannot be null");
this.bodyExtractor = bodyExtractor;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;

import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
Expand All @@ -41,11 +43,14 @@
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.client.WebClient;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -409,4 +414,25 @@ public void convertWhenHeadersConverterSetThenCalled() throws Exception {
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
}

// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {

String accessTokenSuccessResponse = "{}";

WebClientReactiveAuthorizationCodeTokenResponseClient customClient = new WebClientReactiveAuthorizationCodeTokenResponseClient();

BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(extractor.extract(any(), any())).willReturn(Mono.just(response));

customClient.setBodyExtractor(extractor);

this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(authorizationCodeGrantRequest())
.block();
assertThat(accessTokenResponse.getAccessToken()).isNotNull();

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,26 @@
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;

import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.client.WebClientResponseException;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -280,4 +285,26 @@ public void convertWhenHeadersConverterSetThenCalled() throws Exception {
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
}

// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {

enqueueJson("{}");

WebClientReactiveClientCredentialsTokenResponseClient customClient = new WebClientReactiveClientCredentialsTokenResponseClient();

BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(extractor.extract(any(), any())).willReturn(Mono.just(response));

customClient.setBodyExtractor(extractor);

OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.build());

OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(request).block();
assertThat(accessTokenResponse.getAccessToken()).isNotNull();

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,26 @@
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;

import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.web.reactive.function.BodyExtractor;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -286,4 +291,29 @@ public void convertWhenHeadersConverterSetThenCalled() throws Exception {
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
}

// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {

String accessTokenSuccessResponse = "{}";

WebClientReactivePasswordTokenResponseClient customClient = new WebClientReactivePasswordTokenResponseClient();

BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(extractor.extract(any(), any())).willReturn(Mono.just(response));

customClient.setBodyExtractor(extractor);

ClientRegistration clientRegistration = this.clientRegistrationBuilder.build();
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration,
this.username, this.password);

this.server.enqueue(jsonResponse(accessTokenSuccessResponse));

OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(passwordGrantRequest).block();
assertThat(accessTokenResponse.getAccessToken()).isNotNull();

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;

import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
Expand All @@ -39,10 +41,13 @@
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.web.reactive.function.BodyExtractor;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -289,4 +294,27 @@ public void convertWhenHeadersConverterSetThenCalled() throws Exception {
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
}

// gh-10260
@Test
public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {

String accessTokenSuccessResponse = "{}";

WebClientReactiveRefreshTokenTokenResponseClient customClient = new WebClientReactiveRefreshTokenTokenResponseClient();

BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(extractor.extract(any(), any())).willReturn(Mono.just(response));

customClient.setBodyExtractor(extractor);

OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);

this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(refreshTokenGrantRequest).block();
assertThat(accessTokenResponse.getAccessToken()).isNotNull();

}

}

0 comments on commit 8606904

Please sign in to comment.