Skip to content

Commit 4cf0b59

Browse files
committed
Add ForwardedHeaderFilter
The new Filter is simply a new way of packaging the ability to extract X-Forwarded-* headers already available via UriComponentsBuilder. The Filter wraps the request and the effect is that anything using the request will see the original schem, host, and port. Issue: SPR-13614
1 parent 7b861c9 commit 4cf0b59

File tree

2 files changed

+285
-0
lines changed

2 files changed

+285
-0
lines changed
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
/*
2+
* Copyright 2002-2016 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.web.filter;
17+
18+
import java.io.IOException;
19+
import java.util.Collections;
20+
import java.util.Enumeration;
21+
import java.util.HashSet;
22+
import java.util.LinkedHashMap;
23+
import java.util.List;
24+
import java.util.Map;
25+
import java.util.Set;
26+
import javax.servlet.FilterChain;
27+
import javax.servlet.ServletException;
28+
import javax.servlet.http.HttpServletRequest;
29+
import javax.servlet.http.HttpServletRequestWrapper;
30+
import javax.servlet.http.HttpServletResponse;
31+
32+
import org.springframework.http.HttpRequest;
33+
import org.springframework.http.server.ServletServerHttpRequest;
34+
import org.springframework.web.util.UriComponents;
35+
import org.springframework.web.util.UriComponentsBuilder;
36+
37+
38+
/**
39+
* Filter that wraps the request in order to override its
40+
* {@link HttpServletRequest#getServerName() getServerName()},
41+
* {@link HttpServletRequest#getServerPort() getServerPort()},
42+
* {@link HttpServletRequest#getScheme() getScheme()}, and
43+
* {@link HttpServletRequest#isSecure() isSecure()} methods with values derived
44+
* from "Fowarded" or "X-Forwarded-*" headers. In effect the wrapped request
45+
* reflects the client-originated protocol and address.
46+
*
47+
* @author Rossen Stoyanchev
48+
* @since 4.3
49+
*/
50+
public class ForwardedHeaderFilter extends OncePerRequestFilter {
51+
52+
private static final Set<String> FORWARDED_HEADER_NAMES;
53+
54+
static {
55+
FORWARDED_HEADER_NAMES = new HashSet<String>(4);
56+
FORWARDED_HEADER_NAMES.add("Forwarded");
57+
FORWARDED_HEADER_NAMES.add("X-Forwarded-Host");
58+
FORWARDED_HEADER_NAMES.add("X-Forwarded-Port");
59+
FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto");
60+
}
61+
62+
63+
@Override
64+
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
65+
Enumeration<String> headerNames = request.getHeaderNames();
66+
while (headerNames.hasMoreElements()) {
67+
String name = headerNames.nextElement();
68+
if (FORWARDED_HEADER_NAMES.contains(name)) {
69+
return false;
70+
}
71+
}
72+
return true;
73+
}
74+
75+
@Override
76+
protected boolean shouldNotFilterAsyncDispatch() {
77+
return false;
78+
}
79+
80+
@Override
81+
protected boolean shouldNotFilterErrorDispatch() {
82+
return false;
83+
}
84+
85+
@Override
86+
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
87+
FilterChain filterChain) throws ServletException, IOException {
88+
89+
filterChain.doFilter(new ForwardedHeaderRequestWrapper(request), response);
90+
}
91+
92+
93+
private static class ForwardedHeaderRequestWrapper extends HttpServletRequestWrapper {
94+
95+
private final String scheme;
96+
97+
private final boolean secure;
98+
99+
private final String host;
100+
101+
private final int port;
102+
103+
private final String portInUrl;
104+
105+
private final Map<String, List<String>> headers = new LinkedHashMap<String, List<String>>();
106+
107+
108+
public ForwardedHeaderRequestWrapper(HttpServletRequest request) {
109+
super(request);
110+
111+
HttpRequest httpRequest = new ServletServerHttpRequest(request);
112+
UriComponents uriComponents = UriComponentsBuilder.fromHttpRequest(httpRequest).build();
113+
int port = uriComponents.getPort();
114+
115+
this.scheme = uriComponents.getScheme();
116+
this.secure = "https".equals(scheme);
117+
this.host = uriComponents.getHost();
118+
this.port = (port == -1 ? (this.secure ? 443 : 80) : port);
119+
this.portInUrl = (port == -1 ? "" : ":" + port);
120+
121+
Enumeration<String> headerNames = request.getHeaderNames();
122+
while (headerNames.hasMoreElements()) {
123+
String name = headerNames.nextElement();
124+
this.headers.put(name, Collections.list(request.getHeaders(name)));
125+
}
126+
for (String name : FORWARDED_HEADER_NAMES) {
127+
this.headers.remove(name);
128+
}
129+
}
130+
131+
132+
@Override
133+
public String getHeader(String name) {
134+
Map.Entry<String, List<String>> header = getHeaderEntry(name);
135+
if (header == null || header.getValue() == null || header.getValue().isEmpty()) {
136+
return null;
137+
}
138+
return header.getValue().get(0);
139+
}
140+
141+
protected Map.Entry<String, List<String>> getHeaderEntry(String name) {
142+
for (Map.Entry<String, List<String>> entry : this.headers.entrySet()) {
143+
if (entry.getKey().equalsIgnoreCase(name)) {
144+
return entry;
145+
}
146+
}
147+
return null;
148+
}
149+
150+
@Override
151+
public Enumeration<String> getHeaderNames() {
152+
return Collections.enumeration(this.headers.keySet());
153+
}
154+
155+
@Override
156+
public Enumeration<String> getHeaders(String name) {
157+
Map.Entry<String, List<String>> header = getHeaderEntry(name);
158+
if (header == null || header.getValue() == null) {
159+
return Collections.enumeration(Collections.<String>emptyList());
160+
}
161+
return Collections.enumeration(header.getValue());
162+
}
163+
164+
@Override
165+
public String getScheme() {
166+
return this.scheme;
167+
}
168+
169+
@Override
170+
public String getServerName() {
171+
return this.host;
172+
}
173+
174+
@Override
175+
public int getServerPort() {
176+
return this.port;
177+
}
178+
179+
@Override
180+
public boolean isSecure() {
181+
return this.secure;
182+
}
183+
184+
@Override
185+
public StringBuffer getRequestURL() {
186+
StringBuffer sb = new StringBuffer();
187+
sb.append(this.scheme).append("://").append(this.host).append(this.portInUrl);
188+
sb.append(getRequestURI());
189+
return sb;
190+
}
191+
}
192+
193+
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright 2002-2016 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.web.filter;
17+
18+
import javax.servlet.ServletException;
19+
import javax.servlet.http.HttpServlet;
20+
import javax.servlet.http.HttpServletRequest;
21+
22+
import org.junit.Test;
23+
24+
import org.springframework.mock.web.test.MockFilterChain;
25+
import org.springframework.mock.web.test.MockHttpServletRequest;
26+
import org.springframework.mock.web.test.MockHttpServletResponse;
27+
28+
import static org.junit.Assert.assertEquals;
29+
import static org.junit.Assert.assertFalse;
30+
import static org.junit.Assert.assertNull;
31+
import static org.junit.Assert.assertTrue;
32+
33+
/**
34+
* Unit tests for {@link ForwardedHeaderFilter}.
35+
* @author Rossen Stoyanchev
36+
*/
37+
public class ForwardedHeaderFilterTests {
38+
39+
private final ForwardedHeaderFilter filter = new ForwardedHeaderFilter();
40+
41+
42+
@Test
43+
public void shouldFilter() throws Exception {
44+
testShouldFilter("Forwarded");
45+
testShouldFilter("X-Forwarded-Host");
46+
testShouldFilter("X-Forwarded-Port");
47+
testShouldFilter("X-Forwarded-Proto");
48+
}
49+
50+
@Test
51+
public void shouldNotFilter() throws Exception {
52+
assertTrue(this.filter.shouldNotFilter(new MockHttpServletRequest()));
53+
}
54+
55+
@Test
56+
public void xForwardedHeaders() throws Exception {
57+
MockHttpServletRequest request = new MockHttpServletRequest();
58+
request.setScheme("http");
59+
request.setServerName("localhost");
60+
request.setServerPort(80);
61+
request.setRequestURI("/mvc-showcase");
62+
request.addHeader("X-Forwarded-Proto", "https");
63+
request.addHeader("X-Forwarded-Host", "84.198.58.199");
64+
request.addHeader("X-Forwarded-Port", "443");
65+
request.addHeader("foo", "bar");
66+
67+
MockFilterChain chain = new MockFilterChain(new HttpServlet() {});
68+
69+
this.filter.doFilter(request, new MockHttpServletResponse(), chain);
70+
71+
HttpServletRequest actual = (HttpServletRequest) chain.getRequest();
72+
assertEquals("https://84.198.58.199/mvc-showcase", actual.getRequestURL().toString());
73+
assertEquals("https", actual.getScheme());
74+
assertEquals("84.198.58.199", actual.getServerName());
75+
assertEquals(443, actual.getServerPort());
76+
assertTrue(actual.isSecure());
77+
78+
assertNull(actual.getHeader("X-Forwarded-Proto"));
79+
assertNull(actual.getHeader("X-Forwarded-Host"));
80+
assertNull(actual.getHeader("X-Forwarded-Port"));
81+
assertEquals("bar", actual.getHeader("foo"));
82+
}
83+
84+
85+
private void testShouldFilter(String headerName) throws ServletException {
86+
MockHttpServletRequest request = new MockHttpServletRequest();
87+
request.addHeader(headerName, "1");
88+
assertFalse(this.filter.shouldNotFilter(request));
89+
}
90+
91+
92+
}

0 commit comments

Comments
 (0)