Skip to content

Commit f66d7fc

Browse files
authored
netty: Fix ByteBuf leaks in tests (grpc#11593)
Part of grpc#3353
1 parent 7f9c1f3 commit f66d7fc

File tree

3 files changed

+39
-36
lines changed

3 files changed

+39
-36
lines changed

netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
217217
// Simulate receipt of initial remote settings.
218218
ByteBuf serializedSettings = serializeSettings(new Http2Settings());
219219
channelRead(serializedSettings);
220+
channel().releaseOutbound();
220221
}
221222

222223
@Test
@@ -342,11 +343,12 @@ public void sendFrameShouldSucceed() throws Exception {
342343
createStream();
343344

344345
// Send a frame and verify that it was written.
346+
ByteBuf content = content();
345347
ChannelFuture future
346-
= enqueue(new SendGrpcFrameCommand(streamTransportState, content(), true));
348+
= enqueue(new SendGrpcFrameCommand(streamTransportState, content, true));
347349

348350
assertTrue(future.isSuccess());
349-
verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(true),
351+
verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), same(content), eq(0), eq(true),
350352
any(ChannelPromise.class));
351353
verify(mockKeepAliveManager, times(1)).onTransportActive(); // onStreamActive
352354
verifyNoMoreInteractions(mockKeepAliveManager);

netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java

+24-17
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
import io.grpc.internal.WritableBuffer;
3939
import io.netty.buffer.ByteBuf;
4040
import io.netty.buffer.ByteBufAllocator;
41-
import io.netty.buffer.ByteBufUtil;
4241
import io.netty.buffer.CompositeByteBuf;
4342
import io.netty.buffer.Unpooled;
4443
import io.netty.buffer.UnpooledByteBufAllocator;
@@ -68,6 +67,7 @@
6867
import java.nio.ByteBuffer;
6968
import java.util.concurrent.Delayed;
7069
import java.util.concurrent.TimeUnit;
70+
import org.junit.After;
7171
import org.junit.Assert;
7272
import org.junit.Test;
7373
import org.junit.runner.RunWith;
@@ -84,7 +84,6 @@
8484
public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
8585

8686
protected static final int STREAM_ID = 3;
87-
private ByteBuf content;
8887

8988
private EmbeddedChannel channel;
9089

@@ -106,18 +105,24 @@ protected void manualSetUp() throws Exception {}
106105
protected final TransportTracer transportTracer = new TransportTracer();
107106
protected int flowControlWindow = DEFAULT_WINDOW_SIZE;
108107
protected boolean autoFlowControl = false;
109-
110108
private final FakeClock fakeClock = new FakeClock();
111109

112110
FakeClock fakeClock() {
113111
return fakeClock;
114112
}
115113

114+
@After
115+
public void tearDown() throws Exception {
116+
if (channel() != null) {
117+
channel().releaseInbound();
118+
channel().releaseOutbound();
119+
}
120+
}
121+
116122
/**
117123
* Must be called by subclasses to initialize the handler and channel.
118124
*/
119125
protected final void initChannel(Http2HeadersDecoder headersDecoder) throws Exception {
120-
content = Unpooled.copiedBuffer("hello world", UTF_8);
121126
frameWriter = mock(Http2FrameWriter.class, delegatesTo(new DefaultHttp2FrameWriter()));
122127
frameReader = new DefaultHttp2FrameReader(headersDecoder);
123128

@@ -233,11 +238,11 @@ protected final Http2FrameReader frameReader() {
233238
}
234239

235240
protected final ByteBuf content() {
236-
return content;
241+
return Unpooled.copiedBuffer(contentAsArray());
237242
}
238243

239244
protected final byte[] contentAsArray() {
240-
return ByteBufUtil.getBytes(content());
245+
return "\000\000\000\000\rhello world".getBytes(UTF_8);
241246
}
242247

243248
protected final Http2FrameWriter verifyWrite() {
@@ -252,8 +257,8 @@ protected final void channelRead(Object obj) throws Exception {
252257
channel.writeInbound(obj);
253258
}
254259

255-
protected ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) {
256-
final ByteBuf compressionFrame = Unpooled.buffer(content.length);
260+
protected ByteBuf grpcFrame(byte[] message) {
261+
final ByteBuf compressionFrame = Unpooled.buffer(message.length);
257262
MessageFramer framer = new MessageFramer(
258263
new MessageFramer.Sink() {
259264
@Override
@@ -262,23 +267,22 @@ public void deliverFrame(
262267
if (frame != null) {
263268
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf();
264269
compressionFrame.writeBytes(bytebuf);
270+
bytebuf.release();
265271
}
266272
}
267273
},
268274
new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT),
269275
StatsTraceContext.NOOP);
270-
framer.writePayload(new ByteArrayInputStream(content));
271-
framer.flush();
272-
ChannelHandlerContext ctx = newMockContext();
273-
new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream,
274-
newPromise());
275-
return captureWrite(ctx);
276+
framer.writePayload(new ByteArrayInputStream(message));
277+
framer.close();
278+
return compressionFrame;
276279
}
277280

278-
protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) {
279-
// Need to retain the content since the frameWriter releases it.
280-
content.retain();
281+
protected final ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) {
282+
return dataFrame(streamId, endStream, grpcFrame(content));
283+
}
281284

285+
protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) {
282286
ChannelHandlerContext ctx = newMockContext();
283287
new DefaultHttp2FrameWriter().writeData(ctx, streamId, content, 0, endStream, newPromise());
284288
return captureWrite(ctx);
@@ -410,6 +414,7 @@ public void dataSizeSincePingAccumulates() throws Exception {
410414
channelRead(dataFrame(3, false, buff.copy()));
411415

412416
assertEquals(length * 3, handler.flowControlPing().getDataSincePing());
417+
buff.release();
413418
}
414419

415420
@Test
@@ -608,12 +613,14 @@ public void bdpPingWindowResizing() throws Exception {
608613

609614
private void readPingAck(long pingData) throws Exception {
610615
channelRead(pingFrame(true, pingData));
616+
channel().releaseOutbound();
611617
}
612618

613619
private void readXCopies(int copies, byte[] data) throws Exception {
614620
for (int i = 0; i < copies; i++) {
615621
channelRead(grpcDataFrame(STREAM_ID, false, data)); // buffer it
616622
stream().request(1); // consume it
623+
channel().releaseOutbound();
617624
}
618625
}
619626

netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java

+11-17
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import static org.mockito.ArgumentMatchers.anyString;
4444
import static org.mockito.ArgumentMatchers.eq;
4545
import static org.mockito.ArgumentMatchers.isA;
46+
import static org.mockito.ArgumentMatchers.same;
4647
import static org.mockito.Mockito.atLeastOnce;
4748
import static org.mockito.Mockito.doAnswer;
4849
import static org.mockito.Mockito.doThrow;
@@ -74,7 +75,6 @@
7475
import io.grpc.internal.testing.TestServerStreamTracer;
7576
import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder;
7677
import io.netty.buffer.ByteBuf;
77-
import io.netty.buffer.ByteBufUtil;
7878
import io.netty.channel.ChannelFuture;
7979
import io.netty.channel.ChannelHandlerContext;
8080
import io.netty.channel.ChannelPromise;
@@ -120,23 +120,16 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
120120
public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(10));
121121
@Rule
122122
public final MockitoRule mocks = MockitoJUnit.rule();
123-
124123
private static final AsciiString HTTP_FAKE_METHOD = AsciiString.of("FAKE");
125-
126-
127124
@Mock
128125
private ServerStreamListener streamListener;
129-
130126
@Mock
131127
private ServerStreamTracer.Factory streamTracerFactory;
132-
133128
private final ServerTransportListener transportListener =
134129
mock(ServerTransportListener.class, delegatesTo(new ServerTransportListenerImpl()));
135130
private final TestServerStreamTracer streamTracer = new TestServerStreamTracer();
136-
137131
private NettyServerStream stream;
138132
private KeepAliveManager spyKeepAliveManager;
139-
140133
final Queue<InputStream> streamListenerMessageQueue = new LinkedList<>();
141134

142135
private int maxConcurrentStreams = Integer.MAX_VALUE;
@@ -208,6 +201,7 @@ protected void manualSetUp() throws Exception {
208201
// Simulate receipt of initial remote settings.
209202
ByteBuf serializedSettings = serializeSettings(new Http2Settings());
210203
channelRead(serializedSettings);
204+
channel().releaseOutbound();
211205
}
212206

213207
@Test
@@ -229,10 +223,11 @@ public void sendFrameShouldSucceed() throws Exception {
229223
createStream();
230224

231225
// Send a frame and verify that it was written.
226+
ByteBuf content = content();
232227
ChannelFuture future = enqueue(
233-
new SendGrpcFrameCommand(stream.transportState(), content(), false));
228+
new SendGrpcFrameCommand(stream.transportState(), content, false));
234229
assertTrue(future.isSuccess());
235-
verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(false),
230+
verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), same(content), eq(0), eq(false),
236231
any(ChannelPromise.class));
237232
}
238233

@@ -267,10 +262,11 @@ private void inboundDataShouldForwardToStreamListener(boolean endStream) throws
267262
// Create a data frame and then trigger the handler to read it.
268263
ByteBuf frame = grpcDataFrame(STREAM_ID, endStream, contentAsArray());
269264
channelRead(frame);
265+
channel().releaseOutbound();
270266
verify(streamListener, atLeastOnce())
271267
.messagesAvailable(any(StreamListener.MessageProducer.class));
272268
InputStream message = streamListenerMessageQueue.poll();
273-
assertArrayEquals(ByteBufUtil.getBytes(content()), ByteStreams.toByteArray(message));
269+
assertArrayEquals(contentAsArray(), ByteStreams.toByteArray(message));
274270
message.close();
275271
assertNull("no additional message expected", streamListenerMessageQueue.poll());
276272

@@ -870,7 +866,7 @@ public void keepAliveEnforcer_sendingDataResetsCounters() throws Exception {
870866
future.get();
871867
for (int i = 0; i < 10; i++) {
872868
future = enqueue(
873-
new SendGrpcFrameCommand(stream.transportState(), content().retainedSlice(), false));
869+
new SendGrpcFrameCommand(stream.transportState(), content(), false));
874870
future.get();
875871
channel().releaseOutbound();
876872
channelRead(pingFrame(false /* isAck */, 1L));
@@ -1293,6 +1289,7 @@ public void maxRstCount_withinLimit_succeeds() throws Exception {
12931289
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
12941290
manualSetUp();
12951291
rapidReset(maxRstCount);
1292+
12961293
assertTrue(channel().isOpen());
12971294
}
12981295

@@ -1302,6 +1299,7 @@ public void maxRstCount_exceedsLimit_fails() throws Exception {
13021299
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
13031300
manualSetUp();
13041301
assertThrows(ClosedChannelException.class, () -> rapidReset(maxRstCount + 1));
1302+
13051303
assertFalse(channel().isOpen());
13061304
}
13071305

@@ -1344,11 +1342,7 @@ private void createStream() throws Exception {
13441342

13451343
private ByteBuf emptyGrpcFrame(int streamId, boolean endStream) throws Exception {
13461344
ByteBuf buf = NettyTestUtil.messageFrame("");
1347-
try {
1348-
return dataFrame(streamId, endStream, buf);
1349-
} finally {
1350-
buf.release();
1351-
}
1345+
return dataFrame(streamId, endStream, buf);
13521346
}
13531347

13541348
@Override

0 commit comments

Comments
 (0)