diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java index 2b4c577eb57..d93ec3f7bd0 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java @@ -19,8 +19,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import jakarta.servlet.DispatcherType; +import jakarta.servlet.ServletContext; +import jakarta.servlet.ServletRegistration; import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.context.ApplicationContext; @@ -36,6 +39,7 @@ import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.handler.HandlerMappingIntrospector; /** @@ -179,14 +183,47 @@ public C requestMatchers(RequestMatcher... requestMatchers) { * @since 5.8 */ public C requestMatchers(HttpMethod method, String... patterns) { - List matchers = new ArrayList<>(); - if (mvcPresent) { - matchers.addAll(createMvcMatchers(method, patterns)); + if (!mvcPresent) { + return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); + } + if (!(this.context instanceof WebApplicationContext)) { + return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); + } + WebApplicationContext context = (WebApplicationContext) this.context; + ServletContext servletContext = context.getServletContext(); + if (servletContext == null) { + return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); + } + Map registrations = servletContext.getServletRegistrations(); + if (registrations == null) { + return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); + } + if (!hasDispatcherServlet(registrations)) { + return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); } - else { - matchers.addAll(RequestMatchers.antMatchers(method, patterns)); + Assert.isTrue(registrations.size() == 1, + "This method cannot decide whether these patterns are Spring MVC patterns or not. If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); otherwise, please use requestMatchers(AntPathRequestMatcher)."); + return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0])); + } + + private boolean hasDispatcherServlet(Map registrations) { + if (registrations == null) { + return false; + } + Class dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet", + null); + for (ServletRegistration registration : registrations.values()) { + try { + Class clazz = Class.forName(registration.getClassName()); + if (dispatcherServlet.isAssignableFrom(clazz)) { + return true; + } + } + catch (ClassNotFoundException ex) { + return false; + } } - return requestMatchers(matchers.toArray(new RequestMatcher[0])); + return false; } /** @@ -262,12 +299,7 @@ private RequestMatchers() { * @return a {@link List} of {@link AntPathRequestMatcher} instances */ static List antMatchers(HttpMethod httpMethod, String... antPatterns) { - String method = (httpMethod != null) ? httpMethod.toString() : null; - List matchers = new ArrayList<>(); - for (String pattern : antPatterns) { - matchers.add(new AntPathRequestMatcher(pattern, method)); - } - return matchers; + return Arrays.asList(antMatchersAsArray(httpMethod, antPatterns)); } /** @@ -281,6 +313,15 @@ static List antMatchers(String... antPatterns) { return antMatchers(null, antPatterns); } + static RequestMatcher[] antMatchersAsArray(HttpMethod httpMethod, String... antPatterns) { + String method = (httpMethod != null) ? httpMethod.toString() : null; + RequestMatcher[] matchers = new RequestMatcher[antPatterns.length]; + for (int index = 0; index < antPatterns.length; index++) { + matchers[index] = new AntPathRequestMatcher(antPatterns[index], method); + } + return matchers; + } + /** * Create a {@link List} of {@link RegexRequestMatcher} instances. * @param httpMethod the {@link HttpMethod} to use or {@code null} for any diff --git a/config/src/test/java/org/springframework/security/config/MockServletContext.java b/config/src/test/java/org/springframework/security/config/MockServletContext.java new file mode 100644 index 00000000000..bace54fff6a --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/MockServletContext.java @@ -0,0 +1,139 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config; + +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; + +import jakarta.servlet.MultipartConfigElement; +import jakarta.servlet.Servlet; +import jakarta.servlet.ServletRegistration; +import jakarta.servlet.ServletSecurityElement; + +import org.springframework.lang.NonNull; +import org.springframework.web.servlet.DispatcherServlet; + +public class MockServletContext extends org.springframework.mock.web.MockServletContext { + + private final Map registrations = new LinkedHashMap<>(); + + public static MockServletContext mvc() { + MockServletContext servletContext = new MockServletContext(); + servletContext.addServlet("dispatcherServlet", DispatcherServlet.class); + return servletContext; + } + + @NonNull + @Override + public ServletRegistration.Dynamic addServlet(@NonNull String servletName, Class clazz) { + ServletRegistration.Dynamic dynamic = new MockServletRegistration(servletName, clazz); + this.registrations.put(servletName, dynamic); + return dynamic; + } + + @NonNull + @Override + public Map getServletRegistrations() { + return this.registrations; + } + + private static class MockServletRegistration implements ServletRegistration.Dynamic { + + private final String name; + + private final Class clazz; + + MockServletRegistration(String name, Class clazz) { + this.name = name; + this.clazz = clazz; + } + + @Override + public void setLoadOnStartup(int loadOnStartup) { + + } + + @Override + public Set setServletSecurity(ServletSecurityElement constraint) { + return null; + } + + @Override + public void setMultipartConfig(MultipartConfigElement multipartConfig) { + + } + + @Override + public void setRunAsRole(String roleName) { + + } + + @Override + public void setAsyncSupported(boolean isAsyncSupported) { + + } + + @Override + public Set addMapping(String... urlPatterns) { + return null; + } + + @Override + public Collection getMappings() { + return null; + } + + @Override + public String getRunAsRole() { + return null; + } + + @Override + public String getName() { + return this.name; + } + + @Override + public String getClassName() { + return this.clazz.getName(); + } + + @Override + public boolean setInitParameter(String name, String value) { + return false; + } + + @Override + public String getInitParameter(String name) { + return null; + } + + @Override + public Set setInitParameters(Map initParameters) { + return null; + } + + @Override + public Map getInitParameters() { + return null; + } + + } + +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java index 2ca18279f53..1a1aa1f3400 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java @@ -19,18 +19,22 @@ import java.util.List; import jakarta.servlet.DispatcherType; +import jakarta.servlet.Servlet; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.context.ApplicationContext; import org.springframework.http.HttpMethod; +import org.springframework.security.config.MockServletContext; import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher; import org.springframework.security.web.util.matcher.RegexRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -54,12 +58,15 @@ public O postProcess(O object) { private TestRequestMatcherRegistry matcherRegistry; + private WebApplicationContext context; + @BeforeEach public void setUp() { this.matcherRegistry = new TestRequestMatcherRegistry(); - ApplicationContext context = mock(ApplicationContext.class); - given(context.getBean(ObjectPostProcessor.class)).willReturn(NO_OP_OBJECT_POST_PROCESSOR); - this.matcherRegistry.setApplicationContext(context); + this.context = mock(WebApplicationContext.class); + given(this.context.getBean(ObjectPostProcessor.class)).willReturn(NO_OP_OBJECT_POST_PROCESSOR); + given(this.context.getServletContext()).willReturn(MockServletContext.mvc()); + this.matcherRegistry.setApplicationContext(this.context); mockMvcIntrospector(true); } @@ -147,6 +154,32 @@ public void requestMatchersWhenMvcPresentInClassPathAndMvcIntrospectorBeanNotAva "Please ensure Spring Security & Spring MVC are configured in a shared ApplicationContext"); } + @Test + public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType() { + MockServletContext servletContext = new MockServletContext(); + given(this.context.getServletContext()).willReturn(servletContext); + List requestMatchers = this.matcherRegistry.requestMatchers("/**"); + assertThat(requestMatchers).isNotEmpty(); + assertThat(requestMatchers).hasSize(1); + assertThat(requestMatchers.get(0)).isExactlyInstanceOf(AntPathRequestMatcher.class); + servletContext.addServlet("servletOne", Servlet.class); + servletContext.addServlet("servletTwo", Servlet.class); + requestMatchers = this.matcherRegistry.requestMatchers("/**"); + assertThat(requestMatchers).isNotEmpty(); + assertThat(requestMatchers).hasSize(1); + assertThat(requestMatchers.get(0)).isExactlyInstanceOf(AntPathRequestMatcher.class); + } + + @Test + public void requestMatchersWhenAmbiguousServletsThenException() { + MockServletContext servletContext = new MockServletContext(); + given(this.context.getServletContext()).willReturn(servletContext); + servletContext.addServlet("dispatcherServlet", DispatcherServlet.class); + servletContext.addServlet("servletTwo", Servlet.class); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.matcherRegistry.requestMatchers("/**")); + } + private void mockMvcIntrospector(boolean isPresent) { ApplicationContext context = this.matcherRegistry.getApplicationContext(); given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java index fdd697b8ff9..1b98cc4df78 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeRequestsTests.java @@ -28,10 +28,10 @@ import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.mock.web.MockServletContext; import org.springframework.security.access.hierarchicalroles.RoleHierarchy; import org.springframework.security.access.hierarchicalroles.RoleHierarchyImpl; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.config.MockServletContext; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.core.authority.AuthorityUtils; @@ -77,7 +77,7 @@ public class AuthorizeRequestsTests { @BeforeEach public void setup() { - this.servletContext = spy(new MockServletContext()); + this.servletContext = spy(MockServletContext.mvc()); this.request = new MockHttpServletRequest("GET", ""); this.request.setMethod("GET"); this.response = new MockHttpServletResponse(); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecuritySecurityMatchersTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecuritySecurityMatchersTests.java index 0f1fdd15f21..4fa5e3031c5 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecuritySecurityMatchersTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecuritySecurityMatchersTests.java @@ -32,7 +32,7 @@ import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.mock.web.MockServletContext; +import org.springframework.security.config.MockServletContext; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.core.userdetails.User; @@ -220,7 +220,7 @@ public void securityMatchersWhenMultiMvcMatcherThenAllPathsAreDenied() throws Ex public void loadConfig(Class... configs) { this.context = new AnnotationConfigWebApplicationContext(); this.context.register(configs); - this.context.setServletContext(new MockServletContext()); + this.context.setServletContext(MockServletContext.mvc()); this.context.refresh(); this.context.getAutowireCapableBeanFactory().autowireBean(this); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationConfigurerTests.java index cb7c4bf5e6f..42f458cd742 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/UrlAuthorizationConfigurerTests.java @@ -30,8 +30,8 @@ import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.mock.web.MockServletContext; import org.springframework.security.config.Customizer; +import org.springframework.security.config.MockServletContext; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.core.userdetails.PasswordEncodedUser; @@ -167,7 +167,7 @@ public void multiMvcMatchersConfig() throws Exception { public void loadConfig(Class... configs) { this.context = new AnnotationConfigWebApplicationContext(); this.context.register(configs); - this.context.setServletContext(new MockServletContext()); + this.context.setServletContext(MockServletContext.mvc()); this.context.refresh(); this.context.getAutowireCapableBeanFactory().autowireBean(this); } diff --git a/config/src/test/java/org/springframework/security/config/test/SpringTestContext.java b/config/src/test/java/org/springframework/security/config/test/SpringTestContext.java index 78fbe32577e..b165c20b609 100644 --- a/config/src/test/java/org/springframework/security/config/test/SpringTestContext.java +++ b/config/src/test/java/org/springframework/security/config/test/SpringTestContext.java @@ -28,8 +28,8 @@ import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor; import org.springframework.mock.web.MockServletConfig; -import org.springframework.mock.web.MockServletContext; import org.springframework.security.config.BeanIds; +import org.springframework.security.config.MockServletContext; import org.springframework.security.config.util.InMemoryXmlWebApplicationContext; import org.springframework.test.context.web.GenericXmlWebContextLoader; import org.springframework.test.web.servlet.MockMvc; @@ -132,7 +132,7 @@ public SpringTestContext addFilter(Filter filter) { public ConfigurableWebApplicationContext getContext() { if (!this.context.isRunning()) { - this.context.setServletContext(new MockServletContext()); + this.context.setServletContext(MockServletContext.mvc()); this.context.setServletConfig(new MockServletConfig()); this.context.refresh(); } @@ -140,7 +140,7 @@ public ConfigurableWebApplicationContext getContext() { } public void autowire() { - this.context.setServletContext(new MockServletContext()); + this.context.setServletContext(MockServletContext.mvc()); this.context.setServletConfig(new MockServletConfig()); for (Consumer postProcessor : this.postProcessors) { postProcessor.accept(this.context);