31
31
32
32
import org .springframework .http .HttpRequest ;
33
33
import org .springframework .http .server .ServletServerHttpRequest ;
34
+ import org .springframework .util .CollectionUtils ;
34
35
import org .springframework .web .util .UriComponents ;
35
36
import org .springframework .web .util .UriComponentsBuilder ;
36
37
@@ -92,6 +93,10 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
92
93
93
94
private static class ForwardedHeaderRequestWrapper extends HttpServletRequestWrapper {
94
95
96
+ public static final Enumeration <String > EMPTY_HEADER_VALUES =
97
+ Collections .enumeration (Collections .<String >emptyList ());
98
+
99
+
95
100
private final String scheme ;
96
101
97
102
private final boolean secure ;
@@ -100,9 +105,9 @@ private static class ForwardedHeaderRequestWrapper extends HttpServletRequestWra
100
105
101
106
private final int port ;
102
107
103
- private final String portInUrl ;
108
+ private final StringBuffer requestUrl ;
104
109
105
- private final Map <String , List <String >> headers = new LinkedHashMap < String , List < String >>() ;
110
+ private final Map <String , List <String >> headers ;
106
111
107
112
108
113
public ForwardedHeaderRequestWrapper (HttpServletRequest request ) {
@@ -116,51 +121,35 @@ public ForwardedHeaderRequestWrapper(HttpServletRequest request) {
116
121
this .secure = "https" .equals (scheme );
117
122
this .host = uriComponents .getHost ();
118
123
this .port = (port == -1 ? (this .secure ? 443 : 80 ) : port );
119
- this .portInUrl = (port == -1 ? "" : ":" + port );
124
+ this .requestUrl = initRequestUrl (this .scheme , this .host , port , request .getRequestURI ());
125
+ this .headers = initHeaders (request );
126
+ }
120
127
128
+ private static StringBuffer initRequestUrl (String scheme , String host , int port , String path ) {
129
+ StringBuffer sb = new StringBuffer ();
130
+ sb .append (scheme ).append ("://" ).append (host );
131
+ sb .append (port == -1 ? "" : ":" + port );
132
+ sb .append (path );
133
+ return sb ;
134
+ }
135
+
136
+ /**
137
+ * Copy the headers excluding any {@link #FORWARDED_HEADER_NAMES}.
138
+ */
139
+ private static Map <String , List <String >> initHeaders (HttpServletRequest request ) {
140
+ Map <String , List <String >> headers = new LinkedHashMap <String , List <String >>();
121
141
Enumeration <String > headerNames = request .getHeaderNames ();
122
142
while (headerNames .hasMoreElements ()) {
123
143
String name = headerNames .nextElement ();
124
- this . headers .put (name , Collections .list (request .getHeaders (name )));
144
+ headers .put (name , Collections .list (request .getHeaders (name )));
125
145
}
126
146
for (String name : FORWARDED_HEADER_NAMES ) {
127
- this . headers .remove (name );
147
+ headers .remove (name );
128
148
}
149
+ return headers ;
129
150
}
130
151
131
152
132
- @ Override
133
- public String getHeader (String name ) {
134
- Map .Entry <String , List <String >> header = getHeaderEntry (name );
135
- if (header == null || header .getValue () == null || header .getValue ().isEmpty ()) {
136
- return null ;
137
- }
138
- return header .getValue ().get (0 );
139
- }
140
-
141
- protected Map .Entry <String , List <String >> getHeaderEntry (String name ) {
142
- for (Map .Entry <String , List <String >> entry : this .headers .entrySet ()) {
143
- if (entry .getKey ().equalsIgnoreCase (name )) {
144
- return entry ;
145
- }
146
- }
147
- return null ;
148
- }
149
-
150
- @ Override
151
- public Enumeration <String > getHeaderNames () {
152
- return Collections .enumeration (this .headers .keySet ());
153
- }
154
-
155
- @ Override
156
- public Enumeration <String > getHeaders (String name ) {
157
- Map .Entry <String , List <String >> header = getHeaderEntry (name );
158
- if (header == null || header .getValue () == null ) {
159
- return Collections .enumeration (Collections .<String >emptyList ());
160
- }
161
- return Collections .enumeration (header .getValue ());
162
- }
163
-
164
153
@ Override
165
154
public String getScheme () {
166
155
return this .scheme ;
@@ -183,10 +172,26 @@ public boolean isSecure() {
183
172
184
173
@ Override
185
174
public StringBuffer getRequestURL () {
186
- StringBuffer sb = new StringBuffer ();
187
- sb .append (this .scheme ).append ("://" ).append (this .host ).append (this .portInUrl );
188
- sb .append (getRequestURI ());
189
- return sb ;
175
+ return this .requestUrl ;
176
+ }
177
+
178
+ // Override header accessors in order to not expose forwarded headers
179
+
180
+ @ Override
181
+ public String getHeader (String name ) {
182
+ List <String > value = this .headers .get (name );
183
+ return (CollectionUtils .isEmpty (value ) ? null : value .get (0 ));
184
+ }
185
+
186
+ @ Override
187
+ public Enumeration <String > getHeaders (String name ) {
188
+ List <String > value = this .headers .get (name );
189
+ return (CollectionUtils .isEmpty (value ) ? EMPTY_HEADER_VALUES : Collections .enumeration (value ));
190
+ }
191
+
192
+ @ Override
193
+ public Enumeration <String > getHeaderNames () {
194
+ return Collections .enumeration (this .headers .keySet ());
190
195
}
191
196
}
192
197
0 commit comments