Skip to content

Commit

Permalink
Add SslConfig to onSecurityHandshake (#3074)
Browse files Browse the repository at this point in the history
Motivation
----------
This changeset adds ConnectionInfo if available to the
onSecurityHandshake, similar to a change made to
onTransportHandshakeComplete. This allows to get
information about the SSL config from the argument,
without having to get it as an instance variable
from onTransportHandshakeComplete.
  • Loading branch information
daschl authored Oct 3, 2024
1 parent a3a4212 commit 94c1bbb
Show file tree
Hide file tree
Showing 17 changed files with 137 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.servicetalk.transport.api.ConnectionObserver.ReadObserver;
import io.servicetalk.transport.api.ConnectionObserver.StreamObserver;
import io.servicetalk.transport.api.ConnectionObserver.WriteObserver;
import io.servicetalk.transport.api.SslConfig;
import io.servicetalk.transport.api.TransportObserver;
import io.servicetalk.transport.netty.internal.NoopTransportObserver;
import io.servicetalk.transport.netty.internal.NoopTransportObserver.NoopProxyConnectObserver;
Expand Down Expand Up @@ -206,7 +207,7 @@ public ProxyConnectObserver onProxyConnect(final Object connectMsg) {
}

@Override
public SecurityHandshakeObserver onSecurityHandshake() {
public SecurityHandshakeObserver onSecurityHandshake(final SslConfig config) {
// AsyncContext is unknown at this point because this event is triggered by network
return NoopSecurityHandshakeObserver.INSTANCE;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import io.servicetalk.transport.api.RetryableException;
import io.servicetalk.transport.api.ServerContext;
import io.servicetalk.transport.api.ServerSslConfigBuilder;
import io.servicetalk.transport.api.SslConfig;
import io.servicetalk.transport.api.TransportObserver;
import io.servicetalk.transport.netty.internal.ExecutionContextExtension;
import io.servicetalk.transport.netty.internal.NoopTransportObserver.NoopDataObserver;
Expand Down Expand Up @@ -161,7 +162,8 @@ private void initMocks() {
securityHandshakeObserver = mock(SecurityHandshakeObserver.class, "securityHandshakeObserver");
when(transportObserver.onNewConnection(any(), any())).thenReturn(connectionObserver);
when(connectionObserver.onProxyConnect(any())).thenReturn(proxyConnectObserver);
lenient().when(connectionObserver.onSecurityHandshake()).thenReturn(securityHandshakeObserver);
lenient().when(connectionObserver.onSecurityHandshake(any(SslConfig.class)))
.thenReturn(securityHandshakeObserver);
lenient().when(connectionObserver.connectionEstablished(any())).thenReturn(NoopDataObserver.INSTANCE);
lenient().when(connectionObserver.multiplexedConnectionEstablished(any()))
.thenReturn(NoopMultiplexedObserver.INSTANCE);
Expand Down Expand Up @@ -230,7 +232,7 @@ private void assertResponse(HttpResponse httpResponse, HttpProtocolVersion expec
assertTargetAddress();

verifyProxyConnectComplete();
order.verify(connectionObserver).onSecurityHandshake();
order.verify(connectionObserver).onSecurityHandshake(any(SslConfig.class));
order.verify(securityHandshakeObserver).handshakeComplete(any());
if (expectedVersion.major() > 1) {
order.verify(connectionObserver).multiplexedConnectionEstablished(any());
Expand Down Expand Up @@ -398,7 +400,7 @@ void testHandshakeFailed(List<HttpProtocol> protocols) throws Exception {
assertTargetAddress();

verifyProxyConnectComplete();
order.verify(connectionObserver).onSecurityHandshake();
order.verify(connectionObserver).onSecurityHandshake(any(SslConfig.class));
order.verify(securityHandshakeObserver).handshakeFailed(e);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.servicetalk.transport.api.ConnectionObserver.SecurityHandshakeObserver;
import io.servicetalk.transport.api.ServerContext;
import io.servicetalk.transport.api.ServerSslConfigBuilder;
import io.servicetalk.transport.api.SslConfig;
import io.servicetalk.transport.api.TransportObserver;
import io.servicetalk.transport.netty.internal.ExecutionContextExtension;
import io.servicetalk.transport.netty.internal.NoopTransportObserver.NoopDataObserver;
Expand Down Expand Up @@ -94,7 +95,8 @@ class SecurityHandshakeObserverTest {
clientConnectionObserver = mock(ConnectionObserver.class, "clientConnectionObserver");
clientSecurityHandshakeObserver = mock(SecurityHandshakeObserver.class, "clientSecurityHandshakeObserver");
when(clientTransportObserver.onNewConnection(any(), any())).thenReturn(clientConnectionObserver);
when(clientConnectionObserver.onSecurityHandshake()).thenReturn(clientSecurityHandshakeObserver);
when(clientConnectionObserver.onSecurityHandshake(any(SslConfig.class)))
.thenReturn(clientSecurityHandshakeObserver);
when(clientConnectionObserver.connectionEstablished(any(ConnectionInfo.class)))
.thenReturn(NoopDataObserver.INSTANCE);
when(clientConnectionObserver.multiplexedConnectionEstablished(any(ConnectionInfo.class)))
Expand All @@ -107,7 +109,8 @@ class SecurityHandshakeObserverTest {
serverConnectionObserver = mock(ConnectionObserver.class, "serverConnectionObserver");
serverSecurityHandshakeObserver = mock(SecurityHandshakeObserver.class, "serverSecurityHandshakeObserver");
when(serverTransportObserver.onNewConnection(any(), any())).thenReturn(serverConnectionObserver);
when(serverConnectionObserver.onSecurityHandshake()).thenReturn(serverSecurityHandshakeObserver);
when(serverConnectionObserver.onSecurityHandshake(any(SslConfig.class)))
.thenReturn(serverSecurityHandshakeObserver);
when(serverConnectionObserver.connectionEstablished(any(ConnectionInfo.class)))
.thenReturn(NoopDataObserver.INSTANCE);
when(serverConnectionObserver.multiplexedConnectionEstablished(any(ConnectionInfo.class)))
Expand Down Expand Up @@ -177,7 +180,7 @@ void optionalSslWithPlaintextDoesNotTriggerSecurityHandshake() throws Exception
assertThat(client.request(client.get(SVC_ECHO)).status(), is(OK));
}

verify(serverConnectionObserver, never()).onSecurityHandshake();
verify(serverConnectionObserver, never()).onSecurityHandshake(any(SslConfig.class));
verifyNoMoreInteractions(serverSecurityHandshakeObserver);
}

Expand Down Expand Up @@ -221,7 +224,7 @@ private static void verifyObservers(InOrder order, TransportObserver transportOb
HttpProtocol expectedProtocol, boolean failHandshake) {
order.verify(transportObserver).onNewConnection(any(), any());
order.verify(connectionObserver).onTransportHandshakeComplete(any());
order.verify(connectionObserver).onSecurityHandshake();
order.verify(connectionObserver).onSecurityHandshake(any(SslConfig.class));
if (failHandshake) {
ArgumentCaptor<Throwable> exceptionCaptor = ArgumentCaptor.forClass(Throwable.class);
order.verify(securityHandshakeObserver).handshakeFailed(exceptionCaptor.capture());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.servicetalk.transport.api.HostAndPort;
import io.servicetalk.transport.api.ServerContext;
import io.servicetalk.transport.api.ServerSslConfigBuilder;
import io.servicetalk.transport.api.SslConfig;
import io.servicetalk.transport.api.SslProvider;
import io.servicetalk.transport.api.TransportObserver;
import io.servicetalk.transport.netty.internal.NoopTransportObserver;
Expand Down Expand Up @@ -155,7 +156,7 @@ public void onDataWrite(final int size) {
}

@Override
public SecurityHandshakeObserver onSecurityHandshake() {
public SecurityHandshakeObserver onSecurityHandshake(final SslConfig config) {
inHandshake = true;
return new SecurityHandshakeObserver() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ public TcpClientChannelInitializer(final ReadOnlyTcpClientConfig config,
channel -> new TcpConnectionInfo(channel,
// ExecutionContext can be null if users used deprecated ctor
executionContext == null ? null : channelExecutionContext(channel, executionContext),
sslConfig, config.idleTimeoutMs()),
sslConfig != null && !deferSslHandler, true));
sslConfig, config.idleTimeoutMs()), true, deferSslHandler ? null : sslConfig));
}

if (config.idleTimeoutMs() > 0L) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.servicetalk.transport.api.ConnectionInfo;
import io.servicetalk.transport.api.ConnectionObserver;
import io.servicetalk.transport.api.ExecutionContext;
import io.servicetalk.transport.api.ServerSslConfig;
import io.servicetalk.transport.netty.internal.ChannelInitializer;
import io.servicetalk.transport.netty.internal.ConnectionObserverInitializer;
import io.servicetalk.transport.netty.internal.IdleTimeoutInitializer;
Expand Down Expand Up @@ -71,12 +72,12 @@ public TcpServerChannelInitializer(final ReadOnlyTcpServerConfig config,
delegate = delegate.andThen(new TransportConfigInitializer(config.transportConfig()));

if (observer != NoopConnectionObserver.INSTANCE) {
final ServerSslConfig sslConfig = config.sslConfig();
delegate = delegate.andThen(new ConnectionObserverInitializer(observer,
channel -> new TcpConnectionInfo(channel,
// ExecutionContext can be null if users used deprecated ctor
executionContext == null ? null : channelExecutionContext(channel, executionContext),
config.sslConfig(), config.idleTimeoutMs()),
config.sslConfig() != null, false));
sslConfig, config.idleTimeoutMs()), false, sslConfig));
}

if (config.idleTimeoutMs() > 0L) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.servicetalk.transport.api.ConnectionObserver.WriteObserver;
import io.servicetalk.transport.api.ServerSslConfig;
import io.servicetalk.transport.api.ServerSslConfigBuilder;
import io.servicetalk.transport.api.SslConfig;
import io.servicetalk.transport.api.SslProvider;
import io.servicetalk.transport.api.TransportObserver;

Expand Down Expand Up @@ -68,7 +69,8 @@ protected AbstractTransportObserverTest() {
clientReadObserver = mock(ReadObserver.class, "clientReadObserver");
clientWriteObserver = mock(WriteObserver.class, "clientWriteObserver");
when(clientTransportObserver.onNewConnection(any(), any())).thenReturn(clientConnectionObserver);
when(clientConnectionObserver.onSecurityHandshake()).thenReturn(clientSecurityHandshakeObserver);
when(clientConnectionObserver.onSecurityHandshake(any(SslConfig.class)))
.thenReturn(clientSecurityHandshakeObserver);
when(clientConnectionObserver.connectionEstablished(any(ConnectionInfo.class))).thenReturn(clientDataObserver);
when(clientDataObserver.onNewRead()).thenReturn(clientReadObserver);
when(clientDataObserver.onNewWrite()).thenReturn(clientWriteObserver);
Expand All @@ -80,7 +82,8 @@ protected AbstractTransportObserverTest() {
serverReadObserver = mock(ReadObserver.class, "serverReadObserver");
serverWriteObserver = mock(WriteObserver.class, "serverWriteObserver");
when(serverTransportObserver.onNewConnection(any(), any())).thenReturn(serverConnectionObserver);
when(serverConnectionObserver.onSecurityHandshake()).thenReturn(serverSecurityHandshakeObserver);
when(serverConnectionObserver.onSecurityHandshake(any(SslConfig.class)))
.thenReturn(serverSecurityHandshakeObserver);
when(serverConnectionObserver.connectionEstablished(any(ConnectionInfo.class))).thenReturn(serverDataObserver);
when(serverDataObserver.onNewRead()).thenReturn(serverReadObserver);
when(serverDataObserver.onNewWrite()).thenReturn(serverWriteObserver);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.servicetalk.test.resources.DefaultTestCerts;
import io.servicetalk.transport.api.ClientSslConfigBuilder;
import io.servicetalk.transport.api.ServerSslConfigBuilder;
import io.servicetalk.transport.api.SslConfig;
import io.servicetalk.transport.api.SslProvider;
import io.servicetalk.transport.netty.internal.NettyConnection;

Expand Down Expand Up @@ -178,7 +179,7 @@ void testSslErrors(ErrorReason errorReason,
verify(serverConnectionObserver, await()).onTransportHandshakeComplete(any());
switch (errorReason) {
case SECURE_CLIENT_TO_PLAIN_SERVER:
verify(clientConnectionObserver, await()).onSecurityHandshake();
verify(clientConnectionObserver, await()).onSecurityHandshake(any(SslConfig.class));
if (clientProvider == JDK) {
verify(clientSecurityHandshakeObserver, await()).handshakeFailed(any(SSLProtocolException.class));
verify(clientConnectionObserver, await()).connectionClosed(any(SSLProtocolException.class));
Expand All @@ -189,7 +190,7 @@ void testSslErrors(ErrorReason errorReason,
serverConnectionClosed.await();
break;
case PLAIN_CLIENT_TO_SECURE_SERVER:
verify(serverConnectionObserver, await()).onSecurityHandshake();
verify(serverConnectionObserver, await()).onSecurityHandshake(any(SslConfig.class));
clientConnected.await();
connection.get().write(from(DEFAULT_ALLOCATOR.fromAscii("Hello"))).toFuture().get();
verify(serverSecurityHandshakeObserver, await()).handshakeFailed(any(NotSslRecordException.class));
Expand All @@ -202,8 +203,8 @@ void testSslErrors(ErrorReason errorReason,
case MISSED_CLIENT_CERTIFICATE:
case NOT_MATCHING_PROTOCOLS:
case NOT_MATCHING_CIPHERS:
verify(clientConnectionObserver, await()).onSecurityHandshake();
verify(serverConnectionObserver, await()).onSecurityHandshake();
verify(clientConnectionObserver, await()).onSecurityHandshake(any(SslConfig.class));
verify(serverConnectionObserver, await()).onSecurityHandshake(any(SslConfig.class));
verify(clientSecurityHandshakeObserver, await()).handshakeFailed(any(SSLException.class));
verify(clientConnectionObserver, await()).connectionClosed(any(SSLException.class));
verify(serverSecurityHandshakeObserver, await()).handshakeFailed(any(SSLException.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import io.servicetalk.buffer.api.Buffer;
import io.servicetalk.transport.api.ConnectionInfo;
import io.servicetalk.transport.api.SslConfig;
import io.servicetalk.transport.api.SslProvider;
import io.servicetalk.transport.netty.internal.NettyConnection;

Expand Down Expand Up @@ -82,8 +83,8 @@ void testConnectionObserverEvents(SslProvider clientProvider, SslProvider server
verify(serverConnectionObserver, await()).connectionEstablished(any(ConnectionInfo.class));

// handshake starts
verify(clientConnectionObserver).onSecurityHandshake();
verify(serverConnectionObserver).onSecurityHandshake();
verify(clientConnectionObserver).onSecurityHandshake(any(SslConfig.class));
verify(serverConnectionObserver).onSecurityHandshake(any(SslConfig.class));
// handshake progress
verify(clientConnectionObserver, atLeastOnce()).onDataRead(anyInt());
verify(clientConnectionObserver, atLeastOnce()).onDataWrite(anyInt());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,17 @@ public ProxyConnectObserver onProxyConnect(final Object connectMsg) {
}

@Override
@SuppressWarnings("deprecation")
public SecurityHandshakeObserver onSecurityHandshake() {
return new BiSecurityHandshakeObserver(first.onSecurityHandshake(), second.onSecurityHandshake());
}

@Override
public SecurityHandshakeObserver onSecurityHandshake(final SslConfig sslConfig) {
return new BiSecurityHandshakeObserver(first.onSecurityHandshake(sslConfig),
second.onSecurityHandshake(sslConfig));
}

@Override
public DataObserver connectionEstablished(final ConnectionInfo info) {
return new BiDataObserver(first.connectionEstablished(info), second.connectionEstablished(info));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,18 @@ public ProxyConnectObserver onProxyConnect(final Object connectMsg) {
}

@Override
@SuppressWarnings("deprecation")
public SecurityHandshakeObserver onSecurityHandshake() {
return safeReport(observer::onSecurityHandshake, observer, "security handshake",
CatchAllSecurityHandshakeObserver::new, NoopSecurityHandshakeObserver.INSTANCE);
}

@Override
public SecurityHandshakeObserver onSecurityHandshake(final SslConfig sslConfig) {
return safeReport(() -> observer.onSecurityHandshake(sslConfig), observer, "security handshake",
CatchAllSecurityHandshakeObserver::new, NoopSecurityHandshakeObserver.INSTANCE);
}

@Override
public DataObserver connectionEstablished(final ConnectionInfo info) {
return safeReport(() -> observer.connectionEstablished(info), observer, "connection established",
Expand Down
Loading

0 comments on commit 94c1bbb

Please sign in to comment.