From e75a044107bc8701b2e328b7ec633c982fb3a844 Mon Sep 17 00:00:00 2001 From: Riya Mehta <55350838+rmehta19@users.noreply.github.com> Date: Fri, 20 Sep 2024 12:32:54 -0700 Subject: [PATCH 1/2] s2a,netty: S2AHandshakerServiceChannel doesn't use custom event loop. (#11539) * S2AHandshakerServiceChannel doesn't use custom event loop. * use executorPool. * log when channel not shutdown. * use a cached threadpool. * update non-executor version. --- .../netty/InternalProtocolNegotiators.java | 18 ++++++- .../channel/S2AHandshakerServiceChannel.java | 50 +++++++------------ .../S2AProtocolNegotiatorFactory.java | 7 ++- .../S2AHandshakerServiceChannelTest.java | 46 +++++++---------- 4 files changed, 56 insertions(+), 65 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index 0d309828c6d..b9c6a77982a 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -17,12 +17,14 @@ package io.grpc.netty; import io.grpc.ChannelLogger; +import io.grpc.internal.ObjectPool; import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler; import io.grpc.netty.ProtocolNegotiators.GrpcNegotiationHandler; import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler; import io.netty.channel.ChannelHandler; import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; +import java.util.concurrent.Executor; /** * Internal accessor for {@link ProtocolNegotiators}. @@ -35,9 +37,12 @@ private InternalProtocolNegotiators() {} * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. + * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ - public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) { - final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext); + public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, + ObjectPool executorPool) { + final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext, + executorPool); final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { @Override @@ -58,6 +63,15 @@ public void close() { return new TlsNegotiator(); } + + /** + * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will + * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} + * may happen immediately, even before the TLS Handshake is complete. + */ + public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) { + return tls(sslContext, null); + } /** * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will be diff --git a/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java b/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java index 75ec7347bb5..90956907bfe 100644 --- a/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java +++ b/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java @@ -29,13 +29,11 @@ import io.grpc.MethodDescriptor; import io.grpc.internal.SharedResourceHolder.Resource; import io.grpc.netty.NettyChannelBuilder; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.channel.socket.nio.NioSocketChannel; -import io.netty.util.concurrent.DefaultThreadFactory; import java.time.Duration; import java.util.Optional; import java.util.concurrent.ConcurrentMap; +import java.util.logging.Level; +import java.util.logging.Logger; import javax.annotation.concurrent.ThreadSafe; /** @@ -61,7 +59,6 @@ public final class S2AHandshakerServiceChannel { private static final ConcurrentMap> SHARED_RESOURCE_CHANNELS = Maps.newConcurrentMap(); - private static final Duration DELEGATE_TERMINATION_TIMEOUT = Duration.ofSeconds(2); private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10); /** @@ -95,41 +92,34 @@ public ChannelResource(String targetAddress, Optional channe } /** - * Creates a {@code EventLoopHoldingChannel} instance to the service running at {@code - * targetAddress}. This channel uses a dedicated thread pool for its {@code EventLoopGroup} - * instance to avoid blocking. + * Creates a {@code HandshakerServiceChannel} instance to the service running at {@code + * targetAddress}. */ @Override public Channel create() { - EventLoopGroup eventLoopGroup = - new NioEventLoopGroup(1, new DefaultThreadFactory("S2A channel pool", true)); ManagedChannel channel = null; if (channelCredentials.isPresent()) { // Create a secure channel. channel = NettyChannelBuilder.forTarget(targetAddress, channelCredentials.get()) - .channelType(NioSocketChannel.class) .directExecutor() - .eventLoopGroup(eventLoopGroup) .build(); } else { // Create a plaintext channel. channel = NettyChannelBuilder.forTarget(targetAddress) - .channelType(NioSocketChannel.class) .directExecutor() - .eventLoopGroup(eventLoopGroup) .usePlaintext() .build(); } - return EventLoopHoldingChannel.create(channel, eventLoopGroup); + return HandshakerServiceChannel.create(channel); } - /** Destroys a {@code EventLoopHoldingChannel} instance. */ + /** Destroys a {@code HandshakerServiceChannel} instance. */ @Override public void close(Channel instanceChannel) { checkNotNull(instanceChannel); - EventLoopHoldingChannel channel = (EventLoopHoldingChannel) instanceChannel; + HandshakerServiceChannel channel = (HandshakerServiceChannel) instanceChannel; channel.close(); } @@ -140,23 +130,21 @@ public String toString() { } /** - * Manages a channel using a {@link ManagedChannel} instance that belong to the {@code - * EventLoopGroup} thread pool. + * Manages a channel using a {@link ManagedChannel} instance. */ @VisibleForTesting - static class EventLoopHoldingChannel extends Channel { + static class HandshakerServiceChannel extends Channel { + private static final Logger logger = + Logger.getLogger(S2AHandshakerServiceChannel.class.getName()); private final ManagedChannel delegate; - private final EventLoopGroup eventLoopGroup; - static EventLoopHoldingChannel create(ManagedChannel delegate, EventLoopGroup eventLoopGroup) { + static HandshakerServiceChannel create(ManagedChannel delegate) { checkNotNull(delegate); - checkNotNull(eventLoopGroup); - return new EventLoopHoldingChannel(delegate, eventLoopGroup); + return new HandshakerServiceChannel(delegate); } - private EventLoopHoldingChannel(ManagedChannel delegate, EventLoopGroup eventLoopGroup) { + private HandshakerServiceChannel(ManagedChannel delegate) { this.delegate = delegate; - this.eventLoopGroup = eventLoopGroup; } /** @@ -178,16 +166,12 @@ public ClientCall newCall( @SuppressWarnings("FutureReturnValueIgnored") public void close() { delegate.shutdownNow(); - boolean isDelegateTerminated; try { - isDelegateTerminated = - delegate.awaitTermination(DELEGATE_TERMINATION_TIMEOUT.getSeconds(), SECONDS); + delegate.awaitTermination(CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); } catch (InterruptedException e) { - isDelegateTerminated = false; + Thread.currentThread().interrupt(); + logger.log(Level.WARNING, "Channel to S2A was not shutdown."); } - long quietPeriodSeconds = isDelegateTerminated ? 0 : 1; - eventLoopGroup.shutdownGracefully( - quietPeriodSeconds, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); } } diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java index 25d1e325ea8..14bdc05238d 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java @@ -29,7 +29,9 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.errorprone.annotations.ThreadSafe; import io.grpc.Channel; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.InternalProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; @@ -227,7 +229,10 @@ protected void handlerAdded0(ChannelHandlerContext ctx) { @Override public void onSuccess(SslContext sslContext) { ChannelHandler handler = - InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler); + InternalProtocolNegotiators.tls( + sslContext, + SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR)) + .newHandler(grpcHandler); // Remove the bufferReads handler and delegate the rest of the handshake to the TLS // handler. diff --git a/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java b/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java index 57288be1b6f..dc5909442bf 100644 --- a/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java +++ b/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java @@ -18,11 +18,7 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; -import static java.util.concurrent.TimeUnit.SECONDS; import static org.junit.Assert.assertThrows; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import io.grpc.CallOptions; import io.grpc.Channel; @@ -39,15 +35,13 @@ import io.grpc.benchmarks.Utils; import io.grpc.internal.SharedResourceHolder.Resource; import io.grpc.netty.NettyServerBuilder; -import io.grpc.s2a.channel.S2AHandshakerServiceChannel.EventLoopHoldingChannel; +import io.grpc.s2a.channel.S2AHandshakerServiceChannel.HandshakerServiceChannel; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.protobuf.SimpleRequest; import io.grpc.testing.protobuf.SimpleResponse; import io.grpc.testing.protobuf.SimpleServiceGrpc; -import io.netty.channel.EventLoopGroup; import java.io.File; -import java.time.Duration; import java.util.Optional; import java.util.concurrent.TimeUnit; import org.junit.Before; @@ -60,8 +54,6 @@ @RunWith(JUnit4.class) public final class S2AHandshakerServiceChannelTest { @ClassRule public static final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10); - private final EventLoopGroup mockEventLoopGroup = mock(EventLoopGroup.class); private Server mtlsServer; private Server plaintextServer; @@ -191,7 +183,7 @@ public void close_mtlsSuccess() throws Exception { } /** - * Verifies that an {@code EventLoopHoldingChannel}'s {@code newCall} method can be used to + * Verifies that an {@code HandshakerServiceChannel}'s {@code newCall} method can be used to * perform a simple RPC. */ @Test @@ -201,7 +193,7 @@ public void newCall_performSimpleRpcSuccess() { "localhost:" + plaintextServer.getPort(), /* s2aChannelCredentials= */ Optional.empty()); Channel channel = resource.create(); - assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class); + assertThat(channel).isInstanceOf(HandshakerServiceChannel.class); assertThat( SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance())) .isEqualToDefaultInstance(); @@ -214,53 +206,49 @@ public void newCall_mtlsPerformSimpleRpcSuccess() throws Exception { S2AHandshakerServiceChannel.getChannelResource( "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); Channel channel = resource.create(); - assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class); + assertThat(channel).isInstanceOf(HandshakerServiceChannel.class); assertThat( SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance())) .isEqualToDefaultInstance(); } - /** Creates a {@code EventLoopHoldingChannel} instance and verifies its authority. */ + /** Creates a {@code HandshakerServiceChannel} instance and verifies its authority. */ @Test public void authority_success() throws Exception { ManagedChannel channel = new FakeManagedChannel(true); - EventLoopHoldingChannel eventLoopHoldingChannel = - EventLoopHoldingChannel.create(channel, mockEventLoopGroup); + HandshakerServiceChannel eventLoopHoldingChannel = + HandshakerServiceChannel.create(channel); assertThat(eventLoopHoldingChannel.authority()).isEqualTo("FakeManagedChannel"); } /** - * Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} terminates - * successfully. + * Creates and closes a {@code HandshakerServiceChannel} when its {@code ManagedChannel} + * terminates successfully. */ @Test public void close_withDelegateTerminatedSuccess() throws Exception { ManagedChannel channel = new FakeManagedChannel(true); - EventLoopHoldingChannel eventLoopHoldingChannel = - EventLoopHoldingChannel.create(channel, mockEventLoopGroup); + HandshakerServiceChannel eventLoopHoldingChannel = + HandshakerServiceChannel.create(channel); eventLoopHoldingChannel.close(); assertThat(channel.isShutdown()).isTrue(); - verify(mockEventLoopGroup, times(1)) - .shutdownGracefully(0, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); } /** - * Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} does not + * Creates and closes a {@code HandshakerServiceChannel} when its {@code ManagedChannel} does not * terminate successfully. */ @Test public void close_withDelegateTerminatedFailure() throws Exception { ManagedChannel channel = new FakeManagedChannel(false); - EventLoopHoldingChannel eventLoopHoldingChannel = - EventLoopHoldingChannel.create(channel, mockEventLoopGroup); + HandshakerServiceChannel eventLoopHoldingChannel = + HandshakerServiceChannel.create(channel); eventLoopHoldingChannel.close(); assertThat(channel.isShutdown()).isTrue(); - verify(mockEventLoopGroup, times(1)) - .shutdownGracefully(1, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); } /** - * Creates and closes a {@code EventLoopHoldingChannel}, creates a new channel from the same + * Creates and closes a {@code HandshakerServiceChannel}, creates a new channel from the same * resource, and verifies that this second channel is useable. */ @Test @@ -273,7 +261,7 @@ public void create_succeedsAfterCloseIsCalledOnce() throws Exception { resource.close(channelOne); Channel channelTwo = resource.create(); - assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class); + assertThat(channelTwo).isInstanceOf(HandshakerServiceChannel.class); assertThat( SimpleServiceGrpc.newBlockingStub(channelTwo) .unaryRpc(SimpleRequest.getDefaultInstance())) @@ -291,7 +279,7 @@ public void create_mtlsSucceedsAfterCloseIsCalledOnce() throws Exception { resource.close(channelOne); Channel channelTwo = resource.create(); - assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class); + assertThat(channelTwo).isInstanceOf(HandshakerServiceChannel.class); assertThat( SimpleServiceGrpc.newBlockingStub(channelTwo) .unaryRpc(SimpleRequest.getDefaultInstance())) From d8f73e04566fa588889ca1a422e276d71724643c Mon Sep 17 00:00:00 2001 From: Riya Mehta <55350838+rmehta19@users.noreply.github.com> Date: Fri, 20 Sep 2024 15:53:14 -0700 Subject: [PATCH 2/2] s2a: Address comments on PR#11113 (#11534) * Mark S2A public APIs as experimental. * Rename S2AChannelCredentials createBuilder API to newBuilder. * Remove usage of AdvancedTls. * Use InsecureChannelCredentials.create instead of Optional. * Invoke Thread.currentThread().interrupt() in a InterruptedException block. --- .../grpc/s2a/MtlsToS2AChannelCredentials.java | 21 ++++-------- .../io/grpc/s2a/S2AChannelCredentials.java | 12 ++++--- .../channel/S2AHandshakerServiceChannel.java | 27 +++++---------- .../grpc/s2a/handshaker/S2ATrustManager.java | 3 ++ .../s2a/MtlsToS2AChannelCredentialsTest.java | 34 +++++++++---------- .../grpc/s2a/S2AChannelCredentialsTest.java | 24 ++++++------- .../S2AHandshakerServiceChannelTest.java | 25 +++++++------- .../grpc/s2a/handshaker/IntegrationTest.java | 8 ++--- .../S2AProtocolNegotiatorFactoryTest.java | 2 +- .../io/grpc/s2a/handshaker/S2AStubTest.java | 4 +-- 10 files changed, 73 insertions(+), 87 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java b/s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java index 56f612502bf..e8eb01628ed 100644 --- a/s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java +++ b/s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java @@ -21,17 +21,16 @@ import static com.google.common.base.Strings.isNullOrEmpty; import io.grpc.ChannelCredentials; +import io.grpc.ExperimentalApi; import io.grpc.TlsChannelCredentials; -import io.grpc.util.AdvancedTlsX509KeyManager; -import io.grpc.util.AdvancedTlsX509TrustManager; import java.io.File; import java.io.IOException; -import java.security.GeneralSecurityException; /** * Configures an {@code S2AChannelCredentials.Builder} instance with credentials used to establish a * connection with the S2A to support talking to the S2A over mTLS. */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11533") public final class MtlsToS2AChannelCredentials { /** * Creates a {@code S2AChannelCredentials.Builder} builder, that talks to the S2A over mTLS. @@ -42,7 +41,7 @@ public final class MtlsToS2AChannelCredentials { * @param trustBundlePath the path to the trust bundle PEM. * @return a {@code MtlsToS2AChannelCredentials.Builder} instance. */ - public static Builder createBuilder( + public static Builder newBuilder( String s2aAddress, String privateKeyPath, String certChainPath, String trustBundlePath) { checkArgument(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty."); checkArgument(!isNullOrEmpty(privateKeyPath), "privateKeyPath must not be null or empty."); @@ -66,7 +65,7 @@ public static final class Builder { this.trustBundlePath = trustBundlePath; } - public S2AChannelCredentials.Builder build() throws GeneralSecurityException, IOException { + public S2AChannelCredentials.Builder build() throws IOException { checkState(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty."); checkState(!isNullOrEmpty(privateKeyPath), "privateKeyPath must not be null or empty."); checkState(!isNullOrEmpty(certChainPath), "certChainPath must not be null or empty."); @@ -75,19 +74,13 @@ public S2AChannelCredentials.Builder build() throws GeneralSecurityException, IO File certChainFile = new File(certChainPath); File trustBundleFile = new File(trustBundlePath); - AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager(); - keyManager.updateIdentityCredentials(certChainFile, privateKeyFile); - - AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder().build(); - trustManager.updateTrustCredentials(trustBundleFile); - ChannelCredentials channelToS2ACredentials = TlsChannelCredentials.newBuilder() - .keyManager(keyManager) - .trustManager(trustManager) + .keyManager(certChainFile, privateKeyFile) + .trustManager(trustBundleFile) .build(); - return S2AChannelCredentials.createBuilder(s2aAddress) + return S2AChannelCredentials.newBuilder(s2aAddress) .setS2AChannelCredentials(channelToS2ACredentials); } } diff --git a/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java index 8a5f1f51350..ba0f6d72de1 100644 --- a/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java +++ b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java @@ -24,6 +24,8 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.grpc.Channel; import io.grpc.ChannelCredentials; +import io.grpc.ExperimentalApi; +import io.grpc.InsecureChannelCredentials; import io.grpc.internal.ObjectPool; import io.grpc.internal.SharedResourcePool; import io.grpc.netty.InternalNettyChannelCredentials; @@ -31,7 +33,6 @@ import io.grpc.s2a.channel.S2AHandshakerServiceChannel; import io.grpc.s2a.handshaker.S2AIdentity; import io.grpc.s2a.handshaker.S2AProtocolNegotiatorFactory; -import java.util.Optional; import javax.annotation.concurrent.NotThreadSafe; import org.checkerframework.checker.nullness.qual.Nullable; @@ -39,6 +40,7 @@ * Configures gRPC to use S2A for transport security when establishing a secure channel. Only for * use on the client side of a gRPC connection. */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11533") public final class S2AChannelCredentials { /** * Creates a channel credentials builder for establishing an S2A-secured connection. @@ -46,7 +48,7 @@ public final class S2AChannelCredentials { * @param s2aAddress the address of the S2A server used to secure the connection. * @return a {@code S2AChannelCredentials.Builder} instance. */ - public static Builder createBuilder(String s2aAddress) { + public static Builder newBuilder(String s2aAddress) { checkArgument(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty."); return new Builder(s2aAddress); } @@ -56,13 +58,13 @@ public static Builder createBuilder(String s2aAddress) { public static final class Builder { private final String s2aAddress; private ObjectPool s2aChannelPool; - private Optional s2aChannelCredentials; + private ChannelCredentials s2aChannelCredentials; private @Nullable S2AIdentity localIdentity = null; Builder(String s2aAddress) { this.s2aAddress = s2aAddress; this.s2aChannelPool = null; - this.s2aChannelCredentials = Optional.empty(); + this.s2aChannelCredentials = InsecureChannelCredentials.create(); } /** @@ -107,7 +109,7 @@ public Builder setLocalUid(String localUid) { /** Sets the credentials to be used when connecting to the S2A. */ @CanIgnoreReturnValue public Builder setS2AChannelCredentials(ChannelCredentials s2aChannelCredentials) { - this.s2aChannelCredentials = Optional.of(s2aChannelCredentials); + this.s2aChannelCredentials = s2aChannelCredentials; return this; } diff --git a/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java b/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java index 90956907bfe..443ea553e52 100644 --- a/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java +++ b/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java @@ -30,7 +30,6 @@ import io.grpc.internal.SharedResourceHolder.Resource; import io.grpc.netty.NettyChannelBuilder; import java.time.Duration; -import java.util.Optional; import java.util.concurrent.ConcurrentMap; import java.util.logging.Level; import java.util.logging.Logger; @@ -71,8 +70,9 @@ public final class S2AHandshakerServiceChannel { * running at {@code s2aAddress}. */ public static Resource getChannelResource( - String s2aAddress, Optional s2aChannelCredentials) { + String s2aAddress, ChannelCredentials s2aChannelCredentials) { checkNotNull(s2aAddress); + checkNotNull(s2aChannelCredentials); return SHARED_RESOURCE_CHANNELS.computeIfAbsent( s2aAddress, channelResource -> new ChannelResource(s2aAddress, s2aChannelCredentials)); } @@ -84,9 +84,9 @@ public static Resource getChannelResource( */ private static class ChannelResource implements Resource { private final String targetAddress; - private final Optional channelCredentials; + private final ChannelCredentials channelCredentials; - public ChannelResource(String targetAddress, Optional channelCredentials) { + public ChannelResource(String targetAddress, ChannelCredentials channelCredentials) { this.targetAddress = targetAddress; this.channelCredentials = channelCredentials; } @@ -97,21 +97,10 @@ public ChannelResource(String targetAddress, Optional channe */ @Override public Channel create() { - ManagedChannel channel = null; - if (channelCredentials.isPresent()) { - // Create a secure channel. - channel = - NettyChannelBuilder.forTarget(targetAddress, channelCredentials.get()) - .directExecutor() - .build(); - } else { - // Create a plaintext channel. - channel = - NettyChannelBuilder.forTarget(targetAddress) - .directExecutor() - .usePlaintext() - .build(); - } + ManagedChannel channel = + NettyChannelBuilder.forTarget(targetAddress, channelCredentials) + .directExecutor() + .build(); return HandshakerServiceChannel.create(channel); } diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java index fb113bb29cc..aafbb94c047 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java @@ -121,6 +121,9 @@ private void checkPeerTrusted(X509Certificate[] chain, boolean isCheckingClientC try { resp = stub.send(reqBuilder.build()); } catch (IOException | InterruptedException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } throw new CertificateException("Failed to send request to S2A.", e); } if (resp.hasStatus() && resp.getStatus().getCode() != 0) { diff --git a/s2a/src/test/java/io/grpc/s2a/MtlsToS2AChannelCredentialsTest.java b/s2a/src/test/java/io/grpc/s2a/MtlsToS2AChannelCredentialsTest.java index 5ccc522292e..0fc4ecb3268 100644 --- a/s2a/src/test/java/io/grpc/s2a/MtlsToS2AChannelCredentialsTest.java +++ b/s2a/src/test/java/io/grpc/s2a/MtlsToS2AChannelCredentialsTest.java @@ -26,11 +26,11 @@ @RunWith(JUnit4.class) public final class MtlsToS2AChannelCredentialsTest { @Test - public void createBuilder_nullAddress_throwsException() throws Exception { + public void newBuilder_nullAddress_throwsException() throws Exception { assertThrows( IllegalArgumentException.class, () -> - MtlsToS2AChannelCredentials.createBuilder( + MtlsToS2AChannelCredentials.newBuilder( /* s2aAddress= */ null, /* privateKeyPath= */ "src/test/resources/client_key.pem", /* certChainPath= */ "src/test/resources/client_cert.pem", @@ -38,11 +38,11 @@ public void createBuilder_nullAddress_throwsException() throws Exception { } @Test - public void createBuilder_nullPrivateKeyPath_throwsException() throws Exception { + public void newBuilder_nullPrivateKeyPath_throwsException() throws Exception { assertThrows( IllegalArgumentException.class, () -> - MtlsToS2AChannelCredentials.createBuilder( + MtlsToS2AChannelCredentials.newBuilder( /* s2aAddress= */ "s2a_address", /* privateKeyPath= */ null, /* certChainPath= */ "src/test/resources/client_cert.pem", @@ -50,11 +50,11 @@ public void createBuilder_nullPrivateKeyPath_throwsException() throws Exception } @Test - public void createBuilder_nullCertChainPath_throwsException() throws Exception { + public void newBuilder_nullCertChainPath_throwsException() throws Exception { assertThrows( IllegalArgumentException.class, () -> - MtlsToS2AChannelCredentials.createBuilder( + MtlsToS2AChannelCredentials.newBuilder( /* s2aAddress= */ "s2a_address", /* privateKeyPath= */ "src/test/resources/client_key.pem", /* certChainPath= */ null, @@ -62,11 +62,11 @@ public void createBuilder_nullCertChainPath_throwsException() throws Exception { } @Test - public void createBuilder_nullTrustBundlePath_throwsException() throws Exception { + public void newBuilder_nullTrustBundlePath_throwsException() throws Exception { assertThrows( IllegalArgumentException.class, () -> - MtlsToS2AChannelCredentials.createBuilder( + MtlsToS2AChannelCredentials.newBuilder( /* s2aAddress= */ "s2a_address", /* privateKeyPath= */ "src/test/resources/client_key.pem", /* certChainPath= */ "src/test/resources/client_cert.pem", @@ -74,11 +74,11 @@ public void createBuilder_nullTrustBundlePath_throwsException() throws Exception } @Test - public void createBuilder_emptyAddress_throwsException() throws Exception { + public void newBuilder_emptyAddress_throwsException() throws Exception { assertThrows( IllegalArgumentException.class, () -> - MtlsToS2AChannelCredentials.createBuilder( + MtlsToS2AChannelCredentials.newBuilder( /* s2aAddress= */ "", /* privateKeyPath= */ "src/test/resources/client_key.pem", /* certChainPath= */ "src/test/resources/client_cert.pem", @@ -86,11 +86,11 @@ public void createBuilder_emptyAddress_throwsException() throws Exception { } @Test - public void createBuilder_emptyPrivateKeyPath_throwsException() throws Exception { + public void newBuilder_emptyPrivateKeyPath_throwsException() throws Exception { assertThrows( IllegalArgumentException.class, () -> - MtlsToS2AChannelCredentials.createBuilder( + MtlsToS2AChannelCredentials.newBuilder( /* s2aAddress= */ "s2a_address", /* privateKeyPath= */ "", /* certChainPath= */ "src/test/resources/client_cert.pem", @@ -98,11 +98,11 @@ public void createBuilder_emptyPrivateKeyPath_throwsException() throws Exception } @Test - public void createBuilder_emptyCertChainPath_throwsException() throws Exception { + public void newBuilder_emptyCertChainPath_throwsException() throws Exception { assertThrows( IllegalArgumentException.class, () -> - MtlsToS2AChannelCredentials.createBuilder( + MtlsToS2AChannelCredentials.newBuilder( /* s2aAddress= */ "s2a_address", /* privateKeyPath= */ "src/test/resources/client_key.pem", /* certChainPath= */ "", @@ -110,11 +110,11 @@ public void createBuilder_emptyCertChainPath_throwsException() throws Exception } @Test - public void createBuilder_emptyTrustBundlePath_throwsException() throws Exception { + public void newBuilder_emptyTrustBundlePath_throwsException() throws Exception { assertThrows( IllegalArgumentException.class, () -> - MtlsToS2AChannelCredentials.createBuilder( + MtlsToS2AChannelCredentials.newBuilder( /* s2aAddress= */ "s2a_address", /* privateKeyPath= */ "src/test/resources/client_key.pem", /* certChainPath= */ "src/test/resources/client_cert.pem", @@ -124,7 +124,7 @@ public void createBuilder_emptyTrustBundlePath_throwsException() throws Exceptio @Test public void build_s2AChannelCredentials_success() throws Exception { assertThat( - MtlsToS2AChannelCredentials.createBuilder( + MtlsToS2AChannelCredentials.newBuilder( /* s2aAddress= */ "s2a_address", /* privateKeyPath= */ "src/test/resources/client_key.pem", /* certChainPath= */ "src/test/resources/client_cert.pem", diff --git a/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java b/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java index a6133ed0af8..e766aa3f145 100644 --- a/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java +++ b/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java @@ -30,40 +30,40 @@ @RunWith(JUnit4.class) public final class S2AChannelCredentialsTest { @Test - public void createBuilder_nullArgument_throwsException() throws Exception { - assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.createBuilder(null)); + public void newBuilder_nullArgument_throwsException() throws Exception { + assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.newBuilder(null)); } @Test - public void createBuilder_emptyAddress_throwsException() throws Exception { - assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.createBuilder("")); + public void newBuilder_emptyAddress_throwsException() throws Exception { + assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.newBuilder("")); } @Test public void setLocalSpiffeId_nullArgument_throwsException() throws Exception { assertThrows( NullPointerException.class, - () -> S2AChannelCredentials.createBuilder("s2a_address").setLocalSpiffeId(null)); + () -> S2AChannelCredentials.newBuilder("s2a_address").setLocalSpiffeId(null)); } @Test public void setLocalHostname_nullArgument_throwsException() throws Exception { assertThrows( NullPointerException.class, - () -> S2AChannelCredentials.createBuilder("s2a_address").setLocalHostname(null)); + () -> S2AChannelCredentials.newBuilder("s2a_address").setLocalHostname(null)); } @Test public void setLocalUid_nullArgument_throwsException() throws Exception { assertThrows( NullPointerException.class, - () -> S2AChannelCredentials.createBuilder("s2a_address").setLocalUid(null)); + () -> S2AChannelCredentials.newBuilder("s2a_address").setLocalUid(null)); } @Test public void build_withLocalSpiffeId_succeeds() throws Exception { assertThat( - S2AChannelCredentials.createBuilder("s2a_address") + S2AChannelCredentials.newBuilder("s2a_address") .setLocalSpiffeId("spiffe://test") .build()) .isNotNull(); @@ -72,7 +72,7 @@ public void build_withLocalSpiffeId_succeeds() throws Exception { @Test public void build_withLocalHostname_succeeds() throws Exception { assertThat( - S2AChannelCredentials.createBuilder("s2a_address") + S2AChannelCredentials.newBuilder("s2a_address") .setLocalHostname("local_hostname") .build()) .isNotNull(); @@ -80,20 +80,20 @@ public void build_withLocalHostname_succeeds() throws Exception { @Test public void build_withLocalUid_succeeds() throws Exception { - assertThat(S2AChannelCredentials.createBuilder("s2a_address").setLocalUid("local_uid").build()) + assertThat(S2AChannelCredentials.newBuilder("s2a_address").setLocalUid("local_uid").build()) .isNotNull(); } @Test public void build_withNoLocalIdentity_succeeds() throws Exception { - assertThat(S2AChannelCredentials.createBuilder("s2a_address").build()) + assertThat(S2AChannelCredentials.newBuilder("s2a_address").build()) .isNotNull(); } @Test public void build_withTlsChannelCredentials_succeeds() throws Exception { assertThat( - S2AChannelCredentials.createBuilder("s2a_address") + S2AChannelCredentials.newBuilder("s2a_address") .setLocalSpiffeId("spiffe://test") .setS2AChannelCredentials(getTlsChannelCredentials()) .build()) diff --git a/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java b/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java index dc5909442bf..7845e7c3bcb 100644 --- a/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java +++ b/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java @@ -24,6 +24,7 @@ import io.grpc.Channel; import io.grpc.ChannelCredentials; import io.grpc.ClientCall; +import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; import io.grpc.MethodDescriptor; import io.grpc.Server; @@ -42,7 +43,6 @@ import io.grpc.testing.protobuf.SimpleResponse; import io.grpc.testing.protobuf.SimpleServiceGrpc; import java.io.File; -import java.util.Optional; import java.util.concurrent.TimeUnit; import org.junit.Before; import org.junit.ClassRule; @@ -74,7 +74,7 @@ public void getChannelResource_success() { Resource resource = S2AHandshakerServiceChannel.getChannelResource( "localhost:" + plaintextServer.getPort(), - /* s2aChannelCredentials= */ Optional.empty()); + InsecureChannelCredentials.create()); assertThat(resource.toString()).isEqualTo("grpc-s2a-channel"); } @@ -96,11 +96,11 @@ public void getChannelResource_twoEqualChannels() { Resource resource = S2AHandshakerServiceChannel.getChannelResource( "localhost:" + plaintextServer.getPort(), - /* s2aChannelCredentials= */ Optional.empty()); + InsecureChannelCredentials.create()); Resource resourceTwo = S2AHandshakerServiceChannel.getChannelResource( "localhost:" + plaintextServer.getPort(), - /* s2aChannelCredentials= */ Optional.empty()); + InsecureChannelCredentials.create()); assertThat(resource).isEqualTo(resourceTwo); } @@ -125,10 +125,10 @@ public void getChannelResource_twoDistinctChannels() { Resource resource = S2AHandshakerServiceChannel.getChannelResource( "localhost:" + plaintextServer.getPort(), - /* s2aChannelCredentials= */ Optional.empty()); + InsecureChannelCredentials.create()); Resource resourceTwo = S2AHandshakerServiceChannel.getChannelResource( - "localhost:" + Utils.pickUnusedPort(), /* s2aChannelCredentials= */ Optional.empty()); + "localhost:" + Utils.pickUnusedPort(), InsecureChannelCredentials.create()); assertThat(resourceTwo).isNotEqualTo(resource); } @@ -153,7 +153,7 @@ public void close_success() { Resource resource = S2AHandshakerServiceChannel.getChannelResource( "localhost:" + plaintextServer.getPort(), - /* s2aChannelCredentials= */ Optional.empty()); + InsecureChannelCredentials.create()); Channel channel = resource.create(); resource.close(channel); StatusRuntimeException expected = @@ -191,7 +191,7 @@ public void newCall_performSimpleRpcSuccess() { Resource resource = S2AHandshakerServiceChannel.getChannelResource( "localhost:" + plaintextServer.getPort(), - /* s2aChannelCredentials= */ Optional.empty()); + InsecureChannelCredentials.create()); Channel channel = resource.create(); assertThat(channel).isInstanceOf(HandshakerServiceChannel.class); assertThat( @@ -256,7 +256,7 @@ public void create_succeedsAfterCloseIsCalledOnce() throws Exception { Resource resource = S2AHandshakerServiceChannel.getChannelResource( "localhost:" + plaintextServer.getPort(), - /* s2aChannelCredentials= */ Optional.empty()); + InsecureChannelCredentials.create()); Channel channelOne = resource.create(); resource.close(channelOne); @@ -308,15 +308,14 @@ private static Server createPlaintextServer() { ServerBuilder.forPort(Utils.pickUnusedPort()).addService(service).build()); } - private static Optional getTlsChannelCredentials() throws Exception { + private static ChannelCredentials getTlsChannelCredentials() throws Exception { File clientCert = new File("src/test/resources/client_cert.pem"); File clientKey = new File("src/test/resources/client_key.pem"); File rootCert = new File("src/test/resources/root_cert.pem"); - return Optional.of( - TlsChannelCredentials.newBuilder() + return TlsChannelCredentials.newBuilder() .keyManager(clientCert, clientKey) .trustManager(rootCert) - .build()); + .build(); } private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java index 19dda7a19e4..bae58f2f9ec 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java @@ -186,7 +186,7 @@ public void tearDown() throws Exception { @Test public void clientCommunicateUsingS2ACredentials_succeeds() throws Exception { ChannelCredentials credentials = - S2AChannelCredentials.createBuilder(s2aAddress).setLocalSpiffeId("test-spiffe-id").build(); + S2AChannelCredentials.newBuilder(s2aAddress).setLocalSpiffeId("test-spiffe-id").build(); ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build(); assertThat(doUnaryRpc(channel)).isTrue(); @@ -194,7 +194,7 @@ public void clientCommunicateUsingS2ACredentials_succeeds() throws Exception { @Test public void clientCommunicateUsingS2ACredentialsNoLocalIdentity_succeeds() throws Exception { - ChannelCredentials credentials = S2AChannelCredentials.createBuilder(s2aAddress).build(); + ChannelCredentials credentials = S2AChannelCredentials.newBuilder(s2aAddress).build(); ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build(); assertThat(doUnaryRpc(channel)).isTrue(); @@ -203,7 +203,7 @@ public void clientCommunicateUsingS2ACredentialsNoLocalIdentity_succeeds() throw @Test public void clientCommunicateUsingMtlsToS2ACredentials_succeeds() throws Exception { ChannelCredentials credentials = - MtlsToS2AChannelCredentials.createBuilder( + MtlsToS2AChannelCredentials.newBuilder( /* s2aAddress= */ mtlsS2AAddress, /* privateKeyPath= */ "src/test/resources/client_key.pem", /* certChainPath= */ "src/test/resources/client_cert.pem", @@ -218,7 +218,7 @@ public void clientCommunicateUsingMtlsToS2ACredentials_succeeds() throws Excepti @Test public void clientCommunicateUsingS2ACredentials_s2AdelayStart_succeeds() throws Exception { - ChannelCredentials credentials = S2AChannelCredentials.createBuilder(s2aDelayAddress).build(); + ChannelCredentials credentials = S2AChannelCredentials.newBuilder(s2aDelayAddress).build(); ManagedChannel channel = Grpc.newChannelBuilder(serverAddress, credentials).build(); FutureTask rpc = new FutureTask<>(() -> doUnaryRpc(channel)); diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java index f130e52aac7..404910e8be0 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java @@ -115,7 +115,7 @@ public void createProtocolNegotiator_nullArgument() throws Exception { S2AGrpcChannelPool.create( SharedResourcePool.forResource( S2AHandshakerServiceChannel.getChannelResource( - "localhost:8080", /* s2aChannelCredentials= */ Optional.empty()))); + "localhost:8080", InsecureChannelCredentials.create()))); NullPointerTester tester = new NullPointerTester() diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java index bb90be12b6a..47fd154d949 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java @@ -21,13 +21,13 @@ import static org.junit.Assert.assertThrows; import com.google.common.truth.Expect; +import io.grpc.InsecureChannelCredentials; import io.grpc.internal.SharedResourcePool; import io.grpc.s2a.channel.S2AChannelPool; import io.grpc.s2a.channel.S2AGrpcChannelPool; import io.grpc.s2a.channel.S2AHandshakerServiceChannel; import io.grpc.stub.StreamObserver; import java.io.IOException; -import java.util.Optional; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -55,7 +55,7 @@ public void send_receiveOkStatus() throws Exception { S2AGrpcChannelPool.create( SharedResourcePool.forResource( S2AHandshakerServiceChannel.getChannelResource( - S2A_ADDRESS, /* s2aChannelCredentials= */ Optional.empty()))); + S2A_ADDRESS, InsecureChannelCredentials.create()))); S2AServiceGrpc.S2AServiceStub serviceStub = S2AServiceGrpc.newStub(channelPool.getChannel()); S2AStub newStub = S2AStub.newInstance(serviceStub);