Skip to content

Mock connections more accurately in DisruptableMockTransport #37296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ public Join handleStartJoin(StartJoinRequest startJoinRequest) {
* @throws CoordinationStateRejectedException if the arguments were incompatible with the current state of this object.
*/
public boolean handleJoin(Join join) {
assert join.getTargetNode().equals(localNode) : "handling join " + join + " for the wrong node " + localNode;
assert join.targetMatches(localNode) : "handling join " + join + " for the wrong node " + localNode;

if (join.getTerm() != getCurrentTerm()) {
logger.debug("handleJoin: ignored join due to term mismatch (expected: [{}], actual: [{}])",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ PublishWithJoinResponse handlePublishRequest(PublishRequest publishRequest) {

private static Optional<Join> joinWithDestination(Optional<Join> lastJoin, DiscoveryNode leader, long term) {
if (lastJoin.isPresent()
&& lastJoin.get().getTargetNode().getId().equals(leader.getId())
&& lastJoin.get().targetMatches(leader)
&& lastJoin.get().getTerm() == term) {
return lastJoin;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ public DiscoveryNode getTargetNode() {
return targetNode;
}

public boolean targetMatches(DiscoveryNode matchingNode) {
return targetNode.getId().equals(matchingNode.getId());
}

public long getLastAcceptedVersion() {
return lastAcceptedVersion;
}
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.IndexScopedSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor;
Expand Down Expand Up @@ -114,15 +115,13 @@
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.elasticsearch.env.Environment.PATH_HOME_SETTING;
import static org.elasticsearch.node.Node.NODE_NAME_SETTING;
import static org.elasticsearch.transport.TransportService.HANDSHAKE_ACTION_NAME;
import static org.elasticsearch.transport.TransportService.NOOP_TRANSPORT_INTERCEPTOR;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.hasSize;
Expand Down Expand Up @@ -388,41 +387,26 @@ protected PrioritizedEsThreadPoolExecutor createThreadPoolExecutor() {
return new MockSinglePrioritizingExecutor(node.getName(), deterministicTaskQueue);
}
});
mockTransport = new DisruptableMockTransport(logger) {
mockTransport = new DisruptableMockTransport(node, logger) {
@Override
protected DiscoveryNode getLocalNode() {
return node;
}

@Override
protected ConnectionStatus getConnectionStatus(DiscoveryNode sender, DiscoveryNode destination) {
protected ConnectionStatus getConnectionStatus(DiscoveryNode destination) {
return ConnectionStatus.CONNECTED;
}

@Override
protected Optional<DisruptableMockTransport> getDisruptedCapturingTransport(DiscoveryNode node, String action) {
final Predicate<TestClusterNode> matchesDestination;
if (action.equals(HANDSHAKE_ACTION_NAME)) {
matchesDestination = n -> n.transportService.getLocalNode().getAddress().equals(node.getAddress());
} else {
matchesDestination = n -> n.transportService.getLocalNode().equals(node);
}
return testClusterNodes.nodes.values().stream().filter(matchesDestination).findAny().map(cn -> cn.mockTransport);
protected Optional<DisruptableMockTransport> getDisruptableMockTransport(TransportAddress address) {
return testClusterNodes.nodes.values().stream().map(cn -> cn.mockTransport)
.filter(transport -> transport.getLocalNode().getAddress().equals(address))
.findAny();
}

@Override
protected void handle(DiscoveryNode sender, DiscoveryNode destination, String action, Runnable doDelivery) {
// handshake needs to run inline as the caller blockingly waits on the result
final Runnable runnable = CoordinatorTests.onNode(destination, doDelivery);
if (action.equals(HANDSHAKE_ACTION_NAME)) {
runnable.run();
} else {
deterministicTaskQueue.scheduleNow(runnable);
}
protected void execute(Runnable runnable) {
deterministicTaskQueue.scheduleNow(CoordinatorTests.onNodeLog(getLocalNode(), runnable));
}
};
transportService = mockTransport.createTransportService(
settings, deterministicTaskQueue.getThreadPool(runnable -> CoordinatorTests.onNode(node, runnable)),
settings, deterministicTaskQueue.getThreadPool(runnable -> CoordinatorTests.onNodeLog(node, runnable)),
NOOP_TRANSPORT_INTERCEPTOR,
a -> node, null, emptySet()
);
Expand Down Expand Up @@ -544,7 +528,16 @@ public void start(ClusterState initialState) {
coordinator.start();
masterService.start();
clusterService.getClusterApplierService().setNodeConnectionsService(
new NodeConnectionsService(clusterService.getSettings(), threadPool, transportService));
new NodeConnectionsService(clusterService.getSettings(), threadPool, transportService) {
@Override
public void connectToNodes(DiscoveryNodes discoveryNodes) {
// override this method as it does blocking calls
for (final DiscoveryNode node : discoveryNodes) {
transportService.connectToNode(node);
}
super.connectToNodes(discoveryNodes);
}
});
clusterService.getClusterApplierService().start();
indicesService.start();
indicesClusterStateService.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,80 +20,123 @@

import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterModule;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.test.transport.MockTransport;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.CloseableConnection;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.ConnectionProfile;
import org.elasticsearch.transport.RequestHandlerRegistry;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportInterceptor;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

import static org.elasticsearch.test.ESTestCase.copyWriteable;
import static org.elasticsearch.transport.TransportService.HANDSHAKE_ACTION_NAME;

public abstract class DisruptableMockTransport extends MockTransport {
private final DiscoveryNode localNode;
private final Logger logger;

public DisruptableMockTransport(Logger logger) {
public DisruptableMockTransport(DiscoveryNode localNode, Logger logger) {
this.localNode = localNode;
this.logger = logger;
}

protected abstract DiscoveryNode getLocalNode();
protected abstract ConnectionStatus getConnectionStatus(DiscoveryNode destination);

protected abstract ConnectionStatus getConnectionStatus(DiscoveryNode sender, DiscoveryNode destination);
protected abstract Optional<DisruptableMockTransport> getDisruptableMockTransport(TransportAddress address);

protected abstract Optional<DisruptableMockTransport> getDisruptedCapturingTransport(DiscoveryNode node, String action);
protected abstract void execute(Runnable runnable);

protected abstract void handle(DiscoveryNode sender, DiscoveryNode destination, String action, Runnable doDelivery);
protected final void execute(String action, Runnable runnable) {
// handshake needs to run inline as the caller blockingly waits on the result
if (action.equals(HANDSHAKE_ACTION_NAME)) {
runnable.run();
} else {

protected final void sendFromTo(DiscoveryNode sender, DiscoveryNode destination, String action, Runnable doDelivery) {
handle(sender, destination, action, new Runnable() {
@Override
public void run() {
if (getDisruptedCapturingTransport(destination, action).isPresent()) {
doDelivery.run();
} else {
logger.trace("unknown destination in {}", this);
}
}
execute(runnable);
}
}

@Override
public String toString() {
return doDelivery.toString();
}
});
public DiscoveryNode getLocalNode() {
return localNode;
}

@Override
public TransportService createTransportService(Settings settings, ThreadPool threadPool, TransportInterceptor interceptor,
Function<BoundTransportAddress, DiscoveryNode> localNodeFactory,
@Nullable ClusterSettings clusterSettings, Set<String> taskHeaders) {
return new TransportService(settings, this, threadPool, interceptor, localNodeFactory, clusterSettings, taskHeaders);
}

@Override
protected void onSendRequest(long requestId, String action, TransportRequest request, DiscoveryNode destination) {
public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener<Connection> listener) {
final Optional<DisruptableMockTransport> matchingTransport = getDisruptableMockTransport(node.getAddress());
if (matchingTransport.isPresent()) {
listener.onResponse(new CloseableConnection() {
@Override
public DiscoveryNode getNode() {
return node;
}

assert destination.equals(getLocalNode()) == false : "non-local message from " + getLocalNode() + " to itself";
@Override
public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options)
throws TransportException {
onSendRequest(requestId, action, request, matchingTransport.get());
}
});
return () -> {};
} else {
throw new ConnectTransportException(node, "node " + node + " does not exist");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is OK for now but in future (hoho) we will want this to be async and/or to timeout on an unknown node.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this depends on the connection manager becoming async. Right now there's a Future.get() waiting for us behind this call.

}
}

sendFromTo(getLocalNode(), destination, action, new Runnable() {
protected void onSendRequest(long requestId, String action, TransportRequest request,
DisruptableMockTransport destinationTransport) {

assert destinationTransport.getLocalNode().equals(getLocalNode()) == false :
"non-local message from " + getLocalNode() + " to itself";

execute(action, new Runnable() {
@Override
public void run() {
switch (getConnectionStatus(getLocalNode(), destination)) {
switch (getConnectionStatus(destinationTransport.getLocalNode())) {
case BLACK_HOLE:
onBlackholedDuringSend(requestId, action, destination);
onBlackholedDuringSend(requestId, action, destinationTransport);
break;

case DISCONNECTED:
onDisconnectedDuringSend(requestId, action, destination);
onDisconnectedDuringSend(requestId, action, destinationTransport);
break;

case CONNECTED:
onConnectedDuringSend(requestId, action, request, destination);
onConnectedDuringSend(requestId, action, request, destinationTransport);
break;
}
}

@Override
public String toString() {
return getRequestDescription(requestId, action, destination);
return getRequestDescription(requestId, action, destinationTransport.getLocalNode());
}
});
}
Expand All @@ -117,20 +160,27 @@ protected String getRequestDescription(long requestId, String action, DiscoveryN
requestId, action, getLocalNode(), destination).getFormattedMessage();
}

protected void onBlackholedDuringSend(long requestId, String action, DiscoveryNode destination) {
logger.trace("dropping {}", getRequestDescription(requestId, action, destination));
protected void onBlackholedDuringSend(long requestId, String action, DisruptableMockTransport destinationTransport) {
if (action.equals(HANDSHAKE_ACTION_NAME)) {
logger.trace("ignoring blackhole and delivering {}",
getRequestDescription(requestId, action, destinationTransport.getLocalNode()));
// handshakes always have a timeout, and are sent in a blocking fashion, so we must respond with an exception.
destinationTransport.execute(action, getDisconnectException(requestId, action, destinationTransport.getLocalNode()));
} else {
logger.trace("dropping {}", getRequestDescription(requestId, action, destinationTransport.getLocalNode()));
}
}

protected void onDisconnectedDuringSend(long requestId, String action, DiscoveryNode destination) {
sendFromTo(destination, getLocalNode(), action, getDisconnectException(requestId, action, destination));
protected void onDisconnectedDuringSend(long requestId, String action, DisruptableMockTransport destinationTransport) {
destinationTransport.execute(action, getDisconnectException(requestId, action, destinationTransport.getLocalNode()));
}

protected void onConnectedDuringSend(long requestId, String action, TransportRequest request, DiscoveryNode destination) {
Optional<DisruptableMockTransport> destinationTransport = getDisruptedCapturingTransport(destination, action);
assert destinationTransport.isPresent();

protected void onConnectedDuringSend(long requestId, String action, TransportRequest request,
DisruptableMockTransport destinationTransport) {
final RequestHandlerRegistry<TransportRequest> requestHandler =
destinationTransport.get().getRequestHandler(action);
destinationTransport.getRequestHandler(action);

final DiscoveryNode destination = destinationTransport.getLocalNode();

final String requestDescription = getRequestDescription(requestId, action, destination);

Expand All @@ -147,10 +197,10 @@ public String getChannelType() {

@Override
public void sendResponse(final TransportResponse response) {
sendFromTo(destination, getLocalNode(), action, new Runnable() {
execute(action, new Runnable() {
@Override
public void run() {
if (getConnectionStatus(destination, getLocalNode()) != ConnectionStatus.CONNECTED) {
if (destinationTransport.getConnectionStatus(getLocalNode()) != ConnectionStatus.CONNECTED) {
logger.trace("dropping response to {}: channel is not CONNECTED",
requestDescription);
} else {
Expand All @@ -167,10 +217,10 @@ public String toString() {

@Override
public void sendResponse(Exception exception) {
sendFromTo(destination, getLocalNode(), action, new Runnable() {
execute(action, new Runnable() {
@Override
public void run() {
if (getConnectionStatus(destination, getLocalNode()) != ConnectionStatus.CONNECTED) {
if (destinationTransport.getConnectionStatus(getLocalNode()) != ConnectionStatus.CONNECTED) {
logger.trace("dropping response to {}: channel is not CONNECTED",
requestDescription);
} else {
Expand Down
Loading