Skip to content

Prevent ThreadContext header leak when sending response #68649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/68649.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 68649
summary: Prevent `ThreadContext` header leak when sending response
area: Infra/Core
type: bug
issues:
- 68278
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.http.HttpPipelinedRequest;
import org.elasticsearch.transport.Transports;

@ChannelHandler.Sharable
class Netty4HttpRequestHandler extends SimpleChannelInboundHandler<HttpPipelinedRequest> {
Expand All @@ -26,6 +27,8 @@ class Netty4HttpRequestHandler extends SimpleChannelInboundHandler<HttpPipelined

@Override
protected void channelRead0(ChannelHandlerContext ctx, HttpPipelinedRequest httpRequest) {
assert Transports.assertDefaultThreadContext(serverTransport.getThreadPool().getThreadContext());
assert Transports.assertTransportThread();
final Netty4HttpChannel channel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get();
boolean success = false;
try {
Expand All @@ -41,6 +44,8 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpPipelinedRequest http
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
ExceptionsHelper.maybeDieOnAnotherThread(cause);
assert Transports.assertDefaultThreadContext(serverTransport.getThreadPool().getThreadContext());

Netty4HttpChannel channel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get();
if (cause instanceof Error) {
serverTransport.onException(channel, new Exception(cause));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ protected HttpChannelHandler(final Netty4HttpServerTransport transport, final Ht
protected void initChannel(Channel ch) throws Exception {
Netty4HttpChannel nettyHttpChannel = new Netty4HttpChannel(ch);
ch.attr(HTTP_CHANNEL_KEY).set(nettyHttpChannel);
ch.pipeline().addLast("chunked_writer", new Netty4WriteThrottlingHandler());
ch.pipeline().addLast("chunked_writer", new Netty4WriteThrottlingHandler(transport.getThreadPool().getThreadContext()));
ch.pipeline().addLast("byte_buf_sizer", NettyByteBufSizer.INSTANCE);
ch.pipeline().addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS));
final HttpRequestDecoder decoder = new HttpRequestDecoder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ private void setupPipeline(Channel ch) {
ch.pipeline()
.addLast("byte_buf_sizer", NettyByteBufSizer.INSTANCE)
.addLast("logging", ESLoggingHandler.INSTANCE)
.addLast("chunked_writer", new Netty4WriteThrottlingHandler())
.addLast("chunked_writer", new Netty4WriteThrottlingHandler(getThreadPool().getThreadContext()))
.addLast("dispatcher", new Netty4MessageInboundHandler(this, recycler));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;

import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.transport.Transports;

import java.nio.channels.ClosedChannelException;
import java.util.ArrayDeque;
import java.util.Queue;
Expand All @@ -28,13 +31,18 @@ public final class Netty4WriteThrottlingHandler extends ChannelDuplexHandler {

private final Queue<WriteOperation> queuedWrites = new ArrayDeque<>();

private final ThreadContext threadContext;
private WriteOperation currentWrite;

public Netty4WriteThrottlingHandler() {}
public Netty4WriteThrottlingHandler(ThreadContext threadContext) {
this.threadContext = threadContext;
}

@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
assert msg instanceof ByteBuf;
assert Transports.assertDefaultThreadContext(threadContext);
assert Transports.assertTransportThread();
final boolean queued = queuedWrites.offer(new WriteOperation((ByteBuf) msg, promise));
assert queued;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,4 +492,8 @@ private static ActionListener<Void> earlyResponseListener(HttpRequest request, H
return ActionListener.noop();
}
}

public ThreadPool getThreadPool() {
return threadPool;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ public void sendResponse(RestResponse restResponse) {
addCookies(httpResponse);

ActionListener<Void> listener = ActionListener.wrap(() -> Releasables.close(toClose));
httpChannel.sendResponse(httpResponse, listener);
try (ThreadContext.StoredContext existing = threadContext.stashContext()) {
httpChannel.sendResponse(httpResponse, listener);
}
success = true;
} finally {
if (success == false) {
Expand Down