Skip to content

Commit 24e6af1

Browse files
authored
Merge pull request #373 from lutovich/1.3-direct-driver-auth-errors
Fail early on auth errors in direct driver
2 parents 4116f53 + a81b238 commit 24e6af1

File tree

15 files changed

+293
-146
lines changed

15 files changed

+293
-146
lines changed

driver/src/main/java/org/neo4j/driver/internal/DirectConnectionProvider.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ public class DirectConnectionProvider implements ConnectionProvider
3737
{
3838
this.address = address;
3939
this.pool = pool;
40+
41+
verifyConnectivity();
4042
}
4143

4244
@Override
@@ -55,4 +57,13 @@ public BoltServerAddress getAddress()
5557
{
5658
return address;
5759
}
60+
61+
/**
62+
* Acquires and releases a connection to verify connectivity so this connection provider fails fast. This is
63+
* especially valuable when driver was created with incorrect credentials.
64+
*/
65+
private void verifyConnectivity()
66+
{
67+
acquireConnection( AccessMode.READ ).close();
68+
}
5869
}

driver/src/test/java/org/neo4j/driver/internal/DirectConnectionProviderTest.java

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@
2525
import org.neo4j.driver.internal.spi.PooledConnection;
2626

2727
import static org.junit.Assert.assertEquals;
28+
import static org.junit.Assert.assertNotNull;
2829
import static org.junit.Assert.assertSame;
30+
import static org.junit.Assert.fail;
2931
import static org.mockito.Matchers.any;
32+
import static org.mockito.Mockito.RETURNS_MOCKS;
33+
import static org.mockito.Mockito.doThrow;
3034
import static org.mockito.Mockito.mock;
31-
import static org.mockito.Mockito.only;
3235
import static org.mockito.Mockito.verify;
3336
import static org.mockito.Mockito.when;
3437
import static org.neo4j.driver.v1.AccessMode.READ;
@@ -42,7 +45,7 @@ public void acquiresConnectionsFromThePool()
4245
ConnectionPool pool = mock( ConnectionPool.class );
4346
PooledConnection connection1 = mock( PooledConnection.class );
4447
PooledConnection connection2 = mock( PooledConnection.class );
45-
when( pool.acquire( any( BoltServerAddress.class ) ) ).thenReturn( connection1 ).thenReturn( connection2 );
48+
when( pool.acquire( any( BoltServerAddress.class ) ) ).thenReturn( connection1, connection1, connection2 );
4649

4750
DirectConnectionProvider provider = newConnectionProvider( pool );
4851

@@ -53,12 +56,12 @@ public void acquiresConnectionsFromThePool()
5356
@Test
5457
public void closesPool() throws Exception
5558
{
56-
ConnectionPool pool = mock( ConnectionPool.class );
59+
ConnectionPool pool = mock( ConnectionPool.class, RETURNS_MOCKS );
5760
DirectConnectionProvider provider = newConnectionProvider( pool );
5861

5962
provider.close();
6063

61-
verify( pool, only() ).close();
64+
verify( pool ).close();
6265
}
6366

6467
@Test
@@ -71,9 +74,42 @@ public void returnsCorrectAddress()
7174
assertEquals( address, provider.getAddress() );
7275
}
7376

77+
@Test
78+
public void testsConnectivityOnCreation()
79+
{
80+
ConnectionPool pool = mock( ConnectionPool.class );
81+
PooledConnection connection = mock( PooledConnection.class );
82+
when( pool.acquire( any( BoltServerAddress.class ) ) ).thenReturn( connection );
83+
84+
assertNotNull( newConnectionProvider( pool ) );
85+
86+
verify( pool ).acquire( BoltServerAddress.LOCAL_DEFAULT );
87+
verify( connection ).close();
88+
}
89+
90+
@Test
91+
public void throwsWhenTestConnectionThrows()
92+
{
93+
ConnectionPool pool = mock( ConnectionPool.class );
94+
PooledConnection connection = mock( PooledConnection.class );
95+
RuntimeException error = new RuntimeException();
96+
doThrow( error ).when( connection ).close();
97+
when( pool.acquire( any( BoltServerAddress.class ) ) ).thenReturn( connection );
98+
99+
try
100+
{
101+
newConnectionProvider( pool );
102+
fail( "Exception expected" );
103+
}
104+
catch ( Exception e )
105+
{
106+
assertSame( error, e );
107+
}
108+
}
109+
74110
private static DirectConnectionProvider newConnectionProvider( BoltServerAddress address )
75111
{
76-
return new DirectConnectionProvider( address, mock( ConnectionPool.class ) );
112+
return new DirectConnectionProvider( address, mock( ConnectionPool.class, RETURNS_MOCKS ) );
77113
}
78114

79115
private static DirectConnectionProvider newConnectionProvider( ConnectionPool pool )

driver/src/test/java/org/neo4j/driver/internal/DirectDriverTest.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
*/
1919
package org.neo4j.driver.internal;
2020

21+
import org.junit.After;
22+
import org.junit.ClassRule;
2123
import org.junit.Test;
2224

2325
import java.net.URI;
@@ -28,6 +30,7 @@
2830
import org.neo4j.driver.v1.Record;
2931
import org.neo4j.driver.v1.Session;
3032
import org.neo4j.driver.v1.util.StubServer;
33+
import org.neo4j.driver.v1.util.TestNeo4j;
3134

3235
import static org.hamcrest.Matchers.is;
3336
import static org.hamcrest.core.IsEqual.equalTo;
@@ -40,14 +43,28 @@
4043

4144
public class DirectDriverTest
4245
{
46+
@ClassRule
47+
public static final TestNeo4j neo4j = new TestNeo4j();
48+
49+
private Driver driver;
50+
51+
@After
52+
public void closeDriver() throws Exception
53+
{
54+
if ( driver != null )
55+
{
56+
driver.close();
57+
}
58+
}
59+
4360
@Test
4461
public void shouldUseDefaultPortIfMissing()
4562
{
4663
// Given
4764
URI uri = URI.create( "bolt://localhost" );
4865

4966
// When
50-
Driver driver = GraphDatabase.driver( uri );
67+
driver = GraphDatabase.driver( uri, neo4j.authToken() );
5168

5269
// Then
5370
assertThat( driver, is( directDriverWithAddress( LOCAL_DEFAULT ) ) );
@@ -61,7 +78,7 @@ public void shouldAllowIPv6Address()
6178
BoltServerAddress address = BoltServerAddress.from( uri );
6279

6380
// When
64-
Driver driver = GraphDatabase.driver( uri );
81+
driver = GraphDatabase.driver( uri, neo4j.authToken() );
6582

6683
// Then
6784
assertThat( driver, is( directDriverWithAddress( address ) ) );
@@ -76,7 +93,7 @@ public void shouldRejectInvalidAddress()
7693
// When & Then
7794
try
7895
{
79-
Driver driver = GraphDatabase.driver( uri );
96+
driver = GraphDatabase.driver( uri, neo4j.authToken() );
8097
fail("Expecting error for wrong uri");
8198
}
8299
catch( IllegalArgumentException e )
@@ -93,7 +110,7 @@ public void shouldRegisterSingleServer()
93110
BoltServerAddress address = BoltServerAddress.from( uri );
94111

95112
// When
96-
Driver driver = GraphDatabase.driver( uri );
113+
driver = GraphDatabase.driver( uri, neo4j.authToken() );
97114

98115
// Then
99116
assertThat( driver, is( directDriverWithAddress( address ) ) );

driver/src/test/java/org/neo4j/driver/internal/DriverFactoryTest.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.neo4j.driver.internal.security.SecurityPlan;
3737
import org.neo4j.driver.internal.spi.ConnectionPool;
3838
import org.neo4j.driver.internal.spi.ConnectionProvider;
39+
import org.neo4j.driver.internal.spi.PooledConnection;
3940
import org.neo4j.driver.v1.AuthToken;
4041
import org.neo4j.driver.v1.AuthTokens;
4142
import org.neo4j.driver.v1.Config;
@@ -45,9 +46,11 @@
4546
import static org.junit.Assert.assertArrayEquals;
4647
import static org.junit.Assert.assertThat;
4748
import static org.junit.Assert.fail;
49+
import static org.mockito.Mockito.any;
4850
import static org.mockito.Mockito.doThrow;
4951
import static org.mockito.Mockito.mock;
5052
import static org.mockito.Mockito.verify;
53+
import static org.mockito.Mockito.when;
5154
import static org.neo4j.driver.v1.AccessMode.READ;
5255
import static org.neo4j.driver.v1.Config.defaultConfig;
5356

@@ -69,7 +72,7 @@ public static List<URI> uris()
6972
@Test
7073
public void connectionPoolClosedWhenDriverCreationFails() throws Exception
7174
{
72-
ConnectionPool connectionPool = mock( ConnectionPool.class );
75+
ConnectionPool connectionPool = connectionPoolMock();
7376
DriverFactory factory = new ThrowingDriverFactory( connectionPool );
7477

7578
try
@@ -87,7 +90,7 @@ public void connectionPoolClosedWhenDriverCreationFails() throws Exception
8790
@Test
8891
public void connectionPoolCloseExceptionIsSupressedWhenDriverCreationFails() throws Exception
8992
{
90-
ConnectionPool connectionPool = mock( ConnectionPool.class );
93+
ConnectionPool connectionPool = connectionPoolMock();
9194
RuntimeException poolCloseError = new RuntimeException( "Pool close error" );
9295
doThrow( poolCloseError ).when( connectionPool ).close();
9396

@@ -142,6 +145,13 @@ private Driver createDriver( DriverFactory driverFactory, Config config )
142145
return driverFactory.newInstance( uri, auth, routingSettings, RetrySettings.DEFAULT, config );
143146
}
144147

148+
private static ConnectionPool connectionPoolMock()
149+
{
150+
ConnectionPool pool = mock( ConnectionPool.class );
151+
when( pool.acquire( any( BoltServerAddress.class ) ) ).thenReturn( mock( PooledConnection.class ) );
152+
return pool;
153+
}
154+
145155
private static class ThrowingDriverFactory extends DriverFactory
146156
{
147157
final ConnectionPool connectionPool;
@@ -196,5 +206,11 @@ protected SessionFactory createSessionFactory( ConnectionProvider connectionProv
196206
capturedSessionFactory = sessionFactory;
197207
return sessionFactory;
198208
}
209+
210+
@Override
211+
protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Config config )
212+
{
213+
return connectionPoolMock();
214+
}
199215
}
200216
}

driver/src/test/java/org/neo4j/driver/v1/GraphDatabaseTest.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,18 @@
3939
public class GraphDatabaseTest
4040
{
4141
@Test
42-
public void boltSchemeShouldInstantiateDirectDriver()
42+
public void boltSchemeShouldInstantiateDirectDriver() throws Exception
4343
{
4444
// Given
45-
URI uri = URI.create( "bolt://localhost:7687" );
45+
StubServer server = StubServer.start( "dummy_connection.script", 9001 );
46+
URI uri = URI.create( "bolt://localhost:9001" );
4647

4748
// When
48-
Driver driver = GraphDatabase.driver( uri );
49+
Driver driver = GraphDatabase.driver( uri, INSECURE_CONFIG );
4950

5051
// Then
5152
assertThat( driver, is( directDriver() ) );
53+
server.exit();
5254
}
5355

5456
@Test

driver/src/test/java/org/neo4j/driver/v1/integration/ConnectionHandlingIT.java

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import org.neo4j.driver.internal.spi.PooledConnection;
4141
import org.neo4j.driver.internal.util.Clock;
4242
import org.neo4j.driver.v1.AuthToken;
43-
import org.neo4j.driver.v1.AuthTokens;
4443
import org.neo4j.driver.v1.Config;
4544
import org.neo4j.driver.v1.Driver;
4645
import org.neo4j.driver.v1.Logging;
@@ -85,6 +84,7 @@ public void createDriver()
8584
RetrySettings retrySettings = RetrySettings.DEFAULT;
8685
driver = driverFactory.newInstance( neo4j.uri(), auth, routingSettings, retrySettings, defaultConfig() );
8786
connectionPool = driverFactory.connectionPool;
87+
connectionPool.startMemorizing(); // start memorizing connections after driver creation
8888
}
8989

9090
@After
@@ -370,23 +370,34 @@ protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan
370370
private static class MemorizingConnectionPool extends SocketConnectionPool
371371
{
372372
PooledConnection lastAcquiredConnectionSpy;
373+
boolean memorize;
373374

374375
MemorizingConnectionPool( PoolSettings poolSettings, Connector connector, Clock clock, Logging logging )
375376
{
376377
super( poolSettings, connector, clock, logging );
377378
}
378379

380+
void startMemorizing()
381+
{
382+
memorize = true;
383+
}
384+
379385
@Override
380386
public PooledConnection acquire( BoltServerAddress address )
381387
{
382388
PooledConnection connection = super.acquire( address );
383-
// this connection pool returns spies so spies will be returned to the pool
384-
// prevent spying on spies...
385-
if ( !Mockito.mockingDetails( connection ).isSpy() )
389+
390+
if ( memorize )
386391
{
387-
connection = spy( connection );
392+
// this connection pool returns spies so spies will be returned to the pool
393+
// prevent spying on spies...
394+
if ( !Mockito.mockingDetails( connection ).isSpy() )
395+
{
396+
connection = spy( connection );
397+
}
398+
lastAcquiredConnectionSpy = connection;
388399
}
389-
lastAcquiredConnectionSpy = connection;
400+
390401
return connection;
391402
}
392403
}

0 commit comments

Comments
 (0)