Skip to content

Move outbound message handling to OutboundHandler #40336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public ExceptionThrowingNetty4Transport(
}

@Override
protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException {
protected void handleRequest(TcpChannel channel, InboundMessage.Request request, int messageLengthBytes) throws IOException {
super.handleRequest(channel, request, messageLengthBytes);
channelProfileName = TransportSettings.DEFAULT_PROFILE;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public Map<String, Supplier<Transport>> getTransports(Settings settings, ThreadP
}

@Override
protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException {
protected void handleRequest(TcpChannel channel, InboundMessage.Request request, int messageLengthBytes) throws IOException {
super.handleRequest(channel, request, messageLengthBytes);
channelProfileName = TransportSettings.DEFAULT_PROFILE;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ InboundMessage deserialize(BytesReference reference) throws IOException {
if (TransportStatus.isRequest(status)) {
final Set<String> features = Collections.unmodifiableSet(new TreeSet<>(Arrays.asList(streamInput.readStringArray())));
final String action = streamInput.readString();
message = new RequestMessage(threadContext, remoteVersion, status, requestId, action, features, streamInput);
message = new Request(threadContext, remoteVersion, status, requestId, action, features, streamInput);
} else {
message = new ResponseMessage(threadContext, remoteVersion, status, requestId, streamInput);
message = new Response(threadContext, remoteVersion, status, requestId, streamInput);
}
success = true;
return message;
Expand Down Expand Up @@ -133,13 +133,13 @@ private static void ensureVersionCompatibility(Version version, Version currentV
}
}

public static class RequestMessage extends InboundMessage {
public static class Request extends InboundMessage {

private final String actionName;
private final Set<String> features;

RequestMessage(ThreadContext threadContext, Version version, byte status, long requestId, String actionName, Set<String> features,
StreamInput streamInput) {
Request(ThreadContext threadContext, Version version, byte status, long requestId, String actionName, Set<String> features,
StreamInput streamInput) {
super(threadContext, version, status, requestId, streamInput);
this.actionName = actionName;
this.features = features;
Expand All @@ -154,9 +154,9 @@ Set<String> getFeatures() {
}
}

public static class ResponseMessage extends InboundMessage {
public static class Response extends InboundMessage {

ResponseMessage(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) {
Response(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) {
super(threadContext, version, status, requestId, streamInput);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.NotifyOnceListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
Expand All @@ -32,49 +34,100 @@
import org.elasticsearch.common.metrics.MeanMetric;
import org.elasticsearch.common.network.CloseableChannel;
import org.elasticsearch.common.transport.NetworkExceptionHelper;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.threadpool.ThreadPool;

import java.io.IOException;
import java.util.Set;

final class OutboundHandler {

private static final Logger logger = LogManager.getLogger(OutboundHandler.class);

private final MeanMetric transmittedBytesMetric = new MeanMetric();

private final String nodeName;
private final Version version;
private final String[] features;
private final ThreadPool threadPool;
private final BigArrays bigArrays;
private final TransportLogger transportLogger;
private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER;

OutboundHandler(ThreadPool threadPool, BigArrays bigArrays, TransportLogger transportLogger) {
OutboundHandler(String nodeName, Version version, String[] features, ThreadPool threadPool, BigArrays bigArrays,
TransportLogger transportLogger) {
this.nodeName = nodeName;
this.version = version;
this.features = features;
this.threadPool = threadPool;
this.bigArrays = bigArrays;
this.transportLogger = transportLogger;
}

void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener<Void> listener) {
channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
SendContext sendContext = new SendContext(channel, () -> bytes, listener);
try {
internalSendMessage(channel, sendContext);
internalSend(channel, sendContext);
} catch (IOException e) {
// This should not happen as the bytes are already serialized
throw new AssertionError(e);
}
}

void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener<Void> listener) throws IOException {
channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays);
SendContext sendContext = new SendContext(channel, serializer, listener, serializer);
internalSendMessage(channel, sendContext);
/**
* Sends the request to the given channel. This method should be used to send {@link TransportRequest}
* objects back to the caller.
*/
void sendRequest(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action,
final TransportRequest request, final TransportRequestOptions options, final Version channelVersion,
final boolean compressRequest, final boolean isHandshake) throws IOException, TransportException {
Version version = Version.min(this.version, channelVersion);
OutboundMessage.Request message = new OutboundMessage.Request(threadPool.getThreadContext(), features, request, version, action,
requestId, isHandshake, compressRequest);
ActionListener<Void> listener = ActionListener.wrap(() ->
messageListener.onRequestSent(node, requestId, action, request, options));
sendMessage(channel, message, listener);
}

/**
* Sends the response to the given channel. This method should be used to send {@link TransportResponse}
* objects back to the caller.
*
* @see #sendErrorResponse(Version, Set, TcpChannel, long, String, Exception) for sending error responses
*/
void sendResponse(final Version nodeVersion, final Set<String> features, final TcpChannel channel,
final long requestId, final String action, final TransportResponse response,
final boolean compress, final boolean isHandshake) throws IOException {
Version version = Version.min(this.version, nodeVersion);
OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, response, version,
requestId, isHandshake, compress);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response));
sendMessage(channel, message, listener);
}

/**
* sends a message to the given channel, using the given callbacks.
* Sends back an error response to the caller via the given channel
*/
private void internalSendMessage(TcpChannel channel, SendContext sendContext) throws IOException {
void sendErrorResponse(final Version nodeVersion, final Set<String> features, final TcpChannel channel, final long requestId,
final String action, final Exception error) throws IOException {
Version version = Version.min(this.version, nodeVersion);
TransportAddress address = new TransportAddress(channel.getLocalAddress());
RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error);
OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, tx, version, requestId,
false, false);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error));
sendMessage(channel, message, listener);
}

private void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener<Void> listener) throws IOException {
MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays);
SendContext sendContext = new SendContext(channel, serializer, listener, serializer);
internalSend(channel, sendContext);
}

private void internalSend(TcpChannel channel, SendContext sendContext) throws IOException {
channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
BytesReference reference = sendContext.get();
try {
Expand All @@ -91,6 +144,14 @@ MeanMetric getTransmittedBytes() {
return transmittedBytesMetric;
}

void setMessageListener(TransportMessageListener listener) {
if (messageListener == TransportMessageListener.NOOP_LISTENER) {
messageListener = listener;
} else {
throw new IllegalStateException("Cannot set message listener twice");
}
}

private static class MessageSerializer implements CheckedSupplier<BytesReference, IOException>, Releasable {

private final OutboundMessage message;
Expand Down
Loading