Skip to content

Commit

Permalink
Merge branch '6.1.x'
Browse files Browse the repository at this point in the history
  • Loading branch information
jzheaux committed Jul 17, 2023
2 parents 8f5793a + cf2c8da commit 9dc7bdd
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletRegistration;

import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext;
Expand All @@ -36,6 +39,7 @@
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.servlet.handler.HandlerMappingIntrospector;

/**
Expand Down Expand Up @@ -179,14 +183,47 @@ public C requestMatchers(RequestMatcher... requestMatchers) {
* @since 5.8
*/
public C requestMatchers(HttpMethod method, String... patterns) {
List<RequestMatcher> matchers = new ArrayList<>();
if (mvcPresent) {
matchers.addAll(createMvcMatchers(method, patterns));
if (!mvcPresent) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
if (!(this.context instanceof WebApplicationContext)) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
WebApplicationContext context = (WebApplicationContext) this.context;
ServletContext servletContext = context.getServletContext();
if (servletContext == null) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
Map<String, ? extends ServletRegistration> registrations = servletContext.getServletRegistrations();
if (registrations == null) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
if (!hasDispatcherServlet(registrations)) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
else {
matchers.addAll(RequestMatchers.antMatchers(method, patterns));
Assert.isTrue(registrations.size() == 1,
"This method cannot decide whether these patterns are Spring MVC patterns or not. If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); otherwise, please use requestMatchers(AntPathRequestMatcher).");
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
}

private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration> registrations) {
if (registrations == null) {
return false;
}
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
null);
for (ServletRegistration registration : registrations.values()) {
try {
Class<?> clazz = Class.forName(registration.getClassName());
if (dispatcherServlet.isAssignableFrom(clazz)) {
return true;
}
}
catch (ClassNotFoundException ex) {
return false;
}
}
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
return false;
}

/**
Expand Down Expand Up @@ -262,12 +299,7 @@ private RequestMatchers() {
* @return a {@link List} of {@link AntPathRequestMatcher} instances
*/
static List<RequestMatcher> antMatchers(HttpMethod httpMethod, String... antPatterns) {
String method = (httpMethod != null) ? httpMethod.toString() : null;
List<RequestMatcher> matchers = new ArrayList<>();
for (String pattern : antPatterns) {
matchers.add(new AntPathRequestMatcher(pattern, method));
}
return matchers;
return Arrays.asList(antMatchersAsArray(httpMethod, antPatterns));
}

/**
Expand All @@ -281,6 +313,15 @@ static List<RequestMatcher> antMatchers(String... antPatterns) {
return antMatchers(null, antPatterns);
}

static RequestMatcher[] antMatchersAsArray(HttpMethod httpMethod, String... antPatterns) {
String method = (httpMethod != null) ? httpMethod.toString() : null;
RequestMatcher[] matchers = new RequestMatcher[antPatterns.length];
for (int index = 0; index < antPatterns.length; index++) {
matchers[index] = new AntPathRequestMatcher(antPatterns[index], method);
}
return matchers;
}

/**
* Create a {@link List} of {@link RegexRequestMatcher} instances.
* @param httpMethod the {@link HttpMethod} to use or {@code null} for any
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.config;

import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;

import jakarta.servlet.MultipartConfigElement;
import jakarta.servlet.Servlet;
import jakarta.servlet.ServletRegistration;
import jakarta.servlet.ServletSecurityElement;

import org.springframework.lang.NonNull;
import org.springframework.web.servlet.DispatcherServlet;

public class MockServletContext extends org.springframework.mock.web.MockServletContext {

private final Map<String, ServletRegistration> registrations = new LinkedHashMap<>();

public static MockServletContext mvc() {
MockServletContext servletContext = new MockServletContext();
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
return servletContext;
}

@NonNull
@Override
public ServletRegistration.Dynamic addServlet(@NonNull String servletName, Class<? extends Servlet> clazz) {
ServletRegistration.Dynamic dynamic = new MockServletRegistration(servletName, clazz);
this.registrations.put(servletName, dynamic);
return dynamic;
}

@NonNull
@Override
public Map<String, ? extends ServletRegistration> getServletRegistrations() {
return this.registrations;
}

private static class MockServletRegistration implements ServletRegistration.Dynamic {

private final String name;

private final Class<?> clazz;

MockServletRegistration(String name, Class<?> clazz) {
this.name = name;
this.clazz = clazz;
}

@Override
public void setLoadOnStartup(int loadOnStartup) {

}

@Override
public Set<String> setServletSecurity(ServletSecurityElement constraint) {
return null;
}

@Override
public void setMultipartConfig(MultipartConfigElement multipartConfig) {

}

@Override
public void setRunAsRole(String roleName) {

}

@Override
public void setAsyncSupported(boolean isAsyncSupported) {

}

@Override
public Set<String> addMapping(String... urlPatterns) {
return null;
}

@Override
public Collection<String> getMappings() {
return null;
}

@Override
public String getRunAsRole() {
return null;
}

@Override
public String getName() {
return this.name;
}

@Override
public String getClassName() {
return this.clazz.getName();
}

@Override
public boolean setInitParameter(String name, String value) {
return false;
}

@Override
public String getInitParameter(String name) {
return null;
}

@Override
public Set<String> setInitParameters(Map<String, String> initParameters) {
return null;
}

@Override
public Map<String, String> getInitParameters() {
return null;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,22 @@
import java.util.List;

import jakarta.servlet.DispatcherType;
import jakarta.servlet.Servlet;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext;
import org.springframework.http.HttpMethod;
import org.springframework.security.config.MockServletContext;
import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher;
import org.springframework.security.web.util.matcher.RegexRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.servlet.DispatcherServlet;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
Expand All @@ -54,12 +58,15 @@ public <O> O postProcess(O object) {

private TestRequestMatcherRegistry matcherRegistry;

private WebApplicationContext context;

@BeforeEach
public void setUp() {
this.matcherRegistry = new TestRequestMatcherRegistry();
ApplicationContext context = mock(ApplicationContext.class);
given(context.getBean(ObjectPostProcessor.class)).willReturn(NO_OP_OBJECT_POST_PROCESSOR);
this.matcherRegistry.setApplicationContext(context);
this.context = mock(WebApplicationContext.class);
given(this.context.getBean(ObjectPostProcessor.class)).willReturn(NO_OP_OBJECT_POST_PROCESSOR);
given(this.context.getServletContext()).willReturn(MockServletContext.mvc());
this.matcherRegistry.setApplicationContext(this.context);
mockMvcIntrospector(true);
}

Expand Down Expand Up @@ -147,6 +154,32 @@ public void requestMatchersWhenMvcPresentInClassPathAndMvcIntrospectorBeanNotAva
"Please ensure Spring Security & Spring MVC are configured in a shared ApplicationContext");
}

@Test
public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
assertThat(requestMatchers).isNotEmpty();
assertThat(requestMatchers).hasSize(1);
assertThat(requestMatchers.get(0)).isExactlyInstanceOf(AntPathRequestMatcher.class);
servletContext.addServlet("servletOne", Servlet.class);
servletContext.addServlet("servletTwo", Servlet.class);
requestMatchers = this.matcherRegistry.requestMatchers("/**");
assertThat(requestMatchers).isNotEmpty();
assertThat(requestMatchers).hasSize(1);
assertThat(requestMatchers.get(0)).isExactlyInstanceOf(AntPathRequestMatcher.class);
}

@Test
public void requestMatchersWhenAmbiguousServletsThenException() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
servletContext.addServlet("servletTwo", Servlet.class);
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
}

private void mockMvcIntrospector(boolean isPresent) {
ApplicationContext context = this.matcherRegistry.getApplicationContext();
given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.access.hierarchicalroles.RoleHierarchy;
import org.springframework.security.access.hierarchicalroles.RoleHierarchyImpl;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.config.MockServletContext;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.core.authority.AuthorityUtils;
Expand Down Expand Up @@ -77,7 +77,7 @@ public class AuthorizeRequestsTests {

@BeforeEach
public void setup() {
this.servletContext = spy(new MockServletContext());
this.servletContext = spy(MockServletContext.mvc());
this.request = new MockHttpServletRequest("GET", "");
this.request.setMethod("GET");
this.response = new MockHttpServletResponse();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.config.MockServletContext;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.core.userdetails.User;
Expand Down Expand Up @@ -220,7 +220,7 @@ public void securityMatchersWhenMultiMvcMatcherThenAllPathsAreDenied() throws Ex
public void loadConfig(Class<?>... configs) {
this.context = new AnnotationConfigWebApplicationContext();
this.context.register(configs);
this.context.setServletContext(new MockServletContext());
this.context.setServletContext(MockServletContext.mvc());
this.context.refresh();
this.context.getAutowireCapableBeanFactory().autowireBean(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.MockServletContext;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.core.userdetails.PasswordEncodedUser;
Expand Down Expand Up @@ -167,7 +167,7 @@ public void multiMvcMatchersConfig() throws Exception {
public void loadConfig(Class<?>... configs) {
this.context = new AnnotationConfigWebApplicationContext();
this.context.register(configs);
this.context.setServletContext(new MockServletContext());
this.context.setServletContext(MockServletContext.mvc());
this.context.refresh();
this.context.getAutowireCapableBeanFactory().autowireBean(this);
}
Expand Down
Loading

0 comments on commit 9dc7bdd

Please sign in to comment.