Skip to content

Commit

Permalink
fix: Make sure outstanding RPCs count in ChannelPool can not go negat…
Browse files Browse the repository at this point in the history
…ive (#2185)

Add two flags wasClosed and wasReleased to ReleasingClientCall to check various scenarios. The combination of these two flags can make sure the count of outstanding RPCs can never go negative, and help us identify what exactly goes wrong next time it happens.
  • Loading branch information
blakeli0 authored Oct 19, 2023
1 parent 860ae76 commit d2d624d
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
* <p>Package-private for internal use.
*/
class ChannelPool extends ManagedChannel {
private static final Logger LOG = Logger.getLogger(ChannelPool.class.getName());
@VisibleForTesting static final Logger LOG = Logger.getLogger(ChannelPool.class.getName());
private static final Duration REFRESH_PERIOD = Duration.ofMinutes(50);

private final ChannelPoolSettings settings;
Expand Down Expand Up @@ -421,9 +421,25 @@ private Entry getEntry(int affinity) {
}

/** Bundles a gRPC {@link ManagedChannel} with some usage accounting. */
private static class Entry {
static class Entry {
private final ManagedChannel channel;
private final AtomicInteger outstandingRpcs = new AtomicInteger(0);

/**
* The primary purpose of keeping a count for outstanding RPCs is to track when a channel is
* safe to close. In grpc, initialization & starting of rpcs is split between 2 methods:
* Channel#newCall() and ClientCall#start. gRPC already has a mechanism to safely close channels
* that have rpcs that have been started. However, it does not protect calls that have been
* created but not started. In the sequence: Channel#newCall() Channel#shutdown()
* ClientCall#Start(), gRpc will error out the call telling the caller that the channel is
* shutdown.
*
* <p>Hence, the increment of outstanding RPCs has to happen when the ClientCall is initialized,
* as part of Channel#newCall(), not after the ClientCall is started. The decrement of
* outstanding RPCs has to happen when the ClientCall is closed or the ClientCall failed to
* start.
*/
@VisibleForTesting final AtomicInteger outstandingRpcs = new AtomicInteger(0);

private final AtomicInteger maxOutstanding = new AtomicInteger();

// Flag that the channel should be closed once all of the outstanding RPC complete.
Expand Down Expand Up @@ -470,7 +486,7 @@ private boolean retain() {
private void release() {
int newCount = outstandingRpcs.decrementAndGet();
if (newCount < 0) {
throw new IllegalStateException("Bug: reference count is negative!: " + newCount);
LOG.log(Level.WARNING, "Bug! Reference count is negative (" + newCount + ")!");
}

// Must check outstandingRpcs after shutdownRequested (in reverse order of retain()) to ensure
Expand Down Expand Up @@ -526,6 +542,8 @@ public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
static class ReleasingClientCall<ReqT, RespT> extends SimpleForwardingClientCall<ReqT, RespT> {
@Nullable private CancellationException cancellationException;
final Entry entry;
private final AtomicBoolean wasClosed = new AtomicBoolean();
private final AtomicBoolean wasReleased = new AtomicBoolean();

public ReleasingClientCall(ClientCall<ReqT, RespT> delegate, Entry entry) {
super(delegate);
Expand All @@ -542,17 +560,35 @@ public void start(Listener<RespT> responseListener, Metadata headers) {
new SimpleForwardingClientCallListener<RespT>(responseListener) {
@Override
public void onClose(Status status, Metadata trailers) {
if (!wasClosed.compareAndSet(false, true)) {
LOG.log(
Level.WARNING,
"Call is being closed more than once. Please make sure that onClose() is not being manually called.");
return;
}
try {
super.onClose(status, trailers);
} finally {
entry.release();
if (wasReleased.compareAndSet(false, true)) {
entry.release();
} else {
LOG.log(
Level.WARNING,
"Entry was released before the call is closed. This may be due to an exception on start of the call.");
}
}
}
},
headers);
} catch (Exception e) {
// In case start failed, make sure to release
entry.release();
if (wasReleased.compareAndSet(false, true)) {
entry.release();
} else {
LOG.log(
Level.WARNING,
"The entry is already released. This indicates that onClose() has already been called previously");
}
throw e;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,20 @@
*/
package com.google.api.gax.grpc;

import static com.google.api.gax.grpc.testing.FakeServiceGrpc.METHOD_RECOGNIZE;
import static com.google.api.gax.grpc.testing.FakeServiceGrpc.METHOD_SERVER_STREAMING_RECOGNIZE;
import static com.google.common.truth.Truth.assertThat;

import com.google.api.core.ApiFuture;
import com.google.api.gax.grpc.testing.FakeChannelFactory;
import com.google.api.gax.grpc.testing.FakeMethodDescriptor;
import com.google.api.gax.grpc.testing.FakeServiceGrpc;
import com.google.api.gax.rpc.ClientContext;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallSettings;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.StreamController;
import com.google.api.gax.rpc.UnaryCallSettings;
import com.google.api.gax.rpc.UnaryCallable;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
Expand All @@ -63,6 +66,9 @@
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Handler;
import java.util.logging.LogRecord;
import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -117,7 +123,7 @@ public void testRoundRobin() throws IOException {

private void verifyTargetChannel(
ChannelPool pool, List<ManagedChannel> channels, ManagedChannel targetChannel) {
MethodDescriptor<Color, Money> methodDescriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
MethodDescriptor<Color, Money> methodDescriptor = METHOD_RECOGNIZE;
CallOptions callOptions = CallOptions.DEFAULT;
@SuppressWarnings("unchecked")
ClientCall<Color, Money> expectedClientCall = Mockito.mock(ClientCall.class);
Expand All @@ -143,7 +149,7 @@ public void ensureEvenDistribution() throws InterruptedException, IOException {
final ManagedChannel[] channels = new ManagedChannel[numChannels];
final AtomicInteger[] counts = new AtomicInteger[numChannels];

final MethodDescriptor<Color, Money> methodDescriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
final MethodDescriptor<Color, Money> methodDescriptor = METHOD_RECOGNIZE;
final CallOptions callOptions = CallOptions.DEFAULT;
@SuppressWarnings("unchecked")
final ClientCall<Color, Money> clientCall = Mockito.mock(ClientCall.class);
Expand Down Expand Up @@ -472,23 +478,21 @@ public void channelCountShouldNotChangeWhenOutstandingRpcsAreWithinLimits() thro
// Start the minimum number of
for (int i = 0; i < 2; i++) {
ClientCalls.futureUnaryCall(
pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT),
Color.getDefaultInstance());
pool.newCall(METHOD_RECOGNIZE, CallOptions.DEFAULT), Color.getDefaultInstance());
}
pool.resize();
assertThat(pool.entries.get()).hasSize(2);

// Add enough RPCs to be just at the brink of expansion
for (int i = startedCalls.size(); i < 4; i++) {
ClientCalls.futureUnaryCall(
pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT),
Color.getDefaultInstance());
pool.newCall(METHOD_RECOGNIZE, CallOptions.DEFAULT), Color.getDefaultInstance());
}
pool.resize();
assertThat(pool.entries.get()).hasSize(2);

// Add another RPC to push expansion
pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT);
pool.newCall(METHOD_RECOGNIZE, CallOptions.DEFAULT);
pool.resize();
assertThat(pool.entries.get()).hasSize(4); // += ChannelPool::MAX_RESIZE_DELTA
assertThat(startedCalls).hasSize(5);
Expand Down Expand Up @@ -593,8 +597,7 @@ public void removedActiveChannelsAreShutdown() throws Exception {
// Start 2 RPCs
for (int i = 0; i < 2; i++) {
ClientCalls.futureUnaryCall(
pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT),
Color.getDefaultInstance());
pool.newCall(METHOD_RECOGNIZE, CallOptions.DEFAULT), Color.getDefaultInstance());
}
// Complete the first one
@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -663,4 +666,74 @@ public void onComplete() {}
assertThat(e.getCause()).isInstanceOf(CancellationException.class);
assertThat(e.getMessage()).isEqualTo("Call is already cancelled");
}

@Test
public void testDoubleRelease() throws Exception {
FakeLogHandler logHandler = new FakeLogHandler();
ChannelPool.LOG.addHandler(logHandler);

try {
// Create a fake channel pool thats backed by mock channels that simply record invocations
ClientCall mockClientCall = Mockito.mock(ClientCall.class);
ManagedChannel fakeChannel = Mockito.mock(ManagedChannel.class);
Mockito.when(fakeChannel.newCall(Mockito.any(), Mockito.any())).thenReturn(mockClientCall);
ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1);
ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel));

pool = ChannelPool.create(channelPoolSettings, factory);

// Construct a fake callable to use the channel pool
ClientContext context =
ClientContext.newBuilder()
.setTransportChannel(GrpcTransportChannel.create(pool))
.setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT))
.build();

UnaryCallSettings<Color, Money> settings =
UnaryCallSettings.<Color, Money>newUnaryCallSettingsBuilder().build();
UnaryCallable<Color, Money> callable =
GrpcCallableFactory.createUnaryCallable(
GrpcCallSettings.create(METHOD_RECOGNIZE), settings, context);

// Start the RPC
ApiFuture<Money> rpcFuture =
callable.futureCall(Color.getDefaultInstance(), context.getDefaultCallContext());

// Get the server side listener and intentionally close it twice
ArgumentCaptor<ClientCall.Listener<?>> clientCallListenerCaptor =
ArgumentCaptor.forClass(ClientCall.Listener.class);
Mockito.verify(mockClientCall).start(clientCallListenerCaptor.capture(), Mockito.any());
clientCallListenerCaptor.getValue().onClose(Status.INTERNAL, new Metadata());
clientCallListenerCaptor.getValue().onClose(Status.UNKNOWN, new Metadata());

// Ensure that the channel pool properly logged the double call and kept the refCount correct
assertThat(logHandler.getAllMessages())
.contains(
"Call is being closed more than once. Please make sure that onClose() is not being manually called.");
assertThat(pool.entries.get()).hasSize(1);
ChannelPool.Entry entry = pool.entries.get().get(0);
assertThat(entry.outstandingRpcs.get()).isEqualTo(0);
} finally {
ChannelPool.LOG.removeHandler(logHandler);
}
}

private static class FakeLogHandler extends Handler {
List<LogRecord> records = new ArrayList<>();

@Override
public void publish(LogRecord record) {
records.add(record);
}

@Override
public void flush() {}

@Override
public void close() throws SecurityException {}

List<String> getAllMessages() {
return records.stream().map(LogRecord::getMessage).collect(Collectors.toList());
}
}
}

0 comments on commit d2d624d

Please sign in to comment.