Skip to content

Fail correctly when connecting to a server with an unknown server identifier #542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@

import static org.neo4j.driver.internal.async.ChannelAttributes.setConnectionId;
import static org.neo4j.driver.internal.async.ChannelAttributes.setServerVersion;
import static org.neo4j.driver.internal.util.MetadataExtractor.extractNeo4jServerVersion;

public class HelloResponseHandler implements ResponseHandler
{
private static final String SERVER_METADATA_KEY = "server";
private static final String CONNECTION_ID_METADATA_KEY = "connection_id";

private final ChannelPromise connectionInitializedPromise;
Expand All @@ -49,10 +49,10 @@ public void onSuccess( Map<String,Value> metadata )
{
try
{
ServerVersion serverVersion = ServerVersion.version( extractMetadataValue( SERVER_METADATA_KEY, metadata ) );
ServerVersion serverVersion = extractNeo4jServerVersion( metadata );
setServerVersion( channel, serverVersion );

String connectionId = extractMetadataValue( CONNECTION_ID_METADATA_KEY, metadata );
String connectionId = extractConnectionId( metadata );
setConnectionId( channel, connectionId );

connectionInitializedPromise.setSuccess();
Expand All @@ -76,12 +76,13 @@ public void onRecord( Value[] fields )
throw new UnsupportedOperationException();
}

private static String extractMetadataValue( String key, Map<String,Value> metadata )
private static String extractConnectionId( Map<String,Value> metadata )
{
Value value = metadata.get( key );
Value value = metadata.get( CONNECTION_ID_METADATA_KEY );
if ( value == null || value.isNull() )
{
throw new IllegalStateException( "Unable to extract " + key + " from a response to HELLO message. Metadata: " + metadata );
throw new IllegalStateException( "Unable to extract " + CONNECTION_ID_METADATA_KEY + " from a response to HELLO message. " +
"Received metadata: " + metadata );
}
return value.asString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.neo4j.driver.v1.Value;

import static org.neo4j.driver.internal.async.ChannelAttributes.setServerVersion;
import static org.neo4j.driver.internal.util.MetadataExtractor.extractNeo4jServerVersion;

public class InitResponseHandler implements ResponseHandler
{
Expand All @@ -47,7 +48,7 @@ public void onSuccess( Map<String,Value> metadata )
{
try
{
ServerVersion serverVersion = extractServerVersion( metadata );
ServerVersion serverVersion = extractNeo4jServerVersion( metadata );
setServerVersion( channel, serverVersion );
updatePipelineIfNeeded( serverVersion, channel.pipeline() );
connectionInitializedPromise.setSuccess();
Expand All @@ -71,13 +72,6 @@ public void onRecord( Value[] fields )
throw new UnsupportedOperationException();
}

private static ServerVersion extractServerVersion( Map<String,Value> metadata )
{
Value versionValue = metadata.get( "server" );
boolean versionAbsent = versionValue == null || versionValue.isNull();
return versionAbsent ? ServerVersion.v3_0_0 : ServerVersion.version( versionValue.asString() );
}

private static void updatePipelineIfNeeded( ServerVersion serverVersion, ChannelPipeline pipeline )
{
if ( serverVersion.lessThan( ServerVersion.v3_2_0 ) )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.neo4j.driver.internal.summary.InternalSummaryCounters;
import org.neo4j.driver.v1.Statement;
import org.neo4j.driver.v1.Value;
import org.neo4j.driver.v1.exceptions.UntrustedServerException;
import org.neo4j.driver.v1.summary.Notification;
import org.neo4j.driver.v1.summary.Plan;
import org.neo4j.driver.v1.summary.ProfiledPlan;
Expand Down Expand Up @@ -101,6 +102,27 @@ public Bookmarks extractBookmarks( Map<String,Value> metadata )
return Bookmarks.empty();
}

public static ServerVersion extractNeo4jServerVersion( Map<String,Value> metadata )
{
Value versionValue = metadata.get( "server" );
if ( versionValue == null || versionValue.isNull() )
{
throw new UntrustedServerException( "Server provides no product identifier" );
}
else
{
ServerVersion server = ServerVersion.version( versionValue.asString() );
if ( ServerVersion.NEO4J_PRODUCT.equalsIgnoreCase( server.product() ) )
{
return server;
}
else
{
throw new UntrustedServerException( "Server does not identify as a genuine Neo4j instance: '" + server.product() + "'" );
}
}
}

private static StatementType extractStatementType( Map<String,Value> metadata )
{
Value typeValue = metadata.get( "type" );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.neo4j.driver.internal.util;

import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand All @@ -28,28 +29,37 @@

public class ServerVersion
{
public static final ServerVersion v3_5_0 = new ServerVersion( 3, 5, 0 );
public static final ServerVersion v3_4_0 = new ServerVersion( 3, 4, 0 );
public static final ServerVersion v3_2_0 = new ServerVersion( 3, 2, 0 );
public static final ServerVersion v3_1_0 = new ServerVersion( 3, 1, 0 );
public static final ServerVersion v3_0_0 = new ServerVersion( 3, 0, 0 );
public static final ServerVersion vInDev = new ServerVersion( Integer.MAX_VALUE, Integer.MAX_VALUE, Integer.MAX_VALUE );

private static final String NEO4J_IN_DEV_VERSION_STRING = "Neo4j/dev";
public static final String NEO4J_PRODUCT = "Neo4j";

public static final ServerVersion v3_5_0 = new ServerVersion( NEO4J_PRODUCT, 3, 5, 0 );
public static final ServerVersion v3_4_0 = new ServerVersion( NEO4J_PRODUCT, 3, 4, 0 );
public static final ServerVersion v3_2_0 = new ServerVersion( NEO4J_PRODUCT, 3, 2, 0 );
public static final ServerVersion v3_1_0 = new ServerVersion( NEO4J_PRODUCT, 3, 1, 0 );
public static final ServerVersion v3_0_0 = new ServerVersion( NEO4J_PRODUCT, 3, 0, 0 );
public static final ServerVersion vInDev = new ServerVersion( NEO4J_PRODUCT, Integer.MAX_VALUE, Integer.MAX_VALUE, Integer.MAX_VALUE );

private static final String NEO4J_IN_DEV_VERSION_STRING = NEO4J_PRODUCT + "/dev";
private static final Pattern PATTERN =
Pattern.compile( "(Neo4j/)?(\\d+)\\.(\\d+)(?:\\.)?(\\d*)(\\.|-|\\+)?([0-9A-Za-z-.]*)?" );
Pattern.compile( "([^/]+)/(\\d+)\\.(\\d+)(?:\\.)?(\\d*)(\\.|-|\\+)?([0-9A-Za-z-.]*)?" );

private final String product;
private final int major;
private final int minor;
private final int patch;
private final String stringValue;

private ServerVersion( int major, int minor, int patch )
private ServerVersion( String product, int major, int minor, int patch )
{
this.product = product;
this.major = major;
this.minor = minor;
this.patch = patch;
this.stringValue = stringValue( major, minor, patch );
this.stringValue = stringValue( product, major, minor, patch );
}

public String product()
{
return product;
}

public static ServerVersion version( Driver driver )
Expand All @@ -63,33 +73,27 @@ public static ServerVersion version( Driver driver )

public static ServerVersion version( String server )
{
if ( server == null )
Matcher matcher = PATTERN.matcher( server );
if ( matcher.matches() )
{
String product = matcher.group( 1 );
int major = Integer.valueOf( matcher.group( 2 ) );
int minor = Integer.valueOf( matcher.group( 3 ) );
String patchString = matcher.group( 4 );
int patch = 0;
if ( patchString != null && !patchString.isEmpty() )
{
patch = Integer.valueOf( patchString );
}
return new ServerVersion( product, major, minor, patch );
}
else if ( server.equalsIgnoreCase( NEO4J_IN_DEV_VERSION_STRING ) )
{
return v3_0_0;
return vInDev;
}
else
{
Matcher matcher = PATTERN.matcher( server );
if ( matcher.matches() )
{
int major = Integer.valueOf( matcher.group( 2 ) );
int minor = Integer.valueOf( matcher.group( 3 ) );
String patchString = matcher.group( 4 );
int patch = 0;
if ( patchString != null && !patchString.isEmpty() )
{
patch = Integer.valueOf( patchString );
}
return new ServerVersion( major, minor, patch );
}
else if ( server.equalsIgnoreCase( NEO4J_IN_DEV_VERSION_STRING ) )
{
return vInDev;
}
else
{
throw new IllegalArgumentException( "Cannot parse " + server );
}
throw new IllegalArgumentException( "Cannot parse " + server );
}
}

Expand All @@ -103,6 +107,8 @@ public boolean equals( Object o )

ServerVersion that = (ServerVersion) o;

if ( !product.equals( that.product ) )
{ return false; }
if ( major != that.major )
{ return false; }
if ( minor != that.minor )
Expand All @@ -113,10 +119,7 @@ public boolean equals( Object o )
@Override
public int hashCode()
{
int result = major;
result = 31 * result + minor;
result = 31 * result + patch;
return result;
return Objects.hash(product, major, minor, patch);
}

public boolean greaterThan(ServerVersion other)
Expand All @@ -141,6 +144,10 @@ public boolean lessThanOrEqual(ServerVersion other)

private int compareTo( ServerVersion o )
{
if ( !product.equals( o.product ) )
{
throw new IllegalArgumentException( "Comparing different products '" + product + "' with '" + o.product + "'" );
}
int c = compare( major, o.major );
if (c == 0)
{
Expand All @@ -160,12 +167,12 @@ public String toString()
return stringValue;
}

private static String stringValue( int major, int minor, int patch )
private static String stringValue( String product, int major, int minor, int patch )
{
if ( major == Integer.MAX_VALUE && minor == Integer.MAX_VALUE && patch == Integer.MAX_VALUE )
{
return NEO4J_IN_DEV_VERSION_STRING;
}
return String.format( "Neo4j/%s.%s.%s", major, minor, patch );
return String.format( "%s/%s.%s.%s", product, major, minor, patch );
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) 2002-2018 "Neo4j,"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.neo4j.driver.v1.exceptions;

/**
* Thrown if the remote server cannot be verified as Neo4j.
*/
public class UntrustedServerException extends RuntimeException
{
public UntrustedServerException(String message)
{
super(message);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2002-2018 "Neo4j,"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.neo4j.driver.internal;

import org.junit.jupiter.api.Test;

import org.neo4j.driver.v1.Config;
import org.neo4j.driver.v1.GraphDatabase;
import org.neo4j.driver.v1.exceptions.UntrustedServerException;
import org.neo4j.driver.v1.util.StubServer;

import static org.hamcrest.core.IsEqual.equalTo;
import static org.hamcrest.junit.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.neo4j.driver.v1.Logging.none;

class TrustedServerProductTest
{
private static final Config config = Config.build()
.withoutEncryption()
.withLogging( none() )
.toConfig();

@Test
void shouldRejectConnectionsToNonNeo4jServers() throws Exception
{
StubServer server = StubServer.start( "untrusted_server.script", 9001 );
assertThrows( UntrustedServerException.class, () -> GraphDatabase.driver( "bolt://127.0.0.1:9001", config ));
assertThat( server.exitStatus(), equalTo( 0 ) );
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,17 @@ void shouldFailToSetMessageDispatcherTwice()
@Test
void shouldSetAndGetServerVersion()
{
ServerVersion version = version( "3.2.1" );
ServerVersion version = version( "Neo4j/3.2.1" );
setServerVersion( channel, version );
assertEquals( version, serverVersion( channel ) );
}

@Test
void shouldFailToSetServerVersionTwice()
{
setServerVersion( channel, version( "3.2.2" ) );
setServerVersion( channel, version( "Neo4j/3.2.2" ) );

assertThrows( IllegalStateException.class, () -> setServerVersion( channel, version( "3.2.3" ) ) );
assertThrows( IllegalStateException.class, () -> setServerVersion( channel, version( "Neo4j/3.2.3" ) ) );
}

@Test
Expand Down
Loading