Skip to content

Commit 6fcc869

Browse files
committed
Polish ForwardedHeaderFilter
Issue: SPR-13614
1 parent 0679947 commit 6fcc869

File tree

2 files changed

+48
-44
lines changed

2 files changed

+48
-44
lines changed

spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import org.springframework.http.HttpRequest;
3333
import org.springframework.http.server.ServletServerHttpRequest;
34+
import org.springframework.util.CollectionUtils;
3435
import org.springframework.web.util.UriComponents;
3536
import org.springframework.web.util.UriComponentsBuilder;
3637

@@ -92,6 +93,10 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
9293

9394
private static class ForwardedHeaderRequestWrapper extends HttpServletRequestWrapper {
9495

96+
public static final Enumeration<String> EMPTY_HEADER_VALUES =
97+
Collections.enumeration(Collections.<String>emptyList());
98+
99+
95100
private final String scheme;
96101

97102
private final boolean secure;
@@ -100,9 +105,9 @@ private static class ForwardedHeaderRequestWrapper extends HttpServletRequestWra
100105

101106
private final int port;
102107

103-
private final String portInUrl;
108+
private final StringBuffer requestUrl;
104109

105-
private final Map<String, List<String>> headers = new LinkedHashMap<String, List<String>>();
110+
private final Map<String, List<String>> headers;
106111

107112

108113
public ForwardedHeaderRequestWrapper(HttpServletRequest request) {
@@ -116,51 +121,35 @@ public ForwardedHeaderRequestWrapper(HttpServletRequest request) {
116121
this.secure = "https".equals(scheme);
117122
this.host = uriComponents.getHost();
118123
this.port = (port == -1 ? (this.secure ? 443 : 80) : port);
119-
this.portInUrl = (port == -1 ? "" : ":" + port);
124+
this.requestUrl = initRequestUrl(this.scheme, this.host, port, request.getRequestURI());
125+
this.headers = initHeaders(request);
126+
}
120127

128+
private static StringBuffer initRequestUrl(String scheme, String host, int port, String path) {
129+
StringBuffer sb = new StringBuffer();
130+
sb.append(scheme).append("://").append(host);
131+
sb.append(port == -1 ? "" : ":" + port);
132+
sb.append(path);
133+
return sb;
134+
}
135+
136+
/**
137+
* Copy the headers excluding any {@link #FORWARDED_HEADER_NAMES}.
138+
*/
139+
private static Map<String, List<String>> initHeaders(HttpServletRequest request) {
140+
Map<String, List<String>> headers = new LinkedHashMap<String, List<String>>();
121141
Enumeration<String> headerNames = request.getHeaderNames();
122142
while (headerNames.hasMoreElements()) {
123143
String name = headerNames.nextElement();
124-
this.headers.put(name, Collections.list(request.getHeaders(name)));
144+
headers.put(name, Collections.list(request.getHeaders(name)));
125145
}
126146
for (String name : FORWARDED_HEADER_NAMES) {
127-
this.headers.remove(name);
147+
headers.remove(name);
128148
}
149+
return headers;
129150
}
130151

131152

132-
@Override
133-
public String getHeader(String name) {
134-
Map.Entry<String, List<String>> header = getHeaderEntry(name);
135-
if (header == null || header.getValue() == null || header.getValue().isEmpty()) {
136-
return null;
137-
}
138-
return header.getValue().get(0);
139-
}
140-
141-
protected Map.Entry<String, List<String>> getHeaderEntry(String name) {
142-
for (Map.Entry<String, List<String>> entry : this.headers.entrySet()) {
143-
if (entry.getKey().equalsIgnoreCase(name)) {
144-
return entry;
145-
}
146-
}
147-
return null;
148-
}
149-
150-
@Override
151-
public Enumeration<String> getHeaderNames() {
152-
return Collections.enumeration(this.headers.keySet());
153-
}
154-
155-
@Override
156-
public Enumeration<String> getHeaders(String name) {
157-
Map.Entry<String, List<String>> header = getHeaderEntry(name);
158-
if (header == null || header.getValue() == null) {
159-
return Collections.enumeration(Collections.<String>emptyList());
160-
}
161-
return Collections.enumeration(header.getValue());
162-
}
163-
164153
@Override
165154
public String getScheme() {
166155
return this.scheme;
@@ -183,10 +172,26 @@ public boolean isSecure() {
183172

184173
@Override
185174
public StringBuffer getRequestURL() {
186-
StringBuffer sb = new StringBuffer();
187-
sb.append(this.scheme).append("://").append(this.host).append(this.portInUrl);
188-
sb.append(getRequestURI());
189-
return sb;
175+
return this.requestUrl;
176+
}
177+
178+
// Override header accessors in order to not expose forwarded headers
179+
180+
@Override
181+
public String getHeader(String name) {
182+
List<String> value = this.headers.get(name);
183+
return (CollectionUtils.isEmpty(value) ? null : value.get(0));
184+
}
185+
186+
@Override
187+
public Enumeration<String> getHeaders(String name) {
188+
List<String> value = this.headers.get(name);
189+
return (CollectionUtils.isEmpty(value) ? EMPTY_HEADER_VALUES : Collections.enumeration(value));
190+
}
191+
192+
@Override
193+
public Enumeration<String> getHeaderNames() {
194+
return Collections.enumeration(this.headers.keySet());
190195
}
191196
}
192197

spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public void shouldNotFilter() throws Exception {
5353
}
5454

5555
@Test
56-
public void xForwardedHeaders() throws Exception {
56+
public void forwardedRequest() throws Exception {
5757
MockHttpServletRequest request = new MockHttpServletRequest();
5858
request.setScheme("http");
5959
request.setServerName("localhost");
@@ -65,10 +65,9 @@ public void xForwardedHeaders() throws Exception {
6565
request.addHeader("foo", "bar");
6666

6767
MockFilterChain chain = new MockFilterChain(new HttpServlet() {});
68-
6968
this.filter.doFilter(request, new MockHttpServletResponse(), chain);
70-
7169
HttpServletRequest actual = (HttpServletRequest) chain.getRequest();
70+
7271
assertEquals("https://84.198.58.199/mvc-showcase", actual.getRequestURL().toString());
7372
assertEquals("https", actual.getScheme());
7473
assertEquals("84.198.58.199", actual.getServerName());

0 commit comments

Comments
 (0)