Skip to content

Commit d25010d

Browse files
committed
Refactor the way BaseCluster.selectServer deals with the race condition
The new approach allows us to later refactor all other logic inside one or more `ServerSelector`s. See the comment left in the code for more details on the new approach. JAVA-4254
1 parent 4d883c1 commit d25010d

12 files changed

+82
-46
lines changed

driver-core/src/main/com/mongodb/internal/connection/AbstractMultiServerCluster.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@
3131

3232
import java.util.ArrayList;
3333
import java.util.Collection;
34+
import java.util.HashMap;
3435
import java.util.HashSet;
3536
import java.util.Iterator;
3637
import java.util.List;
38+
import java.util.Map;
3739
import java.util.Set;
3840
import java.util.concurrent.ConcurrentHashMap;
3941
import java.util.concurrent.ConcurrentMap;
@@ -122,14 +124,13 @@ public void close() {
122124
}
123125

124126
@Override
125-
public ClusterableServer getServer(final ServerAddress serverAddress) {
127+
public ServersSnapshot getServersSnapshot() {
126128
isTrue("is open", !isClosed());
127-
128-
ServerTuple serverTuple = addressToServerTupleMap.get(serverAddress);
129-
if (serverTuple == null) {
130-
return null;
131-
}
132-
return serverTuple.server;
129+
Map<ServerAddress, ServerTuple> nonAtomicSnapshot = new HashMap<>(addressToServerTupleMap);
130+
return serverAddress -> {
131+
ServerTuple serverTuple = nonAtomicSnapshot.get(serverAddress);
132+
return serverTuple == null ? null : serverTuple.server;
133+
};
133134
}
134135

135136
void onChange(final Collection<ServerAddress> newHosts) {

driver-core/src/main/com/mongodb/internal/connection/BaseCluster.java

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import java.util.concurrent.locks.ReentrantLock;
5959
import java.util.function.Function;
6060

61+
import static com.mongodb.assertions.Assertions.assertNotNull;
6162
import static com.mongodb.assertions.Assertions.isTrue;
6263
import static com.mongodb.assertions.Assertions.notNull;
6364
import static com.mongodb.connection.ServerDescription.MAX_DRIVER_WIRE_VERSION;
@@ -314,16 +315,35 @@ private boolean handleServerSelectionRequest(final ServerSelectionRequest reques
314315
@Nullable
315316
private ServerTuple selectServer(final ServerSelector serverSelector,
316317
final ClusterDescription clusterDescription) {
317-
return selectServer(serverSelector, clusterDescription, this::getServer);
318+
return selectServer(serverSelector, clusterDescription, getServersSnapshot());
318319
}
319320

320321
@Nullable
321322
@VisibleForTesting(otherwise = PRIVATE)
322323
static ServerTuple selectServer(final ServerSelector serverSelector, final ClusterDescription clusterDescription,
323-
final Function<ServerAddress, Server> serverCatalog) {
324-
return atMostNRandom(new ArrayList<>(serverSelector.select(clusterDescription)), 2, serverDescription -> {
325-
Server server = serverCatalog.apply(serverDescription.getAddress());
326-
return server == null ? null : new ServerTuple(server, serverDescription);
324+
final ServersSnapshot serversSnapshot) {
325+
// The set of `Server`s maintained by the `Cluster` is updated concurrently with `clusterDescription` being read.
326+
// Additionally, that set of servers continues to be concurrently updated while `serverSelector` selects.
327+
// This race condition means that we are not guaranteed not observe all the servers from `clusterDescription`
328+
// among the `Server`s maintained by the `Cluster`.
329+
// To deal with this race condition, we take `serversSnapshot` of that set of `Server`s
330+
// (the snapshot itself does not have to be atomic) non-atomically with reading `clusterDescription`
331+
// (this means, `serversSnapshot` and `clusterDescription` are not guaranteed to be consistent with each other),
332+
// and do pre-filtering to make sure that the only `ServerDescription`s we may select,
333+
// are of those `Server`s that are known to both `clusterDescription` and `serversSnapshot`.
334+
// This way we are guaranteed to successfully get `Server`s from `serversSnapshot` based on the selected `ServerDescription`s.
335+
//
336+
// The pre-filtering we do to deal with the race condition described above is achieved by this `ServerSelector`.
337+
ServerSelector raceConditionPreFiltering = clusterDescriptionPotentiallyInconsistentWithServerSnapshot ->
338+
clusterDescriptionPotentiallyInconsistentWithServerSnapshot.getServerDescriptions()
339+
.stream()
340+
.filter(serverDescription -> serversSnapshot.containsServer(serverDescription.getAddress()))
341+
.collect(toList());
342+
List<ServerDescription> intermediateResult = new CompositeServerSelector(asList(raceConditionPreFiltering, serverSelector))
343+
.select(clusterDescription);
344+
return atMostNRandom(new ArrayList<>(intermediateResult), 2, serverDescription -> {
345+
Server server = assertNotNull(serversSnapshot.getServer(serverDescription.getAddress()));
346+
return new ServerTuple(server, serverDescription);
327347
}).stream()
328348
.min(comparingInt(serverTuple -> serverTuple.getServer().operationCount()))
329349
.orElse(null);
@@ -345,18 +365,16 @@ private static List<ServerTuple> atMostNRandom(final ArrayList<ServerDescription
345365
List<ServerTuple> result = new ArrayList<>(n);
346366
for (int i = list.size() - 1; i >= 0 && result.size() < n; i--) {
347367
Collections.swap(list, i, random.nextInt(i + 1));
348-
ServerTuple serverTuple = transformer.apply(list.get(i));
349-
if (serverTuple != null) {
350-
result.add(serverTuple);
351-
}
368+
ServerTuple serverTuple = assertNotNull(transformer.apply(list.get(i)));
369+
result.add(serverTuple);
352370
}
353371
return result;
354372
}
355373

356374
private ServerSelector getCompleteServerSelector(final ServerSelector serverSelector, final ServerDeprioritization serverDeprioritization) {
357375
List<ServerSelector> selectors = Stream.of(
358376
serverSelector,
359-
settings.getServerSelector(),
377+
settings.getServerSelector(), // may be null
360378
new LatencyMinimizingServerSelector(settings.getLocalThreshold(MILLISECONDS), MILLISECONDS)
361379
).filter(Objects::nonNull).collect(toList());
362380
return serverDeprioritization.apply(new CompositeServerSelector(selectors));

driver-core/src/main/com/mongodb/internal/connection/Cluster.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919

2020
import com.mongodb.ServerAddress;
21+
import com.mongodb.annotations.ThreadSafe;
2122
import com.mongodb.connection.ClusterId;
2223
import com.mongodb.event.ServerDescriptionChangedEvent;
23-
import com.mongodb.internal.VisibleForTesting;
2424
import com.mongodb.internal.async.SingleResultCallback;
2525
import com.mongodb.connection.ClusterDescription;
2626
import com.mongodb.connection.ClusterSettings;
@@ -29,8 +29,6 @@
2929

3030
import java.io.Closeable;
3131

32-
import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE;
33-
3432
/**
3533
* Represents a cluster of MongoDB servers. Implementations can define the behaviour depending upon the type of cluster.
3634
*
@@ -43,9 +41,7 @@ public interface Cluster extends Closeable {
4341

4442
ClusterId getClusterId();
4543

46-
@Nullable
47-
@VisibleForTesting(otherwise = PRIVATE)
48-
ClusterableServer getServer(ServerAddress serverAddress);
44+
ServersSnapshot getServersSnapshot();
4945

5046
/**
5147
* Get the current description of this cluster.
@@ -89,4 +85,17 @@ void selectServerAsync(ServerSelector serverSelector, OperationContext operation
8985
* Server Discovery And Monitoring</a> specification.
9086
*/
9187
void onChange(ServerDescriptionChangedEvent event);
88+
89+
/**
90+
* A non-atomic snapshot of the servers in a {@link Cluster}.
91+
*/
92+
@ThreadSafe
93+
interface ServersSnapshot {
94+
@Nullable
95+
Server getServer(ServerAddress serverAddress);
96+
97+
default boolean containsServer(final ServerAddress serverAddress) {
98+
return getServer(serverAddress) != null;
99+
}
100+
}
92101
}

driver-core/src/main/com/mongodb/internal/connection/LoadBalancedCluster.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,11 @@ public ClusterId getClusterId() {
181181
}
182182

183183
@Override
184-
public ClusterableServer getServer(final ServerAddress serverAddress) {
184+
public ServersSnapshot getServersSnapshot() {
185185
isTrue("open", !isClosed());
186186
waitForSrv();
187-
return assertNotNull(server);
187+
ClusterableServer server = assertNotNull(this.server);
188+
return serverAddress -> server;
188189
}
189190

190191
@Override

driver-core/src/main/com/mongodb/internal/connection/SingleServerCluster.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package com.mongodb.internal.connection;
1818

1919
import com.mongodb.MongoConfigurationException;
20-
import com.mongodb.ServerAddress;
2120
import com.mongodb.connection.ClusterConnectionMode;
2221
import com.mongodb.connection.ClusterDescription;
2322
import com.mongodb.connection.ClusterId;
@@ -69,9 +68,10 @@ protected void connect() {
6968
}
7069

7170
@Override
72-
public ClusterableServer getServer(final ServerAddress serverAddress) {
71+
public ServersSnapshot getServersSnapshot() {
7372
isTrue("open", !isClosed());
74-
return assertNotNull(server.get());
73+
ClusterableServer server = assertNotNull(this.server.get());
74+
return serverAddress -> server;
7575
}
7676

7777
@Override

driver-core/src/test/unit/com/mongodb/internal/connection/AbstractServerDiscoveryAndMonitoringTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,13 @@ protected void applyResponse(final BsonArray response) {
8181
protected void applyApplicationError(final BsonDocument applicationError) {
8282
ServerAddress serverAddress = new ServerAddress(applicationError.getString("address").getValue());
8383
int errorGeneration = applicationError.getNumber("generation",
84-
new BsonInt32(((DefaultServer) getCluster().getServer(serverAddress)).getConnectionPool().getGeneration())).intValue();
84+
new BsonInt32(((DefaultServer) getCluster().getServersSnapshot().getServer(serverAddress))
85+
.getConnectionPool().getGeneration())).intValue();
8586
int maxWireVersion = applicationError.getNumber("maxWireVersion").intValue();
8687
String when = applicationError.getString("when").getValue();
8788
String type = applicationError.getString("type").getValue();
8889

89-
DefaultServer server = (DefaultServer) cluster.getServer(serverAddress);
90+
DefaultServer server = (DefaultServer) cluster.getServersSnapshot().getServer(serverAddress);
9091
RuntimeException exception;
9192

9293
switch (type) {

driver-core/src/test/unit/com/mongodb/internal/connection/BaseClusterSpecification.groovy

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,11 @@ class BaseClusterSpecification extends Specification {
6767
}
6868

6969
@Override
70-
ClusterableServer getServer(final ServerAddress serverAddress) {
71-
throw new UnsupportedOperationException()
70+
Cluster.ServersSnapshot getServersSnapshot() {
71+
Cluster.ServersSnapshot result = serverAddress -> {
72+
throw new UnsupportedOperationException()
73+
}
74+
return result
7275
}
7376

7477
@Override

driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerSpecification.groovy

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,11 @@ class DefaultServerSpecification extends Specification {
394394
}
395395

396396
@Override
397-
ClusterableServer getServer(final ServerAddress serverAddress) {
398-
throw new UnsupportedOperationException()
397+
Cluster.ServersSnapshot getServersSnapshot() {
398+
Cluster.ServersSnapshot result = serverAddress -> {
399+
throw new UnsupportedOperationException()
400+
}
401+
return result
399402
}
400403

401404
@Override

driver-core/src/test/unit/com/mongodb/internal/connection/MultiServerClusterSpecification.groovy

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,14 @@ class MultiServerClusterSpecification extends Specification {
8787
cluster.getCurrentDescription().connectionMode == MULTIPLE
8888
}
8989

90-
def 'should not get server when closed'() {
90+
def 'should not get servers snapshot when closed'() {
9191
given:
9292
def cluster = new MultiServerCluster(CLUSTER_ID, ClusterSettings.builder().hosts(Arrays.asList(firstServer)).mode(MULTIPLE).build(),
9393
factory)
9494
cluster.close()
9595

9696
when:
97-
cluster.getServer(firstServer)
97+
cluster.getServersSnapshot()
9898

9999
then:
100100
thrown(IllegalStateException)

driver-core/src/test/unit/com/mongodb/internal/connection/ServerDiscoveryAndMonitoringTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ private void assertServer(final String serverName, final BsonDocument expectedSe
120120

121121
if (expectedServerDescriptionDocument.isDocument("pool")) {
122122
int expectedGeneration = expectedServerDescriptionDocument.getDocument("pool").getNumber("generation").intValue();
123-
DefaultServer server = (DefaultServer) getCluster().getServer(new ServerAddress(serverName));
123+
DefaultServer server = (DefaultServer) getCluster().getServersSnapshot().getServer(new ServerAddress(serverName));
124124
assertEquals(expectedGeneration, server.getConnectionPool().getGeneration());
125125
}
126126
}

driver-core/src/test/unit/com/mongodb/internal/connection/ServerSelectionWithinLatencyWindowTest.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
@RunWith(Parameterized.class)
5757
public class ServerSelectionWithinLatencyWindowTest {
5858
private final ClusterDescription clusterDescription;
59-
private final Map<ServerAddress, Server> serverCatalog;
59+
private final Cluster.ServersSnapshot serversSnapshot;
6060
private final int iterations;
6161
private final Outcome outcome;
6262

@@ -65,7 +65,7 @@ public ServerSelectionWithinLatencyWindowTest(
6565
@SuppressWarnings("unused") final String description,
6666
final BsonDocument definition) {
6767
clusterDescription = buildClusterDescription(definition.getDocument("topology_description"), null);
68-
serverCatalog = serverCatalog(definition.getArray("mocked_topology_state"));
68+
serversSnapshot = serverCatalog(definition.getArray("mocked_topology_state"));
6969
iterations = definition.getInt32("iterations").getValue();
7070
outcome = Outcome.parse(definition.getDocument("outcome"));
7171
}
@@ -74,8 +74,7 @@ public ServerSelectionWithinLatencyWindowTest(
7474
public void shouldPassAllOutcomes() {
7575
ServerSelector selector = new ReadPreferenceServerSelector(ReadPreference.nearest());
7676
Map<ServerAddress, List<ServerTuple>> selectionResultsGroupedByServerAddress = IntStream.range(0, iterations)
77-
.mapToObj(i -> BaseCluster.selectServer(selector, clusterDescription,
78-
address -> Assertions.assertNotNull(serverCatalog.get(address))))
77+
.mapToObj(i -> BaseCluster.selectServer(selector, clusterDescription, serversSnapshot))
7978
.collect(groupingBy(serverTuple -> serverTuple.getServerDescription().getAddress()));
8079
Map<ServerAddress, BigDecimal> selectionFrequencies = selectionResultsGroupedByServerAddress.entrySet()
8180
.stream()
@@ -97,8 +96,8 @@ public static Collection<Object[]> data() {
9796
.collect(toList());
9897
}
9998

100-
private static Map<ServerAddress, Server> serverCatalog(final BsonArray mockedTopologyState) {
101-
return mockedTopologyState.stream()
99+
private static Cluster.ServersSnapshot serverCatalog(final BsonArray mockedTopologyState) {
100+
Map<ServerAddress, Server> serverMap = mockedTopologyState.stream()
102101
.map(BsonValue::asDocument)
103102
.collect(toMap(
104103
el -> new ServerAddress(el.getString("address").getValue()),
@@ -108,6 +107,7 @@ private static Map<ServerAddress, Server> serverCatalog(final BsonArray mockedTo
108107
when(server.operationCount()).thenReturn(operationCount);
109108
return server;
110109
}));
110+
return serverAddress -> Assertions.assertNotNull(serverMap.get(serverAddress));
111111
}
112112

113113
private static final class Outcome {

driver-core/src/test/unit/com/mongodb/internal/connection/SingleServerClusterSpecification.groovy

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,21 +76,21 @@ class SingleServerClusterSpecification extends Specification {
7676
sendNotification(firstServer, STANDALONE)
7777

7878
then:
79-
cluster.getServer(firstServer) == factory.getServer(firstServer)
79+
cluster.getServersSnapshot().getServer(firstServer) == factory.getServer(firstServer)
8080

8181
cleanup:
8282
cluster?.close()
8383
}
8484

8585

86-
def 'should not get server when closed'() {
86+
def 'should not get servers snapshot when closed'() {
8787
given:
8888
def cluster = new SingleServerCluster(CLUSTER_ID,
8989
ClusterSettings.builder().mode(SINGLE).hosts(Arrays.asList(firstServer)).build(), factory)
9090
cluster.close()
9191

9292
when:
93-
cluster.getServer(firstServer)
93+
cluster.getServersSnapshot()
9494

9595
then:
9696
thrown(IllegalStateException)

0 commit comments

Comments
 (0)