@@ -62,11 +62,11 @@ public Task Invoke(HttpContext context)
62
62
{
63
63
// Detect if an opaque upgrade is available. If so, add a websocket upgrade.
64
64
var upgradeFeature = context . Features . Get < IHttpUpgradeFeature > ( ) ;
65
- if ( upgradeFeature != null && context . Features . Get < IHttpWebSocketFeature > ( ) == null )
65
+ var connectFeature = context . Features . Get < IHttpExtendedConnectFeature > ( ) ;
66
+ if ( ( upgradeFeature != null || connectFeature != null ) && context . Features . Get < IHttpWebSocketFeature > ( ) == null )
66
67
{
67
- var webSocketFeature = new UpgradeHandshake ( context , upgradeFeature , _options , _logger ) ;
68
+ var webSocketFeature = new WebSocketHandshake ( context , upgradeFeature , connectFeature , _options , _logger ) ;
68
69
context . Features . Set < IHttpWebSocketFeature > ( webSocketFeature ) ;
69
-
70
70
if ( ! _anyOriginAllowed )
71
71
{
72
72
// Check for Origin header
@@ -88,18 +88,21 @@ public Task Invoke(HttpContext context)
88
88
return _next ( context ) ;
89
89
}
90
90
91
- private sealed class UpgradeHandshake : IHttpWebSocketFeature
91
+ private sealed class WebSocketHandshake : IHttpWebSocketFeature
92
92
{
93
93
private readonly HttpContext _context ;
94
- private readonly IHttpUpgradeFeature _upgradeFeature ;
94
+ private readonly IHttpUpgradeFeature ? _upgradeFeature ;
95
+ private readonly IHttpExtendedConnectFeature ? _connectFeature ;
95
96
private readonly WebSocketOptions _options ;
96
97
private readonly ILogger _logger ;
97
98
private bool ? _isWebSocketRequest ;
99
+ private bool _isH2WebSocket ;
98
100
99
- public UpgradeHandshake ( HttpContext context , IHttpUpgradeFeature upgradeFeature , WebSocketOptions options , ILogger logger )
101
+ public WebSocketHandshake ( HttpContext context , IHttpUpgradeFeature ? upgradeFeature , IHttpExtendedConnectFeature ? connectFeature , WebSocketOptions options , ILogger logger )
100
102
{
101
103
_context = context ;
102
104
_upgradeFeature = upgradeFeature ;
105
+ _connectFeature = connectFeature ;
103
106
_options = options ;
104
107
_logger = logger ;
105
108
}
@@ -110,14 +113,19 @@ public bool IsWebSocketRequest
110
113
{
111
114
if ( _isWebSocketRequest == null )
112
115
{
113
- if ( ! _upgradeFeature . IsUpgradableRequest )
116
+ if ( _connectFeature ? . IsExtendedConnect == true )
114
117
{
115
- _isWebSocketRequest = false ;
118
+ _isH2WebSocket = CheckSupportedWebSocketRequestH2 ( _context . Request . Method , _connectFeature . Protocol , _context . Request . Headers ) ;
119
+ _isWebSocketRequest = _isH2WebSocket ;
116
120
}
117
- else
121
+ else if ( _upgradeFeature ? . IsUpgradableRequest == true )
118
122
{
119
123
_isWebSocketRequest = CheckSupportedWebSocketRequest ( _context . Request . Method , _context . Request . Headers ) ;
120
124
}
125
+ else
126
+ {
127
+ _isWebSocketRequest = false ;
128
+ }
121
129
}
122
130
return _isWebSocketRequest . Value ;
123
131
}
@@ -127,7 +135,7 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
127
135
{
128
136
if ( ! IsWebSocketRequest )
129
137
{
130
- throw new InvalidOperationException ( "Not a WebSocket request." ) ; // TODO: LOC
138
+ throw new InvalidOperationException ( "Not a WebSocket request." ) ;
131
139
}
132
140
133
141
string ? subProtocol = null ;
@@ -154,8 +162,7 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
154
162
}
155
163
}
156
164
157
- var key = _context . Request . Headers . SecWebSocketKey . ToString ( ) ;
158
- HandshakeHelpers . GenerateResponseHeaders ( key , subProtocol , _context . Response . Headers ) ;
165
+ HandshakeHelpers . GenerateResponseHeaders ( ! _isH2WebSocket , _context . Request . Headers , subProtocol , _context . Response . Headers ) ;
159
166
160
167
WebSocketDeflateOptions ? deflateOptions = null ;
161
168
if ( enableCompression )
@@ -187,7 +194,18 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
187
194
}
188
195
}
189
196
190
- Stream opaqueTransport = await _upgradeFeature . UpgradeAsync ( ) ; // Sets status code to 101
197
+ Stream opaqueTransport ;
198
+ // HTTP/2
199
+ if ( _isH2WebSocket )
200
+ {
201
+ // Send the response headers
202
+ opaqueTransport = await _connectFeature ! . AcceptAsync ( ) ;
203
+ }
204
+ // HTTP/1.1
205
+ else
206
+ {
207
+ opaqueTransport = await _upgradeFeature ! . UpgradeAsync ( ) ; // Sets status code to 101
208
+ }
191
209
192
210
return WebSocket . CreateFromStream ( opaqueTransport , new WebSocketCreationOptions ( )
193
211
{
@@ -205,17 +223,22 @@ public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictiona
205
223
return false ;
206
224
}
207
225
226
+ if ( ! CheckWebSocketVersion ( requestHeaders ) )
227
+ {
228
+ return false ;
229
+ }
230
+
208
231
var foundHeader = false ;
209
232
210
- var values = requestHeaders . GetCommaSeparatedValues ( HeaderNames . SecWebSocketVersion ) ;
233
+ var values = requestHeaders . GetCommaSeparatedValues ( HeaderNames . Upgrade ) ;
211
234
foreach ( var value in values )
212
235
{
213
- if ( string . Equals ( value , Constants . Headers . SupportedVersion , StringComparison . OrdinalIgnoreCase ) )
236
+ if ( string . Equals ( value , Constants . Headers . UpgradeWebSocket , StringComparison . OrdinalIgnoreCase ) )
214
237
{
215
238
// WebSockets are long lived; so if the header values are valid we switch them out for the interned versions.
216
239
if ( values . Length == 1 )
217
240
{
218
- requestHeaders . SecWebSocketVersion = Constants . Headers . SupportedVersion ;
241
+ requestHeaders . Upgrade = Constants . Headers . UpgradeWebSocket ;
219
242
}
220
243
foundHeader = true ;
221
244
break ;
@@ -245,28 +268,43 @@ public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictiona
245
268
{
246
269
return false ;
247
270
}
248
- foundHeader = false ;
249
271
250
- values = requestHeaders . GetCommaSeparatedValues ( HeaderNames . Upgrade ) ;
272
+ return HandshakeHelpers . IsRequestKeyValid ( requestHeaders . SecWebSocketKey . ToString ( ) ) ;
273
+ }
274
+
275
+ // https://datatracker.ietf.org/doc/html/rfc8441
276
+ // :method = CONNECT
277
+ // :protocol = websocket
278
+ // :scheme = https
279
+ // :path = /chat
280
+ // :authority = server.example.com
281
+ // sec-websocket-protocol = chat, superchat
282
+ // sec-websocket-extensions = permessage-deflate
283
+ // sec-websocket-version = 13
284
+ // origin = http://www.example.com
285
+ public static bool CheckSupportedWebSocketRequestH2 ( string method , string ? protocol , IHeaderDictionary requestHeaders )
286
+ {
287
+ return HttpMethods . IsConnect ( method )
288
+ && string . Equals ( protocol , Constants . Headers . UpgradeWebSocket , StringComparison . OrdinalIgnoreCase )
289
+ && CheckWebSocketVersion ( requestHeaders ) ;
290
+ }
291
+
292
+ public static bool CheckWebSocketVersion ( IHeaderDictionary requestHeaders )
293
+ {
294
+ var values = requestHeaders . GetCommaSeparatedValues ( HeaderNames . SecWebSocketVersion ) ;
251
295
foreach ( var value in values )
252
296
{
253
- if ( string . Equals ( value , Constants . Headers . UpgradeWebSocket , StringComparison . OrdinalIgnoreCase ) )
297
+ if ( string . Equals ( value , Constants . Headers . SupportedVersion , StringComparison . OrdinalIgnoreCase ) )
254
298
{
255
299
// WebSockets are long lived; so if the header values are valid we switch them out for the interned versions.
256
300
if ( values . Length == 1 )
257
301
{
258
- requestHeaders . Upgrade = Constants . Headers . UpgradeWebSocket ;
302
+ requestHeaders . SecWebSocketVersion = Constants . Headers . SupportedVersion ;
259
303
}
260
- foundHeader = true ;
261
- break ;
304
+ return true ;
262
305
}
263
306
}
264
- if ( ! foundHeader )
265
- {
266
- return false ;
267
- }
268
-
269
- return HandshakeHelpers . IsRequestKeyValid ( requestHeaders . SecWebSocketKey . ToString ( ) ) ;
307
+ return false ;
270
308
}
271
309
}
272
310
0 commit comments