Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ public Map<String, Object> getTokenResponse(RegisteredClient registeredClient, S

HttpHeaders basicAuth = new HttpHeaders();
basicAuth.setBasicAuth(registeredClient.getClientId(), "secret");
basicAuth.setContentType(MediaType.APPLICATION_FORM_URLENCODED);

MvcResult mvcResult = this.mockMvc.perform(post("/oauth2/token")
.params(parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,15 @@ public final class OAuth2AuthorizationCodeAuthenticationConverter implements Aut
@Nullable
@Override
public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// grant_type (REQUIRED)
String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE);
String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE);
if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(grantType)) {
return null;
}

Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();

MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);

// code (REQUIRED)
String code = parameters.getFirst(OAuth2ParameterNames.CODE);
if (!StringUtils.hasText(code) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@
*/
package org.springframework.security.oauth2.server.authorization.web.authentication;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import javax.servlet.http.HttpServletRequest;

import org.springframework.http.MediaType;
import org.springframework.http.converter.FormHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

import javax.servlet.http.HttpServletRequest;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

/**
* Utility methods for the OAuth 2.0 Protocol Endpoints.
Expand All @@ -37,7 +41,6 @@
*/
final class OAuth2EndpointUtils {
static final String ACCESS_TOKEN_REQUEST_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";

private OAuth2EndpointUtils() {
}

Expand All @@ -54,6 +57,49 @@ static MultiValueMap<String, String> getParameters(HttpServletRequest request) {
return parameters;
}

static MultiValueMap<String, String> getFormParameters(HttpServletRequest request) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method can be simplified. Given the existing getParameters() method, the only change required to make it function as getFormParameters() is:

...

parameterMap.forEach((key, values) -> {
	if (!request.getQueryString().contains(key) &&	// If not query parameter then it's a form parameter
			values.length > 0) {
		for (String value : values) {
			parameters.add(key, value);
		}
	}
});

...

And for getQueryParameters(), applying if (request.getQueryString().contains(key) ... would work.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. That's a simple way to implement that.

MultiValueMap<String, String> ret = new LinkedMultiValueMap<>();
try {
HttpMessageConverter converter = new FormHttpMessageConverter();
String contentType = request.getContentType();
MediaType type = StringUtils.hasText(contentType) ? MediaType.valueOf(contentType) : null;
ServletServerHttpRequest serverHttpRequest = new ServletServerHttpRequest(request);

if (converter.canRead(MultiValueMap.class, type)) {
ret = (MultiValueMap<String, String>) converter.read(null, serverHttpRequest);
}
} catch (Exception e) {
}

return ret;
}

static MultiValueMap<String, String> getQueryParameters(HttpServletRequest request) {
String queryParameters = request.getQueryString();
return parseQueryString(queryParameters);
}

static MultiValueMap<String, String> parseQueryString(String queryString) {
MultiValueMap<String, String> queryParameters = new LinkedMultiValueMap<>();
if (!StringUtils.hasText(queryString)) {
return queryParameters;
}

String[] parameters = queryString.split("&");
for (String parameter : parameters) {
String[] keyValuePair = parameter.split("=");
String value = keyValuePair.length < 2 ? null : keyValuePair[1];
if (value != null) {
try {
value = java.net.URLDecoder.decode(value, "UTF-8");
} catch (Exception e) {
}
}
queryParameters.add(keyValuePair[0], value);
}
return queryParameters;
}

static Map<String, Object> getParametersIfMatchesAuthorizationCodeGrantRequest(HttpServletRequest request, String... exclusions) {
if (!matchesAuthorizationCodeGrantRequest(request)) {
return Collections.emptyMap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.springframework.context.annotation.Import;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.JdbcTemplate;
Expand Down Expand Up @@ -355,7 +356,8 @@ private OAuth2AccessTokenResponse assertTokenRequestReturnsAccessTokenResponse(R
OAuth2Authorization authorization, String tokenEndpointUri) throws Exception {
MvcResult mvcResult = this.mvc.perform(post(tokenEndpointUri)
.params(getTokenRequestParameters(registeredClient, authorization))
.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient)))
.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient))
.contentType(MediaType.APPLICATION_FORM_URLENCODED_VALUE))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove all changes where the content type is set in the request as it's redundant. Let's keep the changes in this PR to only what's required for the fix.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted all changes for the content type changes.

.andExpect(status().isOk())
.andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store")))
.andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache")))
Expand Down Expand Up @@ -406,7 +408,8 @@ public void requestWhenPublicClientWithPkceThenReturnAccessTokenResponse() throw
this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
.param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER))
.param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER)
.contentType(MediaType.APPLICATION_FORM_URLENCODED_VALUE))
.andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store")))
.andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache")))
.andExpect(status().isOk())
Expand Down Expand Up @@ -498,7 +501,8 @@ public void requestWhenCustomTokenGeneratorThenUsed() throws Exception {

this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
.params(getTokenRequestParameters(registeredClient, authorization))
.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient)))
.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient))
.contentType(MediaType.APPLICATION_FORM_URLENCODED))
.andExpect(status().isOk());

verify(this.tokenGenerator, times(2)).generate(any());
Expand Down Expand Up @@ -575,7 +579,8 @@ public void requestWhenConsentRequestThenReturnAccessTokenResponse() throws Exce

this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient)))
.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient))
.contentType(MediaType.APPLICATION_FORM_URLENCODED_VALUE))
.andExpect(status().isOk())
.andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store")))
.andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache")))
Expand Down Expand Up @@ -662,7 +667,8 @@ public void requestWhenCustomConsentCustomizerConfiguredThenUsed() throws Except

mvcResult = this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient)))
.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient))
.contentType(MediaType.APPLICATION_FORM_URLENCODED_VALUE))
.andExpect(status().isOk())
.andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store")))
.andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache")))
Expand Down Expand Up @@ -757,7 +763,8 @@ public void requestWhenClientObtainsAccessTokenThenClientAuthenticationNotPersis
mvcResult = this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
.param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER))
.param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER)
.contentType(MediaType.APPLICATION_FORM_URLENCODED_VALUE))
.andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store")))
.andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache")))
.andExpect(status().isOk())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.springframework.context.annotation.Import;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.JdbcTemplate;
Expand Down Expand Up @@ -294,7 +295,8 @@ public void requestWhenObtainReferenceAccessTokenAndIntrospectThenActive() throw
// @formatter:off
MvcResult mvcResult = this.mvc.perform(post(authorizationServerSettings.getTokenEndpoint())
.params(getAuthorizationCodeTokenRequestParameters(authorizedRegisteredClient, authorization))
.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(authorizedRegisteredClient)))
.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(authorizedRegisteredClient))
.contentType(MediaType.APPLICATION_FORM_URLENCODED_VALUE))
.andExpect(status().isOk())
.andReturn();
// @formatter:on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.springframework.context.annotation.Bean;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.JdbcTemplate;
Expand Down Expand Up @@ -198,7 +199,8 @@ public void requestWhenAuthenticationRequestThenTokenResponseIncludesIdToken() t
mvcResult = this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
.params(getTokenRequestParameters(registeredClient, authorization))
.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
registeredClient.getClientId(), registeredClient.getClientSecret())))
registeredClient.getClientId(), registeredClient.getClientSecret()))
.contentType(MediaType.APPLICATION_FORM_URLENCODED_VALUE))
.andExpect(status().isOk())
.andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store")))
.andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache")))
Expand Down Expand Up @@ -239,7 +241,8 @@ public void requestWhenCustomTokenGeneratorThenUsed() throws Exception {
this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
.params(getTokenRequestParameters(registeredClient, authorization))
.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
registeredClient.getClientId(), registeredClient.getClientSecret())))
registeredClient.getClientId(), registeredClient.getClientSecret()))
.contentType(MediaType.APPLICATION_FORM_URLENCODED_VALUE))
.andExpect(status().isOk());

verify(this.tokenGenerator, times(3)).generate(any());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.mockito.ArgumentCaptor;

import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.mock.http.client.MockClientHttpResponse;
import org.springframework.mock.web.MockHttpServletRequest;
Expand Down Expand Up @@ -606,7 +607,7 @@ private static MockHttpServletRequest createAuthorizationCodeTokenRequest(Regist
MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
request.setServletPath(requestUri);
request.setRemoteAddr(REMOTE_ADDRESS);

request.setContentType(MediaType.APPLICATION_FORM_URLENCODED_VALUE);
request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]);
Expand Down