Skip to content

Commit

Permalink
Very basic header validator
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Nied <petern@amazon.com>
  • Loading branch information
peternied committed Sep 29, 2023
1 parent 104c512 commit fbacd86
Show file tree
Hide file tree
Showing 14 changed files with 195 additions and 394 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,26 @@
import org.opensearch.test.OpenSearchIntegTestCase.Scope;

import java.util.Collection;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.stream.IntStream;

import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.util.ReferenceCounted;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.hasSize;

import io.netty.handler.codec.http2.HttpConversionUtil;
import static io.netty.handler.codec.http.HttpHeaderNames.HOST;

@ClusterScope(scope = Scope.TEST, supportsDedicatedMasters = false, numDataNodes = 1)
public class Netty4Http2IT extends OpenSearchNetty4IntegTestCase {

Expand Down Expand Up @@ -56,6 +64,30 @@ public void testThatNettyHttpServerSupportsHttp2GetUpgrades() throws Exception {
}
}


public void testThatNettyHttpServerRequestBlockedWithHeaderVerifier() throws Exception {
HttpServerTransport httpServerTransport = internalCluster().getInstance(HttpServerTransport.class);
TransportAddress[] boundAddresses = httpServerTransport.boundAddress().boundAddresses();
TransportAddress transportAddress = randomFrom(boundAddresses);

final FullHttpRequest blockedRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
blockedRequest.headers().add("blockme", "Not Allowed");
blockedRequest.headers().add(HOST, "localhost");
blockedRequest.headers().add(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http");


final List<FullHttpResponse> responses = new ArrayList<>();
try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http2() ) {
try {
FullHttpResponse blockedResponse = nettyHttpClient.send(transportAddress.address(), blockedRequest);
responses.add(blockedResponse);
assertThat(blockedResponse.status().code(), equalTo(401));
} finally {
responses.forEach(ReferenceCounted::release);
}
}
}

public void testThatNettyHttpServerSupportsHttp2PostUpgrades() throws Exception {
final List<Tuple<String, CharSequence>> requests = List.of(Tuple.tuple("/_search", "{\"query\":{ \"match_all\":{}}}"));

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.http.netty4;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelFutureListener;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.util.ReferenceCountUtil;

@ChannelHandler.Sharable
public class Netty4HeaderVerifier extends ChannelInboundHandlerAdapter {

final static Logger log = LogManager.getLogger(Netty4HeaderVerifier.class);

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (!(msg instanceof HttpRequest)) {
ctx.fireChannelRead(msg);
}

HttpRequest request = (HttpRequest) msg;
if (!isAuthenticated(request)) {
final FullHttpResponse response = new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1,
HttpResponseStatus.UNAUTHORIZED);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
ReferenceCountUtil.release(msg);
} else {
// Lets the request pass to the next channel handler
ctx.fireChannelRead(msg);
}
}

private boolean isAuthenticated(HttpRequest request) {
log.info("Checking if request is authenticated:\n" + request);

final boolean shouldBlock = request.headers().contains("blockme");

return !shouldBlock;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import org.opensearch.ExceptionsHelper;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.http.HttpPipelinedRequest;
import org.opensearch.rest.RestHandlerContext;
import org.opensearch.rest.RestResponse;

import io.netty.channel.ChannelHandler;
Expand All @@ -54,12 +53,9 @@ class Netty4HttpRequestHandler extends SimpleChannelInboundHandler<HttpPipelined
@Override
protected void channelRead0(ChannelHandlerContext ctx, HttpPipelinedRequest httpRequest) {
final Netty4HttpChannel channel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get();
final RestResponse earlyResponse = ctx.channel().attr(Netty4HttpServerTransport.EARLY_RESPONSE).get();
final ThreadContext.StoredContext contextToRestore = ctx.channel().attr(Netty4HttpServerTransport.CONTEXT_TO_RESTORE).get();
final RestHandlerContext requestContext = new RestHandlerContext(earlyResponse, contextToRestore);
boolean success = false;
try {
serverTransport.incomingRequest(httpRequest, channel, requestContext);
serverTransport.incomingRequest(httpRequest, channel);
success = true;
} finally {
if (success == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.common.util.net.NetUtils;
import org.opensearch.core.common.unit.ByteSizeUnit;
Expand All @@ -53,7 +52,6 @@
import org.opensearch.http.HttpHandlingSettings;
import org.opensearch.http.HttpReadTimeoutException;
import org.opensearch.http.HttpServerChannel;
import org.opensearch.rest.RestResponse;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.NettyAllocator;
Expand Down Expand Up @@ -336,16 +334,11 @@ public ChannelHandler configureServerChannelHandler() {
return new HttpChannelHandler(this, handlingSettings);
}

public static final AttributeKey<Netty4HttpChannel> HTTP_CHANNEL_KEY = AttributeKey.newInstance("opensearch-http-channel");
protected static final AttributeKey<Netty4HttpChannel> HTTP_CHANNEL_KEY = AttributeKey.newInstance("opensearch-http-channel");
protected static final AttributeKey<Netty4HttpServerChannel> HTTP_SERVER_CHANNEL_KEY = AttributeKey.newInstance(
"opensearch-http-server-channel"
);

public static final AttributeKey<RestResponse> EARLY_RESPONSE = AttributeKey.newInstance("opensearch-http-early-response");
public static final AttributeKey<ThreadContext.StoredContext> CONTEXT_TO_RESTORE = AttributeKey.newInstance(
"opensearch-http-request-thread-context"
);

protected static class HttpChannelHandler extends ChannelInitializer<Channel> {

private final Netty4HttpServerTransport transport;
Expand Down Expand Up @@ -427,10 +420,14 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpMessage msg) throws E
final ChannelPipeline pipeline = ctx.pipeline();
pipeline.addAfter(ctx.name(), "handler", getRequestHandler());
pipeline.replace(this, "header_verifier", transport.createHeaderVerifier());
pipeline.addAfter("header_verifier", "decompress", transport.createDecompressor());
pipeline.addAfter("decompress", "aggregator", aggregator);
pipeline.addAfter("header_verifier", "decoder_compress", new HttpContentDecompressor());
pipeline.addAfter("decoder_compress", "aggregator", aggregator);
if (handlingSettings.isCompression()) {
pipeline.addAfter("aggregator", "compress", new HttpContentCompressor(handlingSettings.getCompressionLevel()));
pipeline.addAfter(
"aggregator",
"encoder_compress",
new HttpContentCompressor(handlingSettings.getCompressionLevel())
);
}
pipeline.addBefore("handler", "request_creator", requestCreator);
pipeline.addBefore("handler", "response_creator", responseCreator);
Expand All @@ -450,13 +447,13 @@ protected void configureDefaultHttpPipeline(ChannelPipeline pipeline) {
decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR);
pipeline.addLast("decoder", decoder);
pipeline.addLast("header_verifier", transport.createHeaderVerifier());
pipeline.addLast("decompress", transport.createDecompressor());
pipeline.addLast("decoder_compress", new HttpContentDecompressor());
pipeline.addLast("encoder", new HttpResponseEncoder());
final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength());
aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents);
pipeline.addLast("aggregator", aggregator);
if (handlingSettings.isCompression()) {
pipeline.addLast("compress", new HttpContentCompressor(handlingSettings.getCompressionLevel()));
pipeline.addLast("encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel()));
}
pipeline.addLast("request_creator", requestCreator);
pipeline.addLast("response_creator", responseCreator);
Expand Down Expand Up @@ -491,16 +488,18 @@ protected void initChannel(Channel childChannel) throws Exception {

final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength());
aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents);

childChannel.pipeline()
.addLast(new LoggingHandler(LogLevel.DEBUG))
.addLast(new Http2StreamFrameToHttpObjectCodec(true))
.addLast("byte_buf_sizer", byteBufSizer)
.addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS))
.addLast("header_verifier", transport.createHeaderVerifier())
.addLast("decompress", transport.createDecompressor());
.addLast("decoder_decompress", new HttpContentDecompressor());

if (handlingSettings.isCompression()) {
childChannel.pipeline().addLast("compress", new HttpContentCompressor(handlingSettings.getCompressionLevel()));
childChannel.pipeline()
.addLast("encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel()));
}

childChannel.pipeline()
Expand Down Expand Up @@ -535,12 +534,9 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
}
}

protected HttpContentDecompressor createDecompressor() {
return new HttpContentDecompressor();
}

protected ChannelInboundHandlerAdapter createHeaderVerifier() {
return new Netty4HeaderVerifier();
// pass-through
return new ChannelInboundHandlerAdapter();
// return new ChannelInboundHandlerAdapter();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import org.opensearch.nio.SocketChannelContext;
import org.opensearch.nio.TaskScheduler;
import org.opensearch.nio.WriteOperation;
import org.opensearch.rest.RestHandlerContext;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -173,7 +172,7 @@ private void handleRequest(Object msg) {
final HttpPipelinedRequest pipelinedRequest = (HttpPipelinedRequest) msg;
boolean success = false;
try {
transport.incomingRequest(pipelinedRequest, nioHttpChannel, RestHandlerContext.EMPTY);
transport.incomingRequest(pipelinedRequest, nioHttpChannel);
success = true;
} finally {
if (success == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
import org.opensearch.nio.InboundChannelBuffer;
import org.opensearch.nio.SocketChannelContext;
import org.opensearch.nio.TaskScheduler;
import org.opensearch.rest.RestHandlerContext;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.junit.Before;
Expand Down Expand Up @@ -102,7 +101,7 @@ public void setMocks() {
doAnswer(invocation -> {
((HttpRequest) invocation.getArguments()[0]).releaseAndCopy();
return null;
}).when(transport).incomingRequest(any(HttpRequest.class), any(HttpChannel.class), any(RestHandlerContext.class));
}).when(transport).incomingRequest(any(HttpRequest.class), any(HttpChannel.class));
Settings settings = Settings.builder().put(SETTING_HTTP_MAX_CONTENT_LENGTH.getKey(), new ByteSizeValue(1024)).build();
HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
channel = mock(NioHttpChannel.class);
Expand All @@ -123,12 +122,12 @@ public void testSuccessfulDecodeHttpRequest() throws IOException {
try {
handler.consumeReads(toChannelBuffer(slicedBuf));

verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class), any(RestHandlerContext.class));
verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class));

handler.consumeReads(toChannelBuffer(slicedBuf2));

ArgumentCaptor<HttpRequest> requestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class), any(RestHandlerContext.class));
verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class));

HttpRequest nioHttpRequest = requestCaptor.getValue();
assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion());
Expand All @@ -154,7 +153,7 @@ public void testDecodeHttpRequestError() throws IOException {
handler.consumeReads(toChannelBuffer(buf));

ArgumentCaptor<HttpRequest> requestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class), any(RestHandlerContext.class));
verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class));

assertNotNull(requestCaptor.getValue().getInboundException());
assertTrue(requestCaptor.getValue().getInboundException() instanceof IllegalArgumentException);
Expand All @@ -175,7 +174,7 @@ public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() t
} finally {
buf.release();
}
verify(transport, times(0)).incomingRequest(any(), any(), any(RestHandlerContext.class));
verify(transport, times(0)).incomingRequest(any(), any());

List<FlushOperation> flushOperations = handler.pollFlushOperations();
assertFalse(flushOperations.isEmpty());
Expand Down Expand Up @@ -281,7 +280,7 @@ private void prepareHandlerForResponse(HttpReadWriteHandler handler) throws IOEx
}

ArgumentCaptor<HttpPipelinedRequest> requestCaptor = ArgumentCaptor.forClass(HttpPipelinedRequest.class);
verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class), any(RestHandlerContext.class));
verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class));

HttpRequest httpRequest = requestCaptor.getValue();
assertNotNull(httpRequest);
Expand Down
Loading

0 comments on commit fbacd86

Please sign in to comment.