Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import io.ktor.server.engine.*
import io.ktor.utils.io.*
import io.netty.channel.*
import io.netty.handler.codec.http.*
import java.util.concurrent.CancellationException
import kotlin.coroutines.*

public abstract class NettyApplicationResponse(
Expand Down Expand Up @@ -60,15 +61,37 @@ public abstract class NettyApplicationResponse(
responseMessage = message
responseReady.setSuccess()
responseMessageSent = true

awaitProcessingResponseIfInfoOrNoContent()
}

override suspend fun responseChannel(): ByteWriteChannel {
val channel = ByteChannel()
val chunked = headers[HttpHeaders.TransferEncoding] == "chunked"
sendResponse(chunked, content = channel)

awaitProcessingResponseIfInfoOrNoContent()

return channel
}

/**
* Await the [NettyApplicationCall.responseWriteJob] to complete if the response status is informational or 204.
* Netty discards certain headers of such responses, so we have to wait for Netty to finish that,
* in order to avoid the race condition.
*/
private suspend fun awaitProcessingResponseIfInfoOrNoContent() {
val status = status()

if (status != null) {
val infoOrNoContent = status == HttpStatusCode.NoContent || (status.value >= 100 && status.value < 200)

if (infoOrNoContent && call is NettyApplicationCall) {
(call as NettyApplicationCall).responseWriteJob.join()
}
}
}

override suspend fun respondNoContent(content: OutgoingContent.NoContent) {
respondFromBytes(EmptyByteArray)
}
Expand Down Expand Up @@ -122,7 +145,7 @@ public abstract class NettyApplicationResponse(
public fun cancel() {
if (!responseMessageSent) {
responseChannel = ByteReadChannel.Empty
responseReady.setFailure(java.util.concurrent.CancellationException("Response was cancelled"))
responseReady.setFailure(CancellationException("Response was cancelled"))
responseMessageSent = true
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,27 @@

package io.ktor.tests.server.netty

import io.ktor.client.HttpClient
import io.ktor.client.engine.cio.CIO
import io.ktor.client.plugins.DefaultRequest
import io.ktor.client.request.get
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.content.OutgoingContent
import io.ktor.server.application.*
import io.ktor.server.application.hooks.ResponseSent
import io.ktor.server.engine.*
import io.ktor.server.netty.*
import io.ktor.server.response.respond
import io.ktor.server.response.respondBytesWriter
import io.ktor.server.routing.get
import io.ktor.server.routing.routing
import io.ktor.utils.io.ByteReadChannel
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.runTest
import java.net.*
import java.util.concurrent.*
import kotlin.test.*
Expand Down Expand Up @@ -68,4 +86,87 @@ class NettySpecificTest {
assertTrue(server.engine.bootstraps.all { (it.config().group() as ExecutorService).isTerminated })
}
}

@Test
fun contentLengthAndTransferEncodingAreSafelyRemoved() = runTest {
val appStarted = CompletableDeferred<Application>()
val testScope = CoroutineScope(coroutineContext)
val earlyHints = HttpStatusCode(103, "Early Hints")

val serverJob = launch(Dispatchers.IO) {
val server = embeddedServer(Netty, port = 0) {
install(
createApplicationPlugin("CallLogging") {
on(ResponseSent) { call ->
testScope.launch {
val headers = call.response.headers.allValues()
assertNull(headers[HttpHeaders.ContentLength])
assertNull(headers[HttpHeaders.TransferEncoding])
}
}
},
)

routing {
get("/no-content") {
call.respond(HttpStatusCode.NoContent)
}

get("/no-content-channel-writer") {
call.respondBytesWriter(status = HttpStatusCode.NoContent) {}
}

get("/no-content-read-channel") {
call.respond(object : OutgoingContent.ReadChannelContent() {
override val status: HttpStatusCode = HttpStatusCode.NoContent
override fun readFrom(): ByteReadChannel = ByteReadChannel.Empty
})
}

get("/info") {
call.respond(earlyHints)
}

get("/info-channel-writer") {
call.respondBytesWriter(status = earlyHints) {}
}

get("/info-read-channel") {
call.respond(object : OutgoingContent.ReadChannelContent() {
override val status: HttpStatusCode = earlyHints
override fun readFrom(): ByteReadChannel = ByteReadChannel.Empty
})
}
}
}

server.monitor.subscribe(ApplicationStarted) { app ->
appStarted.complete(app)
}

server.start(wait = true)
}

try {
val serverApp = appStarted.await()
val connector = serverApp.engine.resolvedConnectors()[0]
val host = connector.host
val port = connector.port

HttpClient(CIO) {
install(DefaultRequest) {
url("http://$host:$port/")
}
}.use { client ->
assertEquals(HttpStatusCode.NoContent, client.get("/no-content").status)
assertEquals(HttpStatusCode.NoContent, client.get("/no-content-channel-writer").status)
assertEquals(HttpStatusCode.NoContent, client.get("/no-content-read-channel").status)
assertEquals(earlyHints, client.get("/info").status)
assertEquals(earlyHints, client.get("/info-channel-writer").status)
assertEquals(earlyHints, client.get("/info-read-channel").status)
}
} finally {
serverJob.cancel()
}
}
}