Skip to content

Commit

Permalink
Fix Inprocess memory leak
Browse files Browse the repository at this point in the history
  • Loading branch information
shivaspeaks committed Sep 10, 2024
1 parent 82ddea9 commit 85e5acb
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
import io.grpc.Status;
import io.grpc.internal.testing.TestClientStreamTracer;
import io.grpc.internal.testing.TestServerStreamTracer;
import io.grpc.internal.testing.TestStreamTracer;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
Expand Down Expand Up @@ -161,9 +160,9 @@ public void log(ChannelLogLevel level, String messageFormat, Object... args) {}
* tests in an indeterminate state.
*/
protected InternalServer server;
private ServerTransport serverTransport;
private ManagedClientTransport client;
private MethodDescriptor<String, String> methodDescriptor =
protected ServerTransport serverTransport;
protected ManagedClientTransport client;
protected MethodDescriptor<String, String> methodDescriptor =
MethodDescriptor.<String, String>newBuilder()
.setType(MethodDescriptor.MethodType.UNKNOWN)
.setFullMethodName("service/method")
Expand All @@ -180,20 +179,20 @@ public void log(ChannelLogLevel level, String messageFormat, Object... args) {}
"tracer-key", Metadata.ASCII_STRING_MARSHALLER);
private final String tracerKeyValue = "tracer-key-value";

private ManagedClientTransport.Listener mockClientTransportListener
protected ManagedClientTransport.Listener mockClientTransportListener
= mock(ManagedClientTransport.Listener.class);
private MockServerListener serverListener = new MockServerListener();
protected MockServerListener serverListener = new MockServerListener();
private ArgumentCaptor<Throwable> throwableCaptor = ArgumentCaptor.forClass(Throwable.class);
private final TestClientStreamTracer clientStreamTracer1 = new TestHeaderClientStreamTracer();
protected final TestClientStreamTracer clientStreamTracer1 = new TestHeaderClientStreamTracer();
private final TestClientStreamTracer clientStreamTracer2 = new TestHeaderClientStreamTracer();
private final ClientStreamTracer[] tracers = new ClientStreamTracer[] {
protected final ClientStreamTracer[] tracers = new ClientStreamTracer[] {
clientStreamTracer1, clientStreamTracer2
};
private final ClientStreamTracer[] noopTracers = new ClientStreamTracer[] {
new ClientStreamTracer() {}
};

private final TestServerStreamTracer serverStreamTracer1 = new TestServerStreamTracer();
protected final TestServerStreamTracer serverStreamTracer1 = new TestServerStreamTracer();
private final TestServerStreamTracer serverStreamTracer2 = new TestServerStreamTracer();
private final ServerStreamTracer.Factory serverStreamTracerFactory = mock(
ServerStreamTracer.Factory.class,
Expand Down Expand Up @@ -250,15 +249,6 @@ protected long fakeCurrentTimeNanos() {
throw new UnsupportedOperationException();
}

/**
* Specific test for {@link InProcessTransport} to verify assumedMessageSize.
* For more goto: <a href="#11406">link</a>
*/
protected void assertInProcessTransportAssumedMessageSize(
TestStreamTracer streamTracerSender, TestStreamTracer streamTracerReceiver) {
// implemented by SizesReportedInProcessTransportTest
}

// TODO(ejona):
// multiple streams on same transport
// multiple client transports to same server
Expand Down Expand Up @@ -877,7 +867,6 @@ public void basicStream() throws Exception {
assertThat(serverStreamTracer1.nextInboundEvent())
.matches("inboundMessageRead\\(0, -?[0-9]+, -?[0-9]+\\)");

assertInProcessTransportAssumedMessageSize(clientStreamTracer1, serverStreamTracer1);
Metadata serverHeaders = new Metadata();
serverHeaders.put(asciiKey, "server");
serverHeaders.put(asciiKey, "dupvalue");
Expand Down Expand Up @@ -914,7 +903,6 @@ public void basicStream() throws Exception {
.matches("inboundMessageRead\\(0, -?[0-9]+, -?[0-9]+\\)");
assertThat(clientStreamTracer1.getInboundWireSize()).isGreaterThan(0L);
assertThat(clientStreamTracer1.getInboundUncompressedSize()).isGreaterThan(0L);
assertInProcessTransportAssumedMessageSize(serverStreamTracer1, clientStreamTracer1);

message.close();
assertNull("no additional message expected", clientStreamListener.messageQueue.poll());
Expand Down Expand Up @@ -1278,7 +1266,6 @@ public void onReady() {
assertThat(clientStreamTracer1.getInboundUncompressedSize()).isGreaterThan(0L);
assertThat(serverStreamTracer1.getOutboundWireSize()).isGreaterThan(0L);
assertThat(serverStreamTracer1.getOutboundUncompressedSize()).isGreaterThan(0L);
assertInProcessTransportAssumedMessageSize(serverStreamTracer1, clientStreamTracer1);
assertNull(clientStreamTracer1.getInboundTrailers());
assertSame(status, clientStreamTracer1.getStatus());
// There is a race between client cancelling and server closing. The final status seen by the
Expand Down Expand Up @@ -2167,7 +2154,7 @@ private static void runIfNotNull(Runnable runnable) {
}
}

private static void startTransport(
protected static void startTransport(
ManagedClientTransport clientTransport,
ManagedClientTransport.Listener listener) {
runIfNotNull(clientTransport.start(listener));
Expand All @@ -2185,7 +2172,7 @@ public void streamCreated(Attributes transportAttrs, Metadata metadata) {
}
}

private static class MockServerListener implements ServerListener {
public static class MockServerListener implements ServerListener {
public final BlockingQueue<MockServerTransportListener> listeners
= new LinkedBlockingQueue<>();
private final SettableFuture<?> shutdown = SettableFuture.create();
Expand Down Expand Up @@ -2216,7 +2203,7 @@ public MockServerTransportListener takeListenerOrFail(long timeout, TimeUnit uni
}
}

private static class MockServerTransportListener implements ServerTransportListener {
public static class MockServerTransportListener implements ServerTransportListener {
public final ServerTransport transport;
public final BlockingQueue<StreamCreation> streams = new LinkedBlockingQueue<>();
private final SettableFuture<?> terminated = SettableFuture.create();
Expand Down Expand Up @@ -2264,8 +2251,8 @@ public StreamCreation takeStreamOrFail(long timeout, TimeUnit unit)
}
}

private static class ServerStreamListenerBase implements ServerStreamListener {
private final BlockingQueue<InputStream> messageQueue = new LinkedBlockingQueue<>();
public static class ServerStreamListenerBase implements ServerStreamListener {
public final BlockingQueue<InputStream> messageQueue = new LinkedBlockingQueue<>();
// Would have used Void instead of Object, but null elements are not allowed
private final BlockingQueue<Object> readyQueue = new LinkedBlockingQueue<>();
private final CountDownLatch halfClosedLatch = new CountDownLatch(1);
Expand Down Expand Up @@ -2324,8 +2311,8 @@ public void closed(Status status) {
}
}

private static class ClientStreamListenerBase implements ClientStreamListener {
private final BlockingQueue<InputStream> messageQueue = new LinkedBlockingQueue<>();
public static class ClientStreamListenerBase implements ClientStreamListener {
public final BlockingQueue<InputStream> messageQueue = new LinkedBlockingQueue<>();
// Would have used Void instead of Object, but null elements are not allowed
private final BlockingQueue<Object> readyQueue = new LinkedBlockingQueue<>();
private final SettableFuture<Metadata> headers = SettableFuture.create();
Expand Down Expand Up @@ -2382,7 +2369,7 @@ public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) {
}
}

private static class StreamCreation {
public static class StreamCreation {
public final ServerStream stream;
public final String method;
public final Metadata headers;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.grpc.inprocess;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import io.grpc.CallOptions;
Expand All @@ -34,15 +35,20 @@
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import io.grpc.internal.AbstractTransportTest;
import io.grpc.internal.ClientStream;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.InternalServer;
import io.grpc.internal.ManagedClientTransport;
import io.grpc.internal.ServerStream;
import io.grpc.internal.testing.TestStreamTracer;
import io.grpc.stub.ClientCalls;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.testing.TestMethodDescriptors;
import java.io.InputStream;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
Expand All @@ -55,23 +61,26 @@ public class InProcessTransportTest extends AbstractTransportTest {
private static final String TRANSPORT_NAME = "perfect-for-testing";
private static final String AUTHORITY = "a-testing-authority";
protected static final String USER_AGENT = "a-testing-user-agent";
private static final int TIMEOUT_MS = 5000;
private static final long TEST_MESSAGE_LENGTH = 100;

@Rule
public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule();

@Override
protected InternalServer newServer(
List<ServerStreamTracer.Factory> streamTracerFactories) {
InProcessServerBuilder builder = InProcessServerBuilder
.forName(TRANSPORT_NAME)
.maxInboundMetadataSize(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE);
return new InProcessServer(builder, streamTracerFactories);
int port, List<ServerStreamTracer.Factory> streamTracerFactories) {
return newServer(streamTracerFactories);
}

@Override
protected InternalServer newServer(
int port, List<ServerStreamTracer.Factory> streamTracerFactories) {
return newServer(streamTracerFactories);
List<ServerStreamTracer.Factory> streamTracerFactories) {
InProcessServerBuilder builder = InProcessServerBuilder
.forName(TRANSPORT_NAME)
.maxInboundMetadataSize(GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE)
.assumedMessageSize(TEST_MESSAGE_LENGTH);
return new InProcessServer(builder, streamTracerFactories);
}

@Override
Expand All @@ -86,6 +95,12 @@ protected ManagedClientTransport newClientTransport(InternalServer server) {
testAuthority(server), USER_AGENT, eagAttrs(), false);
}

private ManagedClientTransport newClientTransportWithAssumedMessageSize(InternalServer server) {
return new InProcessTransport(
new InProcessSocketAddress(TRANSPORT_NAME), GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE,
testAuthority(server), USER_AGENT, eagAttrs(), false, TEST_MESSAGE_LENGTH);
}

@Test
@Ignore
@Override
Expand Down Expand Up @@ -163,11 +178,62 @@ public Listener<Void> startCall(ServerCall<Void, Void> call, Metadata headers) {
.build();
ClientCall<Void,Void> call = channel.newCall(nonMatchMethod, CallOptions.DEFAULT);
try {
ClientCalls.futureUnaryCall(call, null).get(5, TimeUnit.SECONDS);
ClientCalls.futureUnaryCall(call, null).get(TIMEOUT_MS, TimeUnit.MILLISECONDS);
fail("Call should fail.");
} catch (ExecutionException ex) {
StatusRuntimeException s = (StatusRuntimeException)ex.getCause();
assertEquals(Code.UNIMPLEMENTED, s.getStatus().getCode());
}
}

@Ignore
@Test
public void basicStreamInProcess() throws Exception {
server.start(serverListener);
client = newClientTransportWithAssumedMessageSize(server);
startTransport(client, mockClientTransportListener);
MockServerTransportListener serverTransportListener
= serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS);
serverTransport = serverTransportListener.transport;
// Set up client stream
ClientStream clientStream = client.newStream(
methodDescriptor, new Metadata(), CallOptions.DEFAULT, tracers);
ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase();
clientStream.start(clientStreamListener);
StreamCreation serverStreamCreation
= serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS);
ServerStream serverStream = serverStreamCreation.stream;
ServerStreamListenerBase serverStreamListener = serverStreamCreation.listener;
serverStream.request(1);
assertTrue(clientStream.isReady());
// Send message from client to server
clientStream.writeMessage(methodDescriptor.streamRequest("Hello from client"));
clientStream.flush();
// Verify server received the message and check its size
InputStream message =
serverStreamListener.messageQueue.poll(TIMEOUT_MS, TimeUnit.MILLISECONDS);
assertEquals("Hello from client", methodDescriptor.parseRequest(message));
message.close();
clientStream.halfClose();
assertAssumedMessageSize(clientStreamTracer1, serverStreamTracer1);

clientStream.request(1);
assertTrue(serverStream.isReady());
serverStream.writeMessage(methodDescriptor.streamResponse("Hi from server"));
serverStream.flush();
message = clientStreamListener.messageQueue.poll(TIMEOUT_MS, TimeUnit.MILLISECONDS);
assertEquals("Hi from server", methodDescriptor.parseResponse(message));
assertAssumedMessageSize(serverStreamTracer1, clientStreamTracer1);
message.close();
Status status = Status.OK.withDescription("That was normal");
serverStream.close(status, new Metadata());
}

private void assertAssumedMessageSize(
TestStreamTracer streamTracerSender, TestStreamTracer streamTracerReceiver) {
Assert.assertEquals(TEST_MESSAGE_LENGTH, streamTracerSender.getOutboundWireSize());
Assert.assertEquals(TEST_MESSAGE_LENGTH, streamTracerSender.getOutboundUncompressedSize());
Assert.assertEquals(TEST_MESSAGE_LENGTH, streamTracerReceiver.getInboundWireSize());
Assert.assertEquals(TEST_MESSAGE_LENGTH, streamTracerReceiver.getInboundUncompressedSize());
}
}

This file was deleted.

0 comments on commit 85e5acb

Please sign in to comment.