Skip to content

4.1 add routing context to hello message #701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,18 @@ public final Driver newInstance ( URI uri, AuthToken authToken, RoutingSettings
RetryLogic retryLogic = createRetryLogic( retrySettings, eventExecutorGroup, config.logging() );

MetricsProvider metricsProvider = createDriverMetrics( config, createClock() );
ConnectionPool connectionPool = createConnectionPool( authToken, securityPlan, bootstrap, metricsProvider, config, ownsEventLoopGroup );
ConnectionPool connectionPool = createConnectionPool( authToken, securityPlan, bootstrap, metricsProvider, config,
ownsEventLoopGroup, newRoutingSettings.routingContext() );

return createDriver( uri, securityPlan, address, connectionPool, eventExecutorGroup, newRoutingSettings, retryLogic, metricsProvider, config );
}

protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Bootstrap bootstrap,
MetricsProvider metricsProvider, Config config, boolean ownsEventLoopGroup )
MetricsProvider metricsProvider, Config config, boolean ownsEventLoopGroup, RoutingContext routingContext )
{
Clock clock = createClock();
ConnectionSettings settings = new ConnectionSettings( authToken, config.connectionTimeoutMillis() );
ChannelConnector connector = createConnector( settings, securityPlan, config, clock );
ChannelConnector connector = createConnector( settings, securityPlan, config, clock, routingContext );
PoolSettings poolSettings = new PoolSettings( config.maxConnectionPoolSize(),
config.connectionAcquisitionTimeoutMillis(), config.maxConnectionLifetimeMillis(),
config.idleTimeBeforeConnectionTest()
Expand All @@ -124,9 +125,9 @@ protected static MetricsProvider createDriverMetrics( Config config, Clock clock
}

protected ChannelConnector createConnector( ConnectionSettings settings, SecurityPlan securityPlan,
Config config, Clock clock )
Config config, Clock clock, RoutingContext routingContext )
{
return new ChannelConnectorImpl( settings, securityPlan, config.logging(), clock );
return new ChannelConnectorImpl( settings, securityPlan, config.logging(), clock, routingContext );
}

private InternalDriver createDriver( URI uri, SecurityPlan securityPlan, BoltServerAddress address, ConnectionPool connectionPool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.ConnectionSettings;
import org.neo4j.driver.internal.async.inbound.ConnectTimeoutHandler;
import org.neo4j.driver.internal.cluster.RoutingContext;
import org.neo4j.driver.internal.security.InternalAuthToken;
import org.neo4j.driver.internal.security.SecurityPlan;
import org.neo4j.driver.internal.util.Clock;
Expand All @@ -44,24 +45,26 @@
public class ChannelConnectorImpl implements ChannelConnector
{
private final String userAgent;
private final Map<String,Value> authToken;
private final AuthToken authToken;
private final RoutingContext routingContext;
private final SecurityPlan securityPlan;
private final ChannelPipelineBuilder pipelineBuilder;
private final int connectTimeoutMillis;
private final Logging logging;
private final Clock clock;

public ChannelConnectorImpl( ConnectionSettings connectionSettings, SecurityPlan securityPlan, Logging logging,
Clock clock )
Clock clock, RoutingContext routingContext )
{
this( connectionSettings, securityPlan, new ChannelPipelineBuilderImpl(), logging, clock );
this( connectionSettings, securityPlan, new ChannelPipelineBuilderImpl(), logging, clock, routingContext );
}

public ChannelConnectorImpl( ConnectionSettings connectionSettings, SecurityPlan securityPlan,
ChannelPipelineBuilder pipelineBuilder, Logging logging, Clock clock )
ChannelPipelineBuilder pipelineBuilder, Logging logging, Clock clock, RoutingContext routingContext )
{
this.userAgent = connectionSettings.userAgent();
this.authToken = tokenAsMap( connectionSettings.authToken() );
this.authToken = requireValidAuthToken( connectionSettings.authToken() );
this.routingContext = routingContext;
this.connectTimeoutMillis = connectionSettings.connectTimeoutMillis();
this.securityPlan = requireNonNull( securityPlan );
this.pipelineBuilder = pipelineBuilder;
Expand Down Expand Up @@ -113,14 +116,14 @@ private void installHandshakeCompletedListeners( ChannelPromise handshakeComplet

// add listener that sends an INIT message. connection is now fully established. channel pipeline if fully
// set to send/receive messages for a selected protocol version
handshakeCompleted.addListener( new HandshakeCompletedListener( userAgent, authToken, connectionInitialized ) );
handshakeCompleted.addListener( new HandshakeCompletedListener( userAgent, authToken, routingContext, connectionInitialized ) );
}

private static Map<String,Value> tokenAsMap( AuthToken token )
private static AuthToken requireValidAuthToken( AuthToken token )
{
if ( token instanceof InternalAuthToken )
{
return ((InternalAuthToken) token).toMap();
return token;
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

import java.util.Map;

import org.neo4j.driver.AuthToken;
import org.neo4j.driver.internal.cluster.RoutingContext;
import org.neo4j.driver.internal.messaging.BoltProtocol;
import org.neo4j.driver.Value;

Expand All @@ -32,14 +34,16 @@
public class HandshakeCompletedListener implements ChannelFutureListener
{
private final String userAgent;
private final Map<String,Value> authToken;
private final AuthToken authToken;
private final RoutingContext routingContext;
private final ChannelPromise connectionInitializedPromise;

public HandshakeCompletedListener( String userAgent, Map<String,Value> authToken,
ChannelPromise connectionInitializedPromise )
public HandshakeCompletedListener( String userAgent, AuthToken authToken,
RoutingContext routingContext, ChannelPromise connectionInitializedPromise )
{
this.userAgent = requireNonNull( userAgent );
this.authToken = requireNonNull( authToken );
this.routingContext = routingContext;
this.connectionInitializedPromise = requireNonNull( connectionInitializedPromise );
}

Expand All @@ -49,7 +53,7 @@ public void operationComplete( ChannelFuture future )
if ( future.isSuccess() )
{
BoltProtocol protocol = BoltProtocol.forChannel( future.channel() );
protocol.initializeChannel( userAgent, authToken, connectionInitializedPromise );
protocol.initializeChannel( userAgent, authToken, routingContext, connectionInitializedPromise );
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ BookmarkHolder bookmarkHolder( Bookmark bookmark )
Query procedureQuery(ServerVersion serverVersion, DatabaseName databaseName )
{
HashMap<String,Value> map = new HashMap<>();
map.put( ROUTING_CONTEXT, value( context.asMap() ) );
map.put( ROUTING_CONTEXT, value( context.toMap() ) );
map.put( DATABASE_NAME, value( (Object) databaseName.databaseName().orElse( null ) ) );
return new Query( MULTI_DB_GET_ROUTING_TABLE, value( map ) );
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Map;

import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.Scheme;

import static java.util.Collections.emptyMap;
import static java.util.Collections.unmodifiableMap;
Expand All @@ -33,14 +34,17 @@ public class RoutingContext
private static final String ROUTING_ADDRESS_KEY = "address";

private final Map<String,String> context;
private final boolean isServerRoutingEnabled;

private RoutingContext()
{
this.isServerRoutingEnabled = true;
this.context = emptyMap();
}

public RoutingContext( URI uri )
{
this.isServerRoutingEnabled = Scheme.isRoutingScheme( uri.getScheme() );
this.context = unmodifiableMap( parseParameters( uri ) );
}

Expand All @@ -49,15 +53,20 @@ public boolean isDefined()
return context.size() > 1;
}

public Map<String,String> asMap()
public Map<String,String> toMap()
{
return context;
}

public boolean isServerRoutingEnabled()
{
return isServerRoutingEnabled;
}

@Override
public String toString()
{
return "RoutingContext" + context;
return "RoutingContext" + context + " isServerRoutingEnabled=" + isServerRoutingEnabled;
}

private static Map<String,String> parseParameters( URI uri )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Query procedureQuery(ServerVersion serverVersion, DatabaseName databaseName )
"Refreshing routing table for multi-databases is not supported in server version lower than 4.0. " +
"Current server version: %s. Database name: '%s'", serverVersion, databaseName.description() ) );
}
return new Query( GET_ROUTING_TABLE, parameters( ROUTING_CONTEXT, context.asMap() ) );
return new Query( GET_ROUTING_TABLE, parameters( ROUTING_CONTEXT, context.toMap() ) );
}

BookmarkHolder bookmarkHolder( Bookmark ignored )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Map;
import java.util.concurrent.CompletionStage;

import org.neo4j.driver.AuthToken;
import org.neo4j.driver.Bookmark;
import org.neo4j.driver.Query;
import org.neo4j.driver.Session;
Expand All @@ -34,6 +35,7 @@
import org.neo4j.driver.internal.BookmarkHolder;
import org.neo4j.driver.internal.InternalBookmark;
import org.neo4j.driver.internal.async.UnmanagedTransaction;
import org.neo4j.driver.internal.cluster.RoutingContext;
import org.neo4j.driver.internal.cursor.ResultCursorFactory;
import org.neo4j.driver.internal.messaging.v1.BoltProtocolV1;
import org.neo4j.driver.internal.messaging.v2.BoltProtocolV2;
Expand All @@ -58,9 +60,10 @@ public interface BoltProtocol
*
* @param userAgent the user agent string.
* @param authToken the authentication token.
* @param routingContext the configured routing context
* @param channelInitializedPromise the promise to be notified when initialization is completed.
*/
void initializeChannel( String userAgent, Map<String,Value> authToken, ChannelPromise channelInitializedPromise );
void initializeChannel( String userAgent, AuthToken authToken, RoutingContext routingContext, ChannelPromise channelInitializedPromise );

/**
* Prepare to close channel before it is closed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ public class HelloMessage extends MessageWithMetadata
public final static byte SIGNATURE = 0x01;

private static final String USER_AGENT_METADATA_KEY = "user_agent";
private static final String ROUTING_CONTEXT_METADATA_KEY = "routing";

public HelloMessage( String userAgent, Map<String,Value> authToken )
public HelloMessage( String userAgent, Map<String,Value> authToken, Map<String,String> routingContext )
{
super( buildMetadata( userAgent, authToken ) );
super( buildMetadata( userAgent, authToken, routingContext ) );
}

@Override
Expand Down Expand Up @@ -73,10 +74,11 @@ public String toString()
return "HELLO " + metadataCopy;
}

private static Map<String,Value> buildMetadata( String userAgent, Map<String,Value> authToken )
private static Map<String,Value> buildMetadata( String userAgent, Map<String,Value> authToken, Map<String,String> routingContext )
{
Map<String,Value> result = new HashMap<>( authToken );
result.put( USER_AGENT_METADATA_KEY, value( userAgent ) );
result.put( ROUTING_CONTEXT_METADATA_KEY, value( routingContext ) );
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;

import org.neo4j.driver.AuthToken;
import org.neo4j.driver.Bookmark;
import org.neo4j.driver.Query;
import org.neo4j.driver.TransactionConfig;
Expand All @@ -35,6 +36,7 @@
import org.neo4j.driver.internal.DatabaseName;
import org.neo4j.driver.internal.InternalBookmark;
import org.neo4j.driver.internal.async.UnmanagedTransaction;
import org.neo4j.driver.internal.cluster.RoutingContext;
import org.neo4j.driver.internal.cursor.AsyncResultCursorOnlyFactory;
import org.neo4j.driver.internal.cursor.ResultCursorFactory;
import org.neo4j.driver.internal.handlers.BeginTxResponseHandler;
Expand All @@ -52,6 +54,7 @@
import org.neo4j.driver.internal.messaging.request.InitMessage;
import org.neo4j.driver.internal.messaging.request.PullAllMessage;
import org.neo4j.driver.internal.messaging.request.RunMessage;
import org.neo4j.driver.internal.security.InternalAuthToken;
import org.neo4j.driver.internal.spi.Connection;
import org.neo4j.driver.internal.spi.ResponseHandler;
import org.neo4j.driver.internal.util.Futures;
Expand Down Expand Up @@ -84,11 +87,12 @@ public MessageFormat createMessageFormat()
}

@Override
public void initializeChannel( String userAgent, Map<String,Value> authToken, ChannelPromise channelInitializedPromise )
public void initializeChannel( String userAgent, AuthToken authToken, RoutingContext routingContext,
ChannelPromise channelInitializedPromise )
{
Channel channel = channelInitializedPromise.channel();

InitMessage message = new InitMessage( userAgent, authToken );
InitMessage message = new InitMessage( userAgent, ((InternalAuthToken) authToken).toMap() );
InitResponseHandler handler = new InitResponseHandler( channelInitializedPromise );

messageDispatcher( channel ).enqueue( handler );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;

import org.neo4j.driver.AuthToken;
import org.neo4j.driver.Bookmark;
import org.neo4j.driver.Query;
import org.neo4j.driver.TransactionConfig;
import org.neo4j.driver.Value;
import org.neo4j.driver.internal.BookmarkHolder;
import org.neo4j.driver.internal.DatabaseName;
import org.neo4j.driver.internal.InternalBookmark;
import org.neo4j.driver.internal.async.UnmanagedTransaction;
import org.neo4j.driver.internal.cluster.RoutingContext;
import org.neo4j.driver.internal.cursor.AsyncResultCursorOnlyFactory;
import org.neo4j.driver.internal.cursor.ResultCursorFactory;
import org.neo4j.driver.internal.handlers.BeginTxResponseHandler;
Expand All @@ -49,6 +50,7 @@
import org.neo4j.driver.internal.messaging.request.GoodbyeMessage;
import org.neo4j.driver.internal.messaging.request.HelloMessage;
import org.neo4j.driver.internal.messaging.request.RunWithMetadataMessage;
import org.neo4j.driver.internal.security.InternalAuthToken;
import org.neo4j.driver.internal.spi.Connection;
import org.neo4j.driver.internal.util.Futures;
import org.neo4j.driver.internal.util.MetadataExtractor;
Expand Down Expand Up @@ -76,11 +78,20 @@ public MessageFormat createMessageFormat()
}

@Override
public void initializeChannel( String userAgent, Map<String,Value> authToken, ChannelPromise channelInitializedPromise )
public void initializeChannel( String userAgent, AuthToken authToken, RoutingContext routingContext, ChannelPromise channelInitializedPromise )
{
Channel channel = channelInitializedPromise.channel();
HelloMessage message;

if ( routingContext.isServerRoutingEnabled() )
{
message = new HelloMessage( userAgent, ( ( InternalAuthToken ) authToken ).toMap(), routingContext.toMap() );
}
else
{
message = new HelloMessage( userAgent, ( ( InternalAuthToken ) authToken ).toMap(), null );
}

HelloMessage message = new HelloMessage( userAgent, authToken );
HelloResponseHandler handler = new HelloResponseHandler( channelInitializedPromise, version() );

messageDispatcher( channel ).enqueue( handler );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.neo4j.driver.internal.async.connection.ChannelConnector;
import org.neo4j.driver.internal.async.connection.ChannelConnectorImpl;
import org.neo4j.driver.internal.async.inbound.ConnectTimeoutHandler;
import org.neo4j.driver.internal.cluster.RoutingContext;
import org.neo4j.driver.internal.security.SecurityPlanImpl;
import org.neo4j.driver.internal.security.SecurityPlan;
import org.neo4j.driver.internal.util.FakeClock;
Expand Down Expand Up @@ -231,7 +232,7 @@ private ChannelConnectorImpl newConnector( AuthToken authToken, SecurityPlan sec
int connectTimeoutMillis )
{
ConnectionSettings settings = new ConnectionSettings( authToken, connectTimeoutMillis );
return new ChannelConnectorImpl( settings, securityPlan, DEV_NULL_LOGGING, new FakeClock() );
return new ChannelConnectorImpl( settings, securityPlan, DEV_NULL_LOGGING, new FakeClock(), RoutingContext.EMPTY );
}

private static SecurityPlan trustAllCertificates() throws GeneralSecurityException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.neo4j.driver.internal.async.connection.ChannelConnector;
import org.neo4j.driver.internal.async.pool.ConnectionPoolImpl;
import org.neo4j.driver.internal.async.pool.PoolSettings;
import org.neo4j.driver.internal.cluster.RoutingContext;
import org.neo4j.driver.internal.cluster.RoutingSettings;
import org.neo4j.driver.internal.metrics.MetricsProvider;
import org.neo4j.driver.internal.retry.RetrySettings;
Expand Down Expand Up @@ -445,14 +446,15 @@ private static class DriverFactoryWithConnectionPool extends DriverFactory

@Override
protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Bootstrap bootstrap,
MetricsProvider ignored, Config config, boolean ownsEventLoopGroup )
MetricsProvider ignored, Config config, boolean ownsEventLoopGroup,
RoutingContext routingContext )
{
ConnectionSettings connectionSettings = new ConnectionSettings( authToken, 1000 );
PoolSettings poolSettings = new PoolSettings( config.maxConnectionPoolSize(),
config.connectionAcquisitionTimeoutMillis(), config.maxConnectionLifetimeMillis(),
config.idleTimeBeforeConnectionTest() );
Clock clock = createClock();
ChannelConnector connector = super.createConnector( connectionSettings, securityPlan, config, clock );
ChannelConnector connector = super.createConnector( connectionSettings, securityPlan, config, clock, routingContext );
connectionPool = new MemorizingConnectionPool( connector, bootstrap, poolSettings, config.logging(), clock, ownsEventLoopGroup );
return connectionPool;
}
Expand Down
Loading