Skip to content

Commit 58d4b7c

Browse files
committed
review commits
1 parent a8c5ad0 commit 58d4b7c

File tree

2 files changed

+52
-16
lines changed

2 files changed

+52
-16
lines changed

common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@
3131
import io.netty.channel.ChannelOutboundHandlerAdapter;
3232
import io.netty.channel.ChannelPromise;
3333
import io.netty.channel.FileRegion;
34-
import io.netty.handler.codec.ByteToMessageDecoder;
35-
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
3634
import io.netty.handler.codec.MessageToMessageDecoder;
3735
import io.netty.util.AbstractReferenceCounted;
3836

3937
import org.apache.spark.network.util.ByteArrayWritableChannel;
38+
import org.apache.spark.network.util.TransportFrameDecoder;
39+
4040
/**
4141
* Provides SASL-based encription for transport channels. The single method exposed by this
4242
* class installs the needed channel handlers on a connected channel.
@@ -61,21 +61,12 @@ static void addToChannel(
6161
channel.pipeline()
6262
.addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize))
6363
.addFirst("saslDecryption", new DecryptionHandler(backend))
64-
// Each frame does not exceed 8 + maxOutboundBlockSize bytes
6564
.addFirst("saslFrameDecoder", createFrameDecoder());
6665
}
6766

68-
/**
69-
* Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame.
70-
* This is used before all decoders.
71-
*/
72-
static ByteToMessageDecoder createFrameDecoder() {
73-
// maxFrameLength = 2G
74-
// lengthFieldOffset = 0
75-
// lengthFieldLength = 8
76-
// lengthAdjustment = -8, i.e. exclude the 8 byte length itself
77-
// initialBytesToStrip = 8, i.e. strip out the length field itself
78-
return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8);
67+
// Each frame does not exceed 8 + maxOutboundBlockSize bytes
68+
private static TransportFrameDecoder createFrameDecoder() {
69+
return new TransportFrameDecoder(false);
7970
}
8071

8172
private static class EncryptionHandler extends ChannelOutboundHandlerAdapter {

common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,25 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter {
4646

4747
public static final String HANDLER_NAME = "frameDecoder";
4848
private static final int LENGTH_SIZE = 8;
49+
private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE;
4950
private static final int UNKNOWN_FRAME_SIZE = -1;
5051

5152
private final LinkedList<ByteBuf> buffers = new LinkedList<>();
5253
private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE);
54+
private final boolean isSupportLargeData;
5355

5456
private long totalSize = 0;
5557
private long nextFrameSize = UNKNOWN_FRAME_SIZE;
5658
private volatile Interceptor interceptor;
5759

60+
public TransportFrameDecoder() {
61+
this(true);
62+
}
63+
64+
public TransportFrameDecoder(boolean isSupportLargeData) {
65+
this.isSupportLargeData = isSupportLargeData;
66+
}
67+
5868
@Override
5969
public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception {
6070
ByteBuf in = (ByteBuf) data;
@@ -77,7 +87,13 @@ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception
7787
totalSize -= read;
7888
} else {
7989
// Interceptor is not active, so try to decode one frame.
80-
LinkedList<ByteBuf> frame = decodeNext();
90+
Object frame ;
91+
if (isSupportLargeData) {
92+
frame = decodeList();
93+
} else {
94+
frame = decodeByteBuf();
95+
}
96+
8197
if (frame == null) {
8298
break;
8399
}
@@ -120,7 +136,36 @@ private long decodeFrameSize() {
120136
return nextFrameSize;
121137
}
122138

123-
private LinkedList<ByteBuf> decodeNext() throws Exception {
139+
private ByteBuf decodeByteBuf() throws Exception {
140+
long frameSize = decodeFrameSize();
141+
if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) {
142+
return null;
143+
}
144+
145+
// Reset size for next frame.
146+
nextFrameSize = UNKNOWN_FRAME_SIZE;
147+
148+
Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize);
149+
Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize);
150+
151+
// If the first buffer holds the entire frame, return it.
152+
int remaining = (int) frameSize;
153+
if (buffers.getFirst().readableBytes() >= remaining) {
154+
return nextBufferForFrame(remaining);
155+
}
156+
157+
// Otherwise, create a composite buffer.
158+
CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer(Integer.MAX_VALUE);
159+
while (remaining > 0) {
160+
ByteBuf next = nextBufferForFrame(remaining);
161+
remaining -= next.readableBytes();
162+
frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes());
163+
}
164+
assert remaining == 0;
165+
return frame;
166+
}
167+
168+
private LinkedList<ByteBuf> decodeList() throws Exception {
124169
long frameSize = decodeFrameSize();
125170
if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) {
126171
return null;

0 commit comments

Comments
 (0)