Skip to content

Commit 08a71c9

Browse files
committed
[HTTPDecoder] Decode informal headers correctly
1 parent 6975036 commit 08a71c9

File tree

4 files changed

+129
-14
lines changed

4 files changed

+129
-14
lines changed

Sources/NIOHTTP1/HTTPDecoder.swift

+61-10
Original file line numberDiff line numberDiff line change
@@ -293,16 +293,25 @@ private class BetterHTTPParser {
293293
// does not meet the requirement of RFC 7230. This is an outstanding http_parser issue:
294294
// https://github.com/nodejs/http-parser/issues/251. As a result, we check for these status
295295
// codes and override http_parser's handling as well.
296-
guard let method = self.requestHeads.popFirst()?.method else {
296+
guard !self.requestHeads.isEmpty else {
297297
self.richerError = NIOHTTPDecoderError.unsolicitedResponse
298298
return .error(HPE_UNKNOWN)
299299
}
300-
301-
if method == .HEAD || method == .CONNECT {
302-
skipBody = true
303-
} else if statusCode / 100 == 1 || // 1XX codes
304-
statusCode == 204 || statusCode == 304 {
300+
301+
if (HTTPResponseStatus.continue.code <= statusCode && statusCode < HTTPResponseStatus.ok.code)
302+
&& statusCode != HTTPResponseStatus.switchingProtocols.code {
303+
// if the response status is in the range of 100..<200 but not 101 we don't want to
304+
// pop the request method. The actual request head is expected with the next HTTP
305+
// head.
305306
skipBody = true
307+
} else {
308+
let method = self.requestHeads.removeFirst().method
309+
if method == .HEAD || method == .CONNECT {
310+
skipBody = true
311+
} else if statusCode / 100 == 1 || // 1XX codes
312+
statusCode == 204 || statusCode == 304 {
313+
skipBody = true
314+
}
306315
}
307316
}
308317

@@ -473,15 +482,36 @@ public final class HTTPDecoder<In, Out>: ByteToMessageDecoder, HTTPDecoderDelega
473482
// the actual state
474483
private let parser: BetterHTTPParser
475484
private let leftOverBytesStrategy: RemoveAfterUpgradeStrategy
485+
private let informalResponseStrategy: InformalResponseStrategy
476486
private let kind: HTTPDecoderKind
477487
private var stopParsing = false // set on upgrade or HTTP version error
488+
private var lastHeaderWasInformal = false
478489

479490
/// Creates a new instance of `HTTPDecoder`.
480491
///
481492
/// - parameters:
482493
/// - leftOverBytesStrategy: The strategy to use when removing the decoder from the pipeline and an upgrade was,
483494
/// detected. Note that this does not affect what happens on EOF.
484-
public init(leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes) {
495+
public convenience init(leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes) {
496+
self.init(leftOverBytesStrategy: leftOverBytesStrategy, informalResponseStrategy: .drop)
497+
}
498+
499+
/// Strategy to use when a HTTPDecoder receives an informal HTTP response (1xx except 101)
500+
public enum InformalResponseStrategy {
501+
/// Drop the informal response and only forward the "real" response
502+
case drop
503+
/// Forward the informal response and then forward the "real" response
504+
case forward
505+
}
506+
507+
/// Creates a new instance of `HTTPDecoder`.
508+
///
509+
/// - parameters:
510+
/// - leftOverBytesStrategy: The strategy to use when removing the decoder from the pipeline and an upgrade was,
511+
/// detected. Note that this does not affect what happens on EOF.
512+
/// - supportInformalResponses: Should informal responses (like http status 100) be forwarded or dropped. Default is `.drop`
513+
/// This property is only respected when decoding responses.
514+
public init(leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, informalResponseStrategy: InformalResponseStrategy = .drop) {
485515
self.headers.reserveCapacity(16)
486516
if In.self == HTTPServerRequestPart.self {
487517
self.kind = .request
@@ -492,6 +522,7 @@ public final class HTTPDecoder<In, Out>: ByteToMessageDecoder, HTTPDecoderDelega
492522
}
493523
self.parser = BetterHTTPParser(kind: kind)
494524
self.leftOverBytesStrategy = leftOverBytesStrategy
525+
self.informalResponseStrategy = informalResponseStrategy
495526
}
496527

497528
func didReceiveBody(_ bytes: UnsafeRawBufferPointer) {
@@ -545,7 +576,7 @@ public final class HTTPDecoder<In, Out>: ByteToMessageDecoder, HTTPDecoderDelega
545576
method: http_method,
546577
statusCode: Int,
547578
keepAliveState: KeepAliveState) -> Bool {
548-
let message: NIOAny
579+
let message: NIOAny?
549580

550581
guard versionMajor == 1 else {
551582
self.stopParsing = true
@@ -561,7 +592,23 @@ public final class HTTPDecoder<In, Out>: ByteToMessageDecoder, HTTPDecoderDelega
561592
headers: HTTPHeaders(self.headers,
562593
keepAliveState: keepAliveState))
563594
message = NIOAny(HTTPServerRequestPart.head(reqHead))
595+
596+
case .response where (100..<200).contains(statusCode) && statusCode != 101:
597+
self.lastHeaderWasInformal = true
598+
599+
switch self.informalResponseStrategy {
600+
case .forward:
601+
let resHead: HTTPResponseHead = HTTPResponseHead(version: .init(major: versionMajor, minor: versionMinor),
602+
status: .init(statusCode: statusCode),
603+
headers: HTTPHeaders(self.headers,
604+
keepAliveState: keepAliveState))
605+
message = NIOAny(HTTPClientResponsePart.head(resHead))
606+
case .drop:
607+
message = nil
608+
}
609+
564610
case .response:
611+
self.lastHeaderWasInformal = false
565612
let resHead: HTTPResponseHead = HTTPResponseHead(version: .init(major: versionMajor, minor: versionMinor),
566613
status: .init(statusCode: statusCode),
567614
headers: HTTPHeaders(self.headers,
@@ -570,7 +617,9 @@ public final class HTTPDecoder<In, Out>: ByteToMessageDecoder, HTTPDecoderDelega
570617
}
571618
self.url = nil
572619
self.headers.removeAll(keepingCapacity: true)
573-
self.context!.fireChannelRead(message)
620+
if let message = message {
621+
self.context!.fireChannelRead(message)
622+
}
574623
self.isUpgrade = isUpgrade
575624
return true
576625
}
@@ -582,7 +631,9 @@ public final class HTTPDecoder<In, Out>: ByteToMessageDecoder, HTTPDecoderDelega
582631
case .request:
583632
self.context!.fireChannelRead(NIOAny(HTTPServerRequestPart.end(trailers.map(HTTPHeaders.init))))
584633
case .response:
585-
self.context!.fireChannelRead(NIOAny(HTTPClientResponsePart.end(trailers.map(HTTPHeaders.init))))
634+
if !self.lastHeaderWasInformal {
635+
self.context!.fireChannelRead(NIOAny(HTTPClientResponsePart.end(trailers.map(HTTPHeaders.init))))
636+
}
586637
}
587638
self.stopParsing = self.isUpgrade!
588639
self.isUpgrade = nil

Tests/NIOHTTP1Tests/HTTPDecoderLengthTest.swift

+14-4
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ class HTTPDecoderLengthTest: XCTestCase {
184184
responseStatus: HTTPResponseStatus,
185185
responseFramingField: FramingField) throws {
186186
XCTAssertNoThrow(try channel.pipeline.addHandler(HTTPRequestEncoder()).wait())
187-
XCTAssertNoThrow(try channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder())).wait())
187+
let decoder = HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes, informalResponseStrategy: .forward)
188+
XCTAssertNoThrow(try channel.pipeline.addHandler(ByteToMessageHandler(decoder)).wait())
188189

189190
let handler = MessageEndHandler<HTTPResponseHead, ByteBuffer>()
190191
XCTAssertNoThrow(try channel.pipeline.addHandler(handler).wait())
@@ -214,9 +215,18 @@ class HTTPDecoderLengthTest: XCTestCase {
214215

215216
// We should have a response, no body, and immediately see EOF.
216217
XCTAssert(handler.seenHead)
217-
XCTAssertFalse(handler.seenBody)
218-
XCTAssert(handler.seenEnd)
219-
218+
switch responseStatus.code {
219+
case 100, 102..<200:
220+
// If an informal header is tested, we expect another "real" header to follow. For this
221+
// reason, we don't expect an `.end` here.
222+
XCTAssertFalse(handler.seenBody)
223+
XCTAssertFalse(handler.seenEnd)
224+
225+
default:
226+
XCTAssertFalse(handler.seenBody)
227+
XCTAssert(handler.seenEnd)
228+
}
229+
220230
XCTAssertTrue(try channel.finish().isClean)
221231
}
222232

Tests/NIOHTTP1Tests/HTTPDecoderTest+XCTest.swift

+2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ extension HTTPDecoderTest {
5454
("testAppropriateErrorWhenReceivingUnsolicitedResponse", testAppropriateErrorWhenReceivingUnsolicitedResponse),
5555
("testAppropriateErrorWhenReceivingUnsolicitedResponseDoesNotRecover", testAppropriateErrorWhenReceivingUnsolicitedResponseDoesNotRecover),
5656
("testOneRequestTwoResponses", testOneRequestTwoResponses),
57+
("testForwardContinueThanResponse", testForwardContinueThanResponse),
58+
("testDropContinueThanForwardResponse", testDropContinueThanForwardResponse),
5759
("testRefusesRequestSmugglingAttempt", testRefusesRequestSmugglingAttempt),
5860
("testTrimsTrailingOWS", testTrimsTrailingOWS),
5961
("testMassiveChunkDoesNotBufferAndGivesUsHoweverMuchIsAvailable", testMassiveChunkDoesNotBufferAndGivesUsHoweverMuchIsAvailable),

Tests/NIOHTTP1Tests/HTTPDecoderTest.swift

+52
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,58 @@ class HTTPDecoderTest: XCTestCase {
792792
XCTAssertEqual(["channelReadComplete", "write", "flush", "channelRead", "errorCaught"], eventCounter.allTriggeredEvents())
793793
XCTAssertNoThrow(XCTAssertTrue(try channel.finish().isClean))
794794
}
795+
796+
func testForwardContinueThanResponse() {
797+
let eventCounter = EventCounterHandler()
798+
let responseDecoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes, informalResponseStrategy: .forward))
799+
let channel = EmbeddedChannel(handler: responseDecoder)
800+
XCTAssertNoThrow(try channel.pipeline.addHandler(eventCounter).wait())
801+
802+
let requestHead: HTTPClientRequestPart = .head(.init(version: .http1_1, method: .POST, uri: "/"))
803+
XCTAssertNoThrow(try channel.writeOutbound(requestHead))
804+
var buffer = channel.allocator.buffer(capacity: 128)
805+
buffer.writeString("HTTP/1.1 100 continue\r\n\r\nHTTP/1.1 200 ok\r\ncontent-length: 0\r\n\r\n")
806+
XCTAssertNoThrow(try channel.writeInbound(buffer))
807+
808+
XCTAssertEqual(try channel.readInbound(as: HTTPClientResponsePart.self), .head(.init(version: .http1_1, status: .continue)))
809+
XCTAssertEqual(try channel.readInbound(as: HTTPClientResponsePart.self), .head(.init(version: .http1_1, status: .ok, headers: ["content-length": "0"])))
810+
XCTAssertEqual(.end(nil), try channel.readInbound(as: HTTPClientResponsePart.self))
811+
XCTAssertNil(try channel.readInbound(as: HTTPClientResponsePart.self))
812+
XCTAssertNotNil(try channel.readOutbound())
813+
814+
XCTAssertEqual(1, eventCounter.writeCalls)
815+
XCTAssertEqual(1, eventCounter.flushCalls)
816+
XCTAssertEqual(3, eventCounter.channelReadCalls) // .head, .head & .end
817+
XCTAssertEqual(1, eventCounter.channelReadCompleteCalls)
818+
XCTAssertEqual(["channelReadComplete", "channelRead", "write", "flush"], eventCounter.allTriggeredEvents())
819+
XCTAssertNoThrow(XCTAssertTrue(try channel.finish().isClean))
820+
}
821+
822+
func testDropContinueThanForwardResponse() {
823+
let eventCounter = EventCounterHandler()
824+
let responseDecoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes, informalResponseStrategy: .drop))
825+
let channel = EmbeddedChannel(handler: responseDecoder)
826+
XCTAssertNoThrow(try channel.pipeline.addHandler(eventCounter).wait())
827+
828+
let requestHead: HTTPClientRequestPart = .head(.init(version: .http1_1, method: .POST, uri: "/"))
829+
XCTAssertNoThrow(try channel.writeOutbound(requestHead))
830+
var buffer = channel.allocator.buffer(capacity: 128)
831+
buffer.writeString("HTTP/1.1 100 continue\r\n\r\nHTTP/1.1 200 ok\r\ncontent-length: 0\r\n\r\n")
832+
XCTAssertNoThrow(try channel.writeInbound(buffer))
833+
834+
XCTAssertEqual(try channel.readInbound(as: HTTPClientResponsePart.self), .head(.init(version: .http1_1, status: .ok, headers: ["content-length": "0"])))
835+
XCTAssertEqual(.end(nil), try channel.readInbound(as: HTTPClientResponsePart.self))
836+
XCTAssertNil(try channel.readInbound(as: HTTPClientResponsePart.self))
837+
XCTAssertNotNil(try channel.readOutbound())
838+
839+
XCTAssertEqual(1, eventCounter.writeCalls)
840+
XCTAssertEqual(1, eventCounter.flushCalls)
841+
XCTAssertEqual(2, eventCounter.channelReadCalls) // .head & .end
842+
XCTAssertEqual(1, eventCounter.channelReadCompleteCalls)
843+
XCTAssertEqual(["channelReadComplete", "channelRead", "write", "flush"], eventCounter.allTriggeredEvents())
844+
XCTAssertNoThrow(XCTAssertTrue(try channel.finish().isClean))
845+
}
846+
795847

796848
func testRefusesRequestSmugglingAttempt() throws {
797849
XCTAssertNoThrow(try channel.pipeline.addHandler(ByteToMessageHandler(HTTPRequestDecoder())).wait())

0 commit comments

Comments
 (0)