Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@
build/
logs/
.idea/
*/out/*
*/out/*
*.iml
*.ipr
*.iws
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ libs += [
httpclient: "org.apache.httpcomponents:httpclient:4.3.1",
jacksonCodec: "com.fasterxml.jackson.core:jackson-core:2.5.3",
log4j: "log4j:log4j:1.2.17",
netty: "io.netty:netty-all:4.0.27.Final",
netty: "io.netty:netty-all:4.1.30.Final",
testng: "org.testng:testng:6.8.8",
restliServer: "com.linkedin.pegasus:restli-server:6.0.12",
restliNettyStandalone: "com.linkedin.pegasus:restli-netty-standalone:6.0.12"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.flush.FlushConsolidationHandler;
import io.netty.handler.timeout.IdleStateHandler;


Expand Down Expand Up @@ -42,5 +43,6 @@ public void initChannel(SocketChannel socketChannel) {
new ClientChannelHandler(channelMediator, _proxyServer.getConnectionFlowRegistry());

channelPipeline.addLast("handler", clientChannelHandler);
channelPipeline.addFirst(new FlushConsolidationHandler());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpRequestEncoder;
import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.handler.flush.FlushConsolidationHandler;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import javax.net.ssl.SSLEngine;
import org.apache.log4j.Logger;

Expand All @@ -48,6 +54,7 @@ public class ChannelMediator {
private final Channel _clientChannel;
private final ChannelGroup _allChannelGroup;
private Channel _serverChannel;
private final Map<Channel, List<HttpObject>> _buffers;

public ChannelMediator(Channel clientChannel, final ProxyModeControllerFactory proxyModeControllerFactory,
final NioEventLoopGroup upstreamWorkerGroup, final int timeout, final ChannelGroup channelGroup) {
Expand All @@ -56,6 +63,7 @@ public ChannelMediator(Channel clientChannel, final ProxyModeControllerFactory p
_upstreamWorkerGroup = upstreamWorkerGroup;
_serverConnectionIdleTimeoutMsec = timeout;
_allChannelGroup = channelGroup;
_buffers = new HashMap<>();
}

public void initializeProxyModeController(HttpRequest initialRequest) {
Expand Down Expand Up @@ -137,6 +145,7 @@ public ChannelFuture connectToServer(final InetSocketAddress remoteAddress) {
protected void initChannel(Channel ch)
throws Exception {
initChannelPipeline(ch.pipeline(), serverChannelHandler, _serverConnectionIdleTimeoutMsec);
ch.pipeline().addFirst(new FlushConsolidationHandler());
_serverChannel = ch;
}
});
Expand Down Expand Up @@ -176,20 +185,36 @@ public Future<Channel> handshakeWithClient(SSLEngine sslEngine) {
return handshake(sslEngine, false, _clientChannel);
}

public ChannelFuture resumeReadingFromClientChannel() {
if (_clientChannel == null) {
public ChannelFuture changeReadingFromClientChannel(boolean isRead) {
return changeReadingFromChannel(_clientChannel, isRead);
}

public ChannelFuture changeReadingFromServerChannel(boolean isRead) {
return changeReadingFromChannel(_serverChannel, isRead);
}

private ChannelFuture changeReadingFromChannel(Channel channel, boolean isRead) {
if (channel == null) {
throw new IllegalStateException("Channel can't be null");
}
_clientChannel.config().setAutoRead(true);
return _clientChannel.newSucceededFuture();
channel.config().setAutoRead(isRead);
return channel.newSucceededFuture();
}

public ChannelFuture stopReadingFromClientChannel() {
if (_clientChannel == null) {
throw new IllegalStateException("Channel can't be null");
public void writeAllToServerIfPossible() {
writeBufferToChannelIfPossible(_serverChannel);
}

public void writeAllToClientIfPossible() {
writeBufferToChannelIfPossible(_clientChannel);
}

private void writeBufferToChannelIfPossible(Channel channel) {
List<HttpObject> channelBuffer = _buffers.getOrDefault(channel, new LinkedList<>());
while (channel.isWritable() && !channelBuffer.isEmpty()) {
HttpObject payload = channelBuffer.remove(0);
writeToChannel(channel, payload);
}
_clientChannel.config().setAutoRead(false);
return _clientChannel.newSucceededFuture();
}

/**
Expand Down Expand Up @@ -230,16 +255,24 @@ private void initChannelPipeline(ChannelPipeline pipeline, ServerChannelHandler
* */
private ChannelFuture writeToChannel(final Channel channel, final Object object) {
if (channel == null) {
throw new IllegalStateException("Failed to write to channel because channel is null");
ReferenceCountUtil.safeRelease(object);
new IllegalStateException("Failed to write to channel because channel is null");
}
if (object instanceof ReferenceCounted) {
LOG.debug("Retaining reference counted message");
((ReferenceCounted) object).retain();
if (!channel.isActive()) {
ReferenceCountUtil.safeRelease(object);
return channel.newFailedFuture(new IllegalStateException("Failed to write to channel because it's closed"));
}
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Writing in channel [%s]: %s", channel.toString(), object));
ReferenceCountUtil.retain(object);

LOG.debug(String.format("Writing in channel [%s]: %s", channel.toString(), object));

if (channel.isWritable()) {
return channel.writeAndFlush(object);
} else {
List<HttpObject> buffer = _buffers.getOrDefault(channel, new LinkedList<>());
buffer.add((HttpObject) object);
return channel.newSucceededFuture();
}
return channel.writeAndFlush(object);
}

private Future<Void> disconnect(final Channel channel) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ protected void channelRead0(ChannelHandlerContext channelHandlerContext, HttpObj
_channelHandlerDelegate.onRead(httpObject);
}

@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
_channelMediator.writeAllToClientIfPossible();
_channelMediator.changeReadingFromServerChannel(ctx.channel().isWritable());
super.channelWritabilityChanged(ctx);
}

@Override
public void channelRegistered(ChannelHandlerContext ctx)
throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import io.netty.handler.codec.http.HttpObject;
import org.apache.log4j.Logger;


/**
* Server channel handler that implemented read logic from server side.
* Note: It's stateful. Each {@link com.linkedin.mitm.proxy.channel.ClientChannelHandler} map to one
Expand All @@ -30,7 +31,7 @@ public ServerChannelHandler(ChannelMediator channelMediator) {
}

@Override
protected void channelRead0(ChannelHandlerContext channelHandlerContext, HttpObject httpObject)
protected void channelRead0(ChannelHandlerContext ctx, HttpObject httpObject)
throws Exception {
_channelMediator.readFromServerChannel(httpObject);
if (httpObject instanceof DefaultLastHttpContent) {
Expand All @@ -48,6 +49,13 @@ public void channelRegistered(ChannelHandlerContext ctx)
super.channelRegistered(ctx);
}

@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
_channelMediator.writeAllToServerIfPossible();
_channelMediator.changeReadingFromClientChannel(ctx.channel().isWritable());
super.channelWritabilityChanged(ctx);
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ public class ResumeReadingFromClient implements ConnectionFlowStep {
@Override
public Future execute(ChannelMediator channelMediator, InetSocketAddress remoteAddress) {
LOG.debug("Resume reading from client");
return channelMediator.resumeReadingFromClientChannel();
return channelMediator.changeReadingFromClientChannel(true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ public class StopReadingFromClient implements ConnectionFlowStep {
@Override
public Future execute(ChannelMediator channelMediator, InetSocketAddress remoteAddress) {
LOG.info("Stop reading from client");
return channelMediator.stopReadingFromClientChannel();
return channelMediator.changeReadingFromClientChannel(false);
}
}