From d2d624def3d72ecdaa3fdf7e2a92fd29d39ff1aa Mon Sep 17 00:00:00 2001 From: Blake Li Date: Thu, 19 Oct 2023 17:39:20 -0400 Subject: [PATCH] fix: Make sure outstanding RPCs count in ChannelPool can not go negative (#2185) Add two flags wasClosed and wasReleased to ReleasingClientCall to check various scenarios. The combination of these two flags can make sure the count of outstanding RPCs can never go negative, and help us identify what exactly goes wrong next time it happens. --- .../com/google/api/gax/grpc/ChannelPool.java | 48 ++++++++-- .../google/api/gax/grpc/ChannelPoolTest.java | 93 +++++++++++++++++-- 2 files changed, 125 insertions(+), 16 deletions(-) diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index df3888dfc9..c3e26dc4e2 100644 --- a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -68,7 +68,7 @@ *

Package-private for internal use. */ class ChannelPool extends ManagedChannel { - private static final Logger LOG = Logger.getLogger(ChannelPool.class.getName()); + @VisibleForTesting static final Logger LOG = Logger.getLogger(ChannelPool.class.getName()); private static final Duration REFRESH_PERIOD = Duration.ofMinutes(50); private final ChannelPoolSettings settings; @@ -421,9 +421,25 @@ private Entry getEntry(int affinity) { } /** Bundles a gRPC {@link ManagedChannel} with some usage accounting. */ - private static class Entry { + static class Entry { private final ManagedChannel channel; - private final AtomicInteger outstandingRpcs = new AtomicInteger(0); + + /** + * The primary purpose of keeping a count for outstanding RPCs is to track when a channel is + * safe to close. In grpc, initialization & starting of rpcs is split between 2 methods: + * Channel#newCall() and ClientCall#start. gRPC already has a mechanism to safely close channels + * that have rpcs that have been started. However, it does not protect calls that have been + * created but not started. In the sequence: Channel#newCall() Channel#shutdown() + * ClientCall#Start(), gRpc will error out the call telling the caller that the channel is + * shutdown. + * + *

Hence, the increment of outstanding RPCs has to happen when the ClientCall is initialized, + * as part of Channel#newCall(), not after the ClientCall is started. The decrement of + * outstanding RPCs has to happen when the ClientCall is closed or the ClientCall failed to + * start. + */ + @VisibleForTesting final AtomicInteger outstandingRpcs = new AtomicInteger(0); + private final AtomicInteger maxOutstanding = new AtomicInteger(); // Flag that the channel should be closed once all of the outstanding RPC complete. @@ -470,7 +486,7 @@ private boolean retain() { private void release() { int newCount = outstandingRpcs.decrementAndGet(); if (newCount < 0) { - throw new IllegalStateException("Bug: reference count is negative!: " + newCount); + LOG.log(Level.WARNING, "Bug! Reference count is negative (" + newCount + ")!"); } // Must check outstandingRpcs after shutdownRequested (in reverse order of retain()) to ensure @@ -526,6 +542,8 @@ public ClientCall newCall( static class ReleasingClientCall extends SimpleForwardingClientCall { @Nullable private CancellationException cancellationException; final Entry entry; + private final AtomicBoolean wasClosed = new AtomicBoolean(); + private final AtomicBoolean wasReleased = new AtomicBoolean(); public ReleasingClientCall(ClientCall delegate, Entry entry) { super(delegate); @@ -542,17 +560,35 @@ public void start(Listener responseListener, Metadata headers) { new SimpleForwardingClientCallListener(responseListener) { @Override public void onClose(Status status, Metadata trailers) { + if (!wasClosed.compareAndSet(false, true)) { + LOG.log( + Level.WARNING, + "Call is being closed more than once. Please make sure that onClose() is not being manually called."); + return; + } try { super.onClose(status, trailers); } finally { - entry.release(); + if (wasReleased.compareAndSet(false, true)) { + entry.release(); + } else { + LOG.log( + Level.WARNING, + "Entry was released before the call is closed. This may be due to an exception on start of the call."); + } } } }, headers); } catch (Exception e) { // In case start failed, make sure to release - entry.release(); + if (wasReleased.compareAndSet(false, true)) { + entry.release(); + } else { + LOG.log( + Level.WARNING, + "The entry is already released. This indicates that onClose() has already been called previously"); + } throw e; } } diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index c38d98e91f..173528d6e2 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -29,17 +29,20 @@ */ package com.google.api.gax.grpc; +import static com.google.api.gax.grpc.testing.FakeServiceGrpc.METHOD_RECOGNIZE; import static com.google.api.gax.grpc.testing.FakeServiceGrpc.METHOD_SERVER_STREAMING_RECOGNIZE; import static com.google.common.truth.Truth.assertThat; +import com.google.api.core.ApiFuture; import com.google.api.gax.grpc.testing.FakeChannelFactory; import com.google.api.gax.grpc.testing.FakeMethodDescriptor; -import com.google.api.gax.grpc.testing.FakeServiceGrpc; import com.google.api.gax.rpc.ClientContext; import com.google.api.gax.rpc.ResponseObserver; import com.google.api.gax.rpc.ServerStreamingCallSettings; import com.google.api.gax.rpc.ServerStreamingCallable; import com.google.api.gax.rpc.StreamController; +import com.google.api.gax.rpc.UnaryCallSettings; +import com.google.api.gax.rpc.UnaryCallable; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; @@ -63,6 +66,9 @@ import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.logging.Handler; +import java.util.logging.LogRecord; +import java.util.stream.Collectors; import org.junit.After; import org.junit.Assert; import org.junit.Test; @@ -117,7 +123,7 @@ public void testRoundRobin() throws IOException { private void verifyTargetChannel( ChannelPool pool, List channels, ManagedChannel targetChannel) { - MethodDescriptor methodDescriptor = FakeServiceGrpc.METHOD_RECOGNIZE; + MethodDescriptor methodDescriptor = METHOD_RECOGNIZE; CallOptions callOptions = CallOptions.DEFAULT; @SuppressWarnings("unchecked") ClientCall expectedClientCall = Mockito.mock(ClientCall.class); @@ -143,7 +149,7 @@ public void ensureEvenDistribution() throws InterruptedException, IOException { final ManagedChannel[] channels = new ManagedChannel[numChannels]; final AtomicInteger[] counts = new AtomicInteger[numChannels]; - final MethodDescriptor methodDescriptor = FakeServiceGrpc.METHOD_RECOGNIZE; + final MethodDescriptor methodDescriptor = METHOD_RECOGNIZE; final CallOptions callOptions = CallOptions.DEFAULT; @SuppressWarnings("unchecked") final ClientCall clientCall = Mockito.mock(ClientCall.class); @@ -472,8 +478,7 @@ public void channelCountShouldNotChangeWhenOutstandingRpcsAreWithinLimits() thro // Start the minimum number of for (int i = 0; i < 2; i++) { ClientCalls.futureUnaryCall( - pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT), - Color.getDefaultInstance()); + pool.newCall(METHOD_RECOGNIZE, CallOptions.DEFAULT), Color.getDefaultInstance()); } pool.resize(); assertThat(pool.entries.get()).hasSize(2); @@ -481,14 +486,13 @@ public void channelCountShouldNotChangeWhenOutstandingRpcsAreWithinLimits() thro // Add enough RPCs to be just at the brink of expansion for (int i = startedCalls.size(); i < 4; i++) { ClientCalls.futureUnaryCall( - pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT), - Color.getDefaultInstance()); + pool.newCall(METHOD_RECOGNIZE, CallOptions.DEFAULT), Color.getDefaultInstance()); } pool.resize(); assertThat(pool.entries.get()).hasSize(2); // Add another RPC to push expansion - pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT); + pool.newCall(METHOD_RECOGNIZE, CallOptions.DEFAULT); pool.resize(); assertThat(pool.entries.get()).hasSize(4); // += ChannelPool::MAX_RESIZE_DELTA assertThat(startedCalls).hasSize(5); @@ -593,8 +597,7 @@ public void removedActiveChannelsAreShutdown() throws Exception { // Start 2 RPCs for (int i = 0; i < 2; i++) { ClientCalls.futureUnaryCall( - pool.newCall(FakeServiceGrpc.METHOD_RECOGNIZE, CallOptions.DEFAULT), - Color.getDefaultInstance()); + pool.newCall(METHOD_RECOGNIZE, CallOptions.DEFAULT), Color.getDefaultInstance()); } // Complete the first one @SuppressWarnings("unchecked") @@ -663,4 +666,74 @@ public void onComplete() {} assertThat(e.getCause()).isInstanceOf(CancellationException.class); assertThat(e.getMessage()).isEqualTo("Call is already cancelled"); } + + @Test + public void testDoubleRelease() throws Exception { + FakeLogHandler logHandler = new FakeLogHandler(); + ChannelPool.LOG.addHandler(logHandler); + + try { + // Create a fake channel pool thats backed by mock channels that simply record invocations + ClientCall mockClientCall = Mockito.mock(ClientCall.class); + ManagedChannel fakeChannel = Mockito.mock(ManagedChannel.class); + Mockito.when(fakeChannel.newCall(Mockito.any(), Mockito.any())).thenReturn(mockClientCall); + ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1); + ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel)); + + pool = ChannelPool.create(channelPoolSettings, factory); + + // Construct a fake callable to use the channel pool + ClientContext context = + ClientContext.newBuilder() + .setTransportChannel(GrpcTransportChannel.create(pool)) + .setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT)) + .build(); + + UnaryCallSettings settings = + UnaryCallSettings.newUnaryCallSettingsBuilder().build(); + UnaryCallable callable = + GrpcCallableFactory.createUnaryCallable( + GrpcCallSettings.create(METHOD_RECOGNIZE), settings, context); + + // Start the RPC + ApiFuture rpcFuture = + callable.futureCall(Color.getDefaultInstance(), context.getDefaultCallContext()); + + // Get the server side listener and intentionally close it twice + ArgumentCaptor> clientCallListenerCaptor = + ArgumentCaptor.forClass(ClientCall.Listener.class); + Mockito.verify(mockClientCall).start(clientCallListenerCaptor.capture(), Mockito.any()); + clientCallListenerCaptor.getValue().onClose(Status.INTERNAL, new Metadata()); + clientCallListenerCaptor.getValue().onClose(Status.UNKNOWN, new Metadata()); + + // Ensure that the channel pool properly logged the double call and kept the refCount correct + assertThat(logHandler.getAllMessages()) + .contains( + "Call is being closed more than once. Please make sure that onClose() is not being manually called."); + assertThat(pool.entries.get()).hasSize(1); + ChannelPool.Entry entry = pool.entries.get().get(0); + assertThat(entry.outstandingRpcs.get()).isEqualTo(0); + } finally { + ChannelPool.LOG.removeHandler(logHandler); + } + } + + private static class FakeLogHandler extends Handler { + List records = new ArrayList<>(); + + @Override + public void publish(LogRecord record) { + records.add(record); + } + + @Override + public void flush() {} + + @Override + public void close() throws SecurityException {} + + List getAllMessages() { + return records.stream().map(LogRecord::getMessage).collect(Collectors.toList()); + } + } }