Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

s2a: Add S2AStub cleanup handler. #11600

Merged
merged 7 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.netty.channel.ChannelHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.util.AsciiString;
import java.util.Optional;
import java.util.concurrent.Executor;

/**
Expand All @@ -40,9 +41,10 @@
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
*/
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext,
ObjectPool<? extends Executor> executorPool) {
ObjectPool<? extends Executor> executorPool,
Optional<Runnable> handshakeCompleteRunnable) {
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext,
executorPool);
executorPool, handshakeCompleteRunnable);
final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {

@Override
Expand Down Expand Up @@ -70,7 +72,7 @@
* may happen immediately, even before the TLS Handshake is complete.
*/
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) {
return tls(sslContext, null);
return tls(sslContext, null, Optional.empty());
}

/**
Expand Down Expand Up @@ -167,7 +169,8 @@
public static ChannelHandler clientTlsHandler(
ChannelHandler next, SslContext sslContext, String authority,
ChannelLogger negotiationLogger) {
return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger);
return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger,
Optional.empty());

Check warning on line 173 in netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java

View check run for this annotation

Codecov / codecov/patch

netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java#L172-L173

Added lines #L172 - L173 were not covered by tests
}

public static class ProtocolNegotiationHandler
Expand Down
3 changes: 2 additions & 1 deletion netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -604,7 +605,7 @@ static ProtocolNegotiator createProtocolNegotiatorByType(
case PLAINTEXT_UPGRADE:
return ProtocolNegotiators.plaintextUpgrade();
case TLS:
return ProtocolNegotiators.tls(sslContext, executorPool);
return ProtocolNegotiators.tls(sslContext, executorPool, Optional.empty());
default:
throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType);
}
Expand Down
24 changes: 18 additions & 6 deletions netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import java.nio.channels.ClosedChannelException;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.logging.Level;
Expand Down Expand Up @@ -543,16 +544,18 @@
static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator {

public ClientTlsProtocolNegotiator(SslContext sslContext,
ObjectPool<? extends Executor> executorPool) {
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
this.sslContext = checkNotNull(sslContext, "sslContext");
this.executorPool = executorPool;
if (this.executorPool != null) {
this.executor = this.executorPool.getObject();
}
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
}

private final SslContext sslContext;
private final ObjectPool<? extends Executor> executorPool;
private final Optional<Runnable> handshakeCompleteRunnable;
private Executor executor;

@Override
Expand All @@ -565,7 +568,7 @@
ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler);
ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger();
ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(),
this.executor, negotiationLogger);
this.executor, negotiationLogger, handshakeCompleteRunnable);
return new WaitUntilActiveHandler(cth, negotiationLogger);
}

Expand All @@ -583,15 +586,18 @@
private final String host;
private final int port;
private Executor executor;
private final Optional<Runnable> handshakeCompleteRunnable;

ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority,
Executor executor, ChannelLogger negotiationLogger) {
Executor executor, ChannelLogger negotiationLogger,
Optional<Runnable> handshakeCompleteRunnable) {
super(next, negotiationLogger);
this.sslContext = checkNotNull(sslContext, "sslContext");
HostPort hostPort = parseAuthority(authority);
this.host = hostPort.host;
this.port = hostPort.port;
this.executor = executor;
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
}

@Override
Expand Down Expand Up @@ -634,6 +640,9 @@
.withCause(t)
.asRuntimeException();
}
if (handshakeCompleteRunnable.isPresent()) {
handshakeCompleteRunnable.get().run();

Check warning on line 644 in netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java

View check run for this annotation

Codecov / codecov/patch

netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java#L644

Added line #L644 was not covered by tests
}
ctx.fireExceptionCaught(t);
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
Expand All @@ -649,6 +658,9 @@
.set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
.build();
replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security));
if (handshakeCompleteRunnable.isPresent()) {
handshakeCompleteRunnable.get().run();

Check warning on line 662 in netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java

View check run for this annotation

Codecov / codecov/patch

netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java#L662

Added line #L662 was not covered by tests
}
fireProtocolNegotiationEvent(ctx);
}
}
Expand Down Expand Up @@ -683,8 +695,8 @@
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
*/
public static ProtocolNegotiator tls(SslContext sslContext,
ObjectPool<? extends Executor> executorPool) {
return new ClientTlsProtocolNegotiator(sslContext, executorPool);
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable);
}

/**
Expand All @@ -693,7 +705,7 @@
* may happen immediately, even before the TLS Handshake is complete.
*/
public static ProtocolNegotiator tls(SslContext sslContext) {
return tls(sslContext, null);
return tls(sslContext, null, Optional.empty());
}

public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -766,7 +767,8 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception {
.trustManager(caCert)
.keyManager(clientCert, clientKey)
.build();
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool);
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool,
Optional.empty());
// after starting the client, the Executor in the client pool should be used
assertEquals(true, clientExecutorPool.isInUse());
final NettyClientTransport transport = newTransport(negotiator);
Expand Down
12 changes: 7 additions & 5 deletions netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -876,7 +877,7 @@ public String applicationProtocol() {
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);

ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", elg, noopLogger);
"authority", elg, noopLogger, Optional.empty());
pipeline.addLast(handler);
pipeline.replace(SslHandler.class, null, goodSslHandler);
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
Expand Down Expand Up @@ -914,7 +915,7 @@ public String applicationProtocol() {
.applicationProtocolConfig(apn).build();

ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", elg, noopLogger);
"authority", elg, noopLogger, Optional.empty());
pipeline.addLast(handler);
pipeline.replace(SslHandler.class, null, goodSslHandler);
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
Expand All @@ -938,7 +939,7 @@ public String applicationProtocol() {
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);

ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", elg, noopLogger);
"authority", elg, noopLogger, Optional.empty());
pipeline.addLast(handler);

final AtomicReference<Throwable> error = new AtomicReference<>();
Expand Down Expand Up @@ -966,7 +967,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
@Test
public void clientTlsHandler_closeDuringNegotiation() throws Exception {
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", null, noopLogger);
"authority", null, noopLogger, Optional.empty());
pipeline.addLast(new WriteBufferingAndExceptionHandler(handler));
ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);

Expand Down Expand Up @@ -1228,7 +1229,8 @@ public void clientTlsHandler_firesNegotiation() throws Exception {
serverSslContext = GrpcSslContexts.forServer(server1Chain, server1Key).build();
}
FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, null);
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext,
null, Optional.empty());
WriteBufferingAndExceptionHandler clientWbaeh =
new WriteBufferingAndExceptionHandler(pn.newHandler(gh));

Expand Down
14 changes: 13 additions & 1 deletion s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel;
import io.grpc.s2a.internal.handshaker.S2AIdentity;
import io.grpc.s2a.internal.handshaker.S2AProtocolNegotiatorFactory;
import io.grpc.s2a.internal.handshaker.S2AStub;
import javax.annotation.concurrent.NotThreadSafe;
import org.checkerframework.checker.nullness.qual.Nullable;

Expand Down Expand Up @@ -59,6 +60,7 @@ public static final class Builder {
private final String s2aAddress;
private final ChannelCredentials s2aChannelCredentials;
private @Nullable S2AIdentity localIdentity = null;
private @Nullable S2AStub stub = null;

Builder(String s2aAddress, ChannelCredentials s2aChannelCredentials) {
this.s2aAddress = s2aAddress;
Expand Down Expand Up @@ -104,6 +106,16 @@ public Builder setLocalUid(String localUid) {
return this;
}

/**
* Sets the stub to use to communicate with S2A. This is only used for testing that the
* stream to S2A gets closed.
*/
public Builder setStub(S2AStub stub) {
checkNotNull(stub);
this.stub = stub;
return this;
}

public ChannelCredentials build() {
return InternalNettyChannelCredentials.create(buildProtocolNegotiatorFactory());
}
Expand All @@ -113,7 +125,7 @@ InternalProtocolNegotiator.ClientFactory buildProtocolNegotiatorFactory() {
SharedResourcePool.forResource(
S2AHandshakerServiceChannel.getChannelResource(s2aAddress, s2aChannelCredentials));
checkNotNull(s2aChannelPool, "s2aChannelPool");
return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool);
return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool, stub);
}
}

Expand Down
Loading