Skip to content

Commit

Permalink
Merge branch '6.0.x'
Browse files Browse the repository at this point in the history
  • Loading branch information
rstoyanchev committed Nov 10, 2023
2 parents f5453cc + 44a3700 commit 33b1ff5
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
import java.util.function.BiFunction;
import java.util.stream.Collectors;

import jakarta.servlet.Filter;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;

Expand Down Expand Up @@ -77,6 +80,15 @@
public class HandlerMappingIntrospector
implements CorsConfigurationSource, ApplicationContextAware, InitializingBean {

static final String MAPPING_ATTRIBUTE =
HandlerMappingIntrospector.class.getName() + ".HandlerMapping";

static final String CORS_CONFIG_ATTRIBUTE =
HandlerMappingIntrospector.class.getName() + ".CorsConfig";

private static final CorsConfiguration NO_CORS_CONFIG = new CorsConfiguration();


@Nullable
private ApplicationContext applicationContext;

Expand Down Expand Up @@ -153,6 +165,58 @@ public List<HandlerMapping> getHandlerMappings() {
}


/**
* Return Filter that performs lookups, caches the results in request attributes,
* and clears the attributes after the filter chain returns.
* @since 6.0.14
*/
public Filter createCacheFilter() {
return (request, response, chain) -> {
MatchableHandlerMapping previousMapping = getCachedMapping(request);
CorsConfiguration previousCorsConfig = getCachedCorsConfiguration(request);
try {
HttpServletRequest wrappedRequest = new AttributesPreservingRequest((HttpServletRequest) request);
doWithHandlerMapping(wrappedRequest, false, (mapping, executionChain) -> {
MatchableHandlerMapping matchableMapping = createMatchableHandlerMapping(mapping, wrappedRequest);
CorsConfiguration corsConfig = getCorsConfiguration(wrappedRequest, executionChain);
setCache(request, matchableMapping, corsConfig);
return null;
});
chain.doFilter(request, response);
}
catch (Exception ex) {
throw new ServletException("HandlerMapping introspection failed", ex);
}
finally {
setCache(request, previousMapping, previousCorsConfig);
}
};
}

@Nullable
private static MatchableHandlerMapping getCachedMapping(ServletRequest request) {
return (MatchableHandlerMapping) request.getAttribute(MAPPING_ATTRIBUTE);
}

@Nullable
private static CorsConfiguration getCachedCorsConfiguration(ServletRequest request) {
return (CorsConfiguration) request.getAttribute(CORS_CONFIG_ATTRIBUTE);
}

private static void setCache(
ServletRequest request, @Nullable MatchableHandlerMapping mapping,
@Nullable CorsConfiguration corsConfig) {

if (mapping != null) {
request.setAttribute(MAPPING_ATTRIBUTE, mapping);
request.setAttribute(CORS_CONFIG_ATTRIBUTE, (corsConfig != null ? corsConfig : NO_CORS_CONFIG));
}
else {
request.removeAttribute(MAPPING_ATTRIBUTE);
request.removeAttribute(CORS_CONFIG_ATTRIBUTE);
}
}

/**
* Find the {@link HandlerMapping} that would handle the given request and
* return a {@link MatchableHandlerMapping} to use for path matching.
Expand All @@ -164,39 +228,60 @@ public List<HandlerMapping> getHandlerMappings() {
*/
@Nullable
public MatchableHandlerMapping getMatchableHandlerMapping(HttpServletRequest request) throws Exception {
HttpServletRequest wrappedRequest = new AttributesPreservingRequest(request);

return doWithHandlerMapping(wrappedRequest, false, (mapping, executionChain) -> {
if (mapping instanceof MatchableHandlerMapping) {
PathPatternMatchableHandlerMapping pathPatternMapping = this.pathPatternMappings.get(mapping);
if (pathPatternMapping != null) {
RequestPath requestPath = ServletRequestPathUtils.getParsedRequestPath(wrappedRequest);
return new LookupPathMatchableHandlerMapping(pathPatternMapping, requestPath);
}
else {
String lookupPath = (String) wrappedRequest.getAttribute(UrlPathHelper.PATH_ATTRIBUTE);
return new LookupPathMatchableHandlerMapping((MatchableHandlerMapping) mapping, lookupPath);
}
MatchableHandlerMapping cachedMapping = getCachedMapping(request);
if (cachedMapping != null) {
return cachedMapping;
}
HttpServletRequest requestToUse = new AttributesPreservingRequest(request);
return doWithHandlerMapping(requestToUse, false,
(mapping, executionChain) -> createMatchableHandlerMapping(mapping, requestToUse));
}

private MatchableHandlerMapping createMatchableHandlerMapping(HandlerMapping mapping, HttpServletRequest request) {
if (mapping instanceof MatchableHandlerMapping) {
PathPatternMatchableHandlerMapping pathPatternMapping = this.pathPatternMappings.get(mapping);
if (pathPatternMapping != null) {
RequestPath requestPath = ServletRequestPathUtils.getParsedRequestPath(request);
return new LookupPathMatchableHandlerMapping(pathPatternMapping, requestPath);
}
else {
String lookupPath = (String) request.getAttribute(UrlPathHelper.PATH_ATTRIBUTE);
return new LookupPathMatchableHandlerMapping((MatchableHandlerMapping) mapping, lookupPath);
}
throw new IllegalStateException("HandlerMapping is not a MatchableHandlerMapping");
});
}
throw new IllegalStateException("HandlerMapping is not a MatchableHandlerMapping");
}

@Override
@Nullable
public CorsConfiguration getCorsConfiguration(HttpServletRequest request) {
AttributesPreservingRequest wrappedRequest = new AttributesPreservingRequest(request);
return doWithHandlerMappingIgnoringException(wrappedRequest, (handlerMapping, executionChain) -> {
for (HandlerInterceptor interceptor : executionChain.getInterceptorList()) {
if (interceptor instanceof CorsConfigurationSource ccs) {
return ccs.getCorsConfiguration(wrappedRequest);
}
}
if (executionChain.getHandler() instanceof CorsConfigurationSource ccs) {
return ccs.getCorsConfiguration(wrappedRequest);
CorsConfiguration cachedCorsConfiguration = getCachedCorsConfiguration(request);
if (cachedCorsConfiguration != null) {
return (cachedCorsConfiguration != NO_CORS_CONFIG ? cachedCorsConfiguration : null);
}
try {
boolean ignoreException = true;
AttributesPreservingRequest requestToUse = new AttributesPreservingRequest(request);
return doWithHandlerMapping(requestToUse, ignoreException,
(handlerMapping, executionChain) -> getCorsConfiguration(requestToUse, executionChain));
}
catch (Exception ex) {
// HandlerMapping exceptions have been ignored. Some more basic error perhaps like request parsing
throw new IllegalStateException(ex);
}
}

@Nullable
private static CorsConfiguration getCorsConfiguration(HttpServletRequest request, HandlerExecutionChain chain) {
for (HandlerInterceptor interceptor : chain.getInterceptorList()) {
if (interceptor instanceof CorsConfigurationSource source) {
return source.getCorsConfiguration(request);
}
return null;
});
}
if (chain.getHandler() instanceof CorsConfigurationSource source) {
return source.getCorsConfiguration(request);
}
return null;
}

@Nullable
Expand Down Expand Up @@ -237,18 +322,6 @@ private <T> T doWithHandlerMapping(
return null;
}

@Nullable
private <T> T doWithHandlerMappingIgnoringException(
HttpServletRequest request, BiFunction<HandlerMapping, HandlerExecutionChain, T> matchHandler) {

try {
return doWithHandlerMapping(request, true, matchHandler);
}
catch (Exception ex) {
throw new IllegalStateException("HandlerMapping exception not suppressed", ex);
}
}


/**
* Request wrapper that buffers request attributes in order protect the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@

package org.springframework.web.servlet.handler;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import jakarta.servlet.Filter;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
Expand All @@ -44,7 +49,9 @@
import org.springframework.web.servlet.function.ServerResponse;
import org.springframework.web.servlet.function.support.RouterFunctionMapping;
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping;
import org.springframework.web.testfixture.servlet.MockFilterChain;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
import org.springframework.web.testfixture.servlet.MockHttpServletResponse;
import org.springframework.web.util.ServletRequestPathUtils;
import org.springframework.web.util.pattern.PathPattern;
import org.springframework.web.util.pattern.PathPatternParser;
Expand Down Expand Up @@ -137,7 +144,7 @@ void getMatchable(boolean usePathPatterns) throws Exception {
@Test
void getMatchableWhereHandlerMappingDoesNotImplementMatchableInterface() {
StaticWebApplicationContext cxt = new StaticWebApplicationContext();
cxt.registerSingleton("mapping", TestHandlerMapping.class);
cxt.registerBean("mapping", HandlerMapping.class, () -> request -> new HandlerExecutionChain(new Object()));
cxt.refresh();

MockHttpServletRequest request = new MockHttpServletRequest();
Expand Down Expand Up @@ -193,6 +200,69 @@ void getCorsConfigurationActual() {
assertThat(corsConfig.getAllowedMethods()).isEqualTo(Collections.singletonList("POST"));
}

@Test
void cacheFilter() throws Exception {
testCacheFilter(new MockHttpServletRequest());
}

@Test
void cacheFilterRestoresPreviousValues() throws Exception {
TestMatchableHandlerMapping previousMapping = new TestMatchableHandlerMapping();
CorsConfiguration previousCorsConfig = new CorsConfiguration();

MockHttpServletRequest request = new MockHttpServletRequest();
request.setAttribute(HandlerMappingIntrospector.MAPPING_ATTRIBUTE, previousMapping);
request.setAttribute(HandlerMappingIntrospector.CORS_CONFIG_ATTRIBUTE, previousCorsConfig);

testCacheFilter(request);

assertThat(previousMapping.getInvocationCount()).isEqualTo(0);
assertThat(request.getAttribute(HandlerMappingIntrospector.MAPPING_ATTRIBUTE)).isSameAs(previousMapping);
assertThat(request.getAttribute(HandlerMappingIntrospector.CORS_CONFIG_ATTRIBUTE)).isSameAs(previousCorsConfig);
}

private void testCacheFilter(MockHttpServletRequest request) throws IOException, ServletException {
TestMatchableHandlerMapping mapping = new TestMatchableHandlerMapping();
StaticWebApplicationContext context = new StaticWebApplicationContext();
context.registerBean(TestMatchableHandlerMapping.class, () -> mapping);
context.refresh();

HandlerMappingIntrospector introspector = initIntrospector(context);
MockHttpServletResponse response = new MockHttpServletResponse();

Filter filter = (req, res, chain) -> {
try {
for (int i = 0; i < 10; i++) {
introspector.getMatchableHandlerMapping((HttpServletRequest) req);
introspector.getCorsConfiguration((HttpServletRequest) req);
}
}
catch (Exception ex) {
throw new IllegalStateException(ex);
}
chain.doFilter(req, res);
};

HttpServlet servlet = new HttpServlet() {

@Override
protected void service(HttpServletRequest req, HttpServletResponse res) {
try {
res.getWriter().print("Success");
}
catch (Exception ex) {
throw new IllegalStateException(ex);
}
}
};

new MockFilterChain(servlet, introspector.createCacheFilter(), filter)
.doFilter(request, response);

assertThat(response.getContentAsString()).isEqualTo("Success");
assertThat(mapping.getInvocationCount()).isEqualTo(1);
}

private HandlerMappingIntrospector initIntrospector(WebApplicationContext context) {
HandlerMappingIntrospector introspector = new HandlerMappingIntrospector();
introspector.setApplicationContext(context);
Expand All @@ -201,15 +271,6 @@ private HandlerMappingIntrospector initIntrospector(WebApplicationContext contex
}


private static class TestHandlerMapping implements HandlerMapping {

@Override
public HandlerExecutionChain getHandler(HttpServletRequest request) {
return new HandlerExecutionChain(new Object());
}
}


@Configuration
static class TestConfig {

Expand Down Expand Up @@ -248,6 +309,7 @@ void handle() {
}
}


private static class TestPathPatternParser extends PathPatternParser {

private final List<String> parsedPatterns = new ArrayList<>();
Expand All @@ -264,4 +326,25 @@ public PathPattern parse(String pathPattern) throws PatternParseException {
}
}


private static class TestMatchableHandlerMapping implements MatchableHandlerMapping {

private int invocationCount;

public int getInvocationCount() {
return this.invocationCount;
}

@Override
public HandlerExecutionChain getHandler(HttpServletRequest request) {
this.invocationCount++;
return new HandlerExecutionChain(new Object());
}

@Override
public RequestMatchResult match(HttpServletRequest request, String pattern) {
throw new UnsupportedOperationException();
}
}

}

0 comments on commit 33b1ff5

Please sign in to comment.