Skip to content

Simplify InboundPipeline and make it more obviously correct #91350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public InboundAggregator(
Function<String, RequestHandlerRegistry<TransportRequest>> registryFunction,
boolean ignoreDeserializationErrors
) {
this(circuitBreaker, (Predicate<String>) actionName -> {
this(circuitBreaker, actionName -> {
final RequestHandlerRegistry<TransportRequest> reg = registryFunction.apply(actionName);
if (reg == null) {
assert ignoreDeserializationErrors : actionName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.CheckedBiConsumer;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
Expand All @@ -18,7 +19,6 @@
import org.elasticsearch.core.Releasables;

import java.io.IOException;
import java.util.function.Consumer;

public class InboundDecoder implements Releasable {

Expand All @@ -38,23 +38,31 @@ public InboundDecoder(Version version, Recycler<BytesRef> recycler) {
this.recycler = recycler;
}

public int decode(ReleasableBytesReference reference, Consumer<Object> fragmentConsumer) throws IOException {
public int decode(
ReleasableBytesReference reference,
TcpChannel channel,
CheckedBiConsumer<TcpChannel, Object, IOException> fragmentConsumer
) throws IOException {
ensureOpen();
try {
return internalDecode(reference, fragmentConsumer);
return internalDecode(reference, channel, fragmentConsumer);
} catch (Exception e) {
cleanDecodeState();
throw e;
}
}

public int internalDecode(ReleasableBytesReference reference, Consumer<Object> fragmentConsumer) throws IOException {
private int internalDecode(
ReleasableBytesReference reference,
TcpChannel channel,
CheckedBiConsumer<TcpChannel, Object, IOException> fragmentConsumer
) throws IOException {
if (isOnHeader()) {
int messageLength = TcpTransport.readMessageLength(reference);
if (messageLength == -1) {
return 0;
} else if (messageLength == 0) {
fragmentConsumer.accept(PING);
fragmentConsumer.accept(channel, PING);
return 6;
} else {
int headerBytesToRead = headerBytesToRead(reference);
Expand All @@ -68,10 +76,10 @@ public int internalDecode(ReleasableBytesReference reference, Consumer<Object> f
if (header.isCompressed()) {
isCompressed = true;
}
fragmentConsumer.accept(header);
fragmentConsumer.accept(channel, header);

if (isDone()) {
finishMessage(fragmentConsumer);
finishMessage(channel, fragmentConsumer);
}
return headerBytesToRead;
}
Expand All @@ -84,7 +92,7 @@ public int internalDecode(ReleasableBytesReference reference, Consumer<Object> f
return 0;
} else {
this.decompressor = decompressor;
fragmentConsumer.accept(this.decompressor.getScheme());
fragmentConsumer.accept(channel, this.decompressor.getScheme());
}
}
int remainingToConsume = totalNetworkSize - bytesConsumed;
Expand All @@ -102,15 +110,15 @@ public int internalDecode(ReleasableBytesReference reference, Consumer<Object> f
bytesConsumed += bytesConsumedThisDecode;
ReleasableBytesReference decompressed;
while ((decompressed = decompressor.pollDecompressedPage(isDone())) != null) {
fragmentConsumer.accept(decompressed);
fragmentConsumer.accept(channel, decompressed);
}
} else {
bytesConsumedThisDecode += maxBytesToConsume;
bytesConsumed += maxBytesToConsume;
fragmentConsumer.accept(retainedContent);
fragmentConsumer.accept(channel, retainedContent);
}
if (isDone()) {
finishMessage(fragmentConsumer);
finishMessage(channel, fragmentConsumer);
}

return bytesConsumedThisDecode;
Expand All @@ -123,9 +131,9 @@ public void close() {
cleanDecodeState();
}

private void finishMessage(Consumer<Object> fragmentConsumer) {
private void finishMessage(TcpChannel channel, CheckedBiConsumer<TcpChannel, Object, IOException> fragmentConsumer) throws IOException {
cleanDecodeState();
fragmentConsumer.accept(END_CONTENT);
fragmentConsumer.accept(channel, END_CONTENT);
}

private void cleanDecodeState() {
Expand Down
108 changes: 40 additions & 68 deletions server/src/main/java/org/elasticsearch/transport/InboundPipeline.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.CheckedBiConsumer;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
Expand All @@ -19,26 +20,25 @@

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.LongSupplier;
import java.util.function.Supplier;

public class InboundPipeline implements Releasable {

private static final ThreadLocal<ArrayList<Object>> fragmentList = ThreadLocal.withInitial(ArrayList::new);
private static final InboundMessage PING_MESSAGE = new InboundMessage(null, true);

private final LongSupplier relativeTimeInMillis;
private final StatsTracker statsTracker;
private final InboundDecoder decoder;
private final InboundAggregator aggregator;
private final BiConsumer<TcpChannel, InboundMessage> messageHandler;
private Exception uncaughtException;
private final ArrayDeque<ReleasableBytesReference> pending = new ArrayDeque<>(2);
private boolean isClosed = false;

private final CheckedBiConsumer<TcpChannel, Object, IOException> fragmentConsumer;

public InboundPipeline(
Version version,
StatsTracker statsTracker,
Expand Down Expand Up @@ -69,7 +69,37 @@ public InboundPipeline(
this.statsTracker = statsTracker;
this.decoder = decoder;
this.aggregator = aggregator;
this.messageHandler = messageHandler;
this.fragmentConsumer = (c, fragment) -> {
if (fragment instanceof Header) {
assert aggregator.isAggregating() == false;
aggregator.headerReceived((Header) fragment);
} else if (fragment instanceof Compression.Scheme) {
assert aggregator.isAggregating();
aggregator.updateCompressionScheme((Compression.Scheme) fragment);
} else if (fragment == InboundDecoder.PING) {
assert aggregator.isAggregating() == false;
messageHandler.accept(c, PING_MESSAGE);
} else if (fragment == InboundDecoder.END_CONTENT) {
assert aggregator.isAggregating();
InboundMessage aggregated = aggregator.finishAggregation();
try {
statsTracker.markMessageReceived();
messageHandler.accept(c, aggregated);
} finally {
aggregated.decRef();
}
} else {
assert aggregator.isAggregating();
assert fragment instanceof ReleasableBytesReference;
// fragment will be released by the aggregator
final var bytes = (ReleasableBytesReference) fragment;
try {
aggregator.aggregate(bytes);
} finally {
bytes.decRef();
}
}
};
}

@Override
Expand All @@ -94,75 +124,17 @@ public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference
channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong());
statsTracker.markBytesRead(reference.length());
pending.add(reference.retain());

final ArrayList<Object> fragments = fragmentList.get();
boolean continueHandling = true;

while (continueHandling && isClosed == false) {
boolean continueDecoding = true;
while (continueDecoding && pending.isEmpty() == false) {
try (ReleasableBytesReference toDecode = getPendingBytes()) {
final int bytesDecoded = decoder.decode(toDecode, fragments::add);
if (bytesDecoded != 0) {
releasePendingBytes(bytesDecoded);
if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) {
continueDecoding = false;
}
} else {
continueDecoding = false;
}
while (pending.isEmpty() == false && isClosed == false) {
try (ReleasableBytesReference toDecode = getPendingBytes()) {
final int bytesDecoded = decoder.decode(toDecode, channel, fragmentConsumer);
if (bytesDecoded == 0) {
break;
}
}

if (fragments.isEmpty()) {
continueHandling = false;
} else {
try {
forwardFragments(channel, fragments);
} finally {
for (Object fragment : fragments) {
if (fragment instanceof ReleasableBytesReference) {
((ReleasableBytesReference) fragment).close();
}
}
fragments.clear();
}
}
}
}

private void forwardFragments(TcpChannel channel, ArrayList<Object> fragments) throws IOException {
for (Object fragment : fragments) {
if (fragment instanceof Header) {
assert aggregator.isAggregating() == false;
aggregator.headerReceived((Header) fragment);
} else if (fragment instanceof Compression.Scheme) {
assert aggregator.isAggregating();
aggregator.updateCompressionScheme((Compression.Scheme) fragment);
} else if (fragment == InboundDecoder.PING) {
assert aggregator.isAggregating() == false;
messageHandler.accept(channel, PING_MESSAGE);
} else if (fragment == InboundDecoder.END_CONTENT) {
assert aggregator.isAggregating();
InboundMessage aggregated = aggregator.finishAggregation();
try {
statsTracker.markMessageReceived();
messageHandler.accept(channel, aggregated);
} finally {
aggregated.decRef();
}
} else {
assert aggregator.isAggregating();
assert fragment instanceof ReleasableBytesReference;
aggregator.aggregate((ReleasableBytesReference) fragment);
releasePendingBytes(bytesDecoded);
}
}
}

private static boolean endOfMessage(Object fragment) {
return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception;
}

private ReleasableBytesReference getPendingBytes() {
if (pending.size() == 1) {
return pending.peekFirst().retain();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public void testDecode() throws IOException {
InboundDecoder decoder = new InboundDecoder(Version.CURRENT, recycler);
final ArrayList<Object> fragments = new ArrayList<>();
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(totalBytes);
int bytesConsumed = decoder.decode(releasable1, fragments::add);
int bytesConsumed = decoder.decode(releasable1, null, (c, f) -> fragments.add(f));
assertEquals(totalHeaderSize, bytesConsumed);
assertTrue(releasable1.hasReferences());

Expand All @@ -100,7 +100,7 @@ public void testDecode() throws IOException {

final BytesReference bytes2 = totalBytes.slice(bytesConsumed, totalBytes.length() - bytesConsumed);
final ReleasableBytesReference releasable2 = ReleasableBytesReference.wrap(bytes2);
int bytesConsumed2 = decoder.decode(releasable2, fragments::add);
int bytesConsumed2 = decoder.decode(releasable2, null, (c, f) -> fragments.add(f));
assertEquals(totalBytes.length() - totalHeaderSize, bytesConsumed2);

final Object content = fragments.get(0);
Expand Down Expand Up @@ -142,7 +142,7 @@ public void testDecodePreHeaderSizeVariableInt() throws IOException {
InboundDecoder decoder = new InboundDecoder(Version.CURRENT, recycler);
final ArrayList<Object> fragments = new ArrayList<>();
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(totalBytes);
int bytesConsumed = decoder.decode(releasable1, fragments::add);
int bytesConsumed = decoder.decode(releasable1, null, (c, f) -> fragments.add(f));
assertEquals(partialHeaderSize, bytesConsumed);
assertTrue(releasable1.hasReferences());

Expand All @@ -161,7 +161,7 @@ public void testDecodePreHeaderSizeVariableInt() throws IOException {

final BytesReference bytes2 = totalBytes.slice(bytesConsumed, totalBytes.length() - bytesConsumed);
final ReleasableBytesReference releasable2 = ReleasableBytesReference.wrap(bytes2);
int bytesConsumed2 = decoder.decode(releasable2, fragments::add);
int bytesConsumed2 = decoder.decode(releasable2, null, (c, f) -> fragments.add(f));
if (compressionScheme == null) {
assertEquals(2, fragments.size());
} else {
Expand Down Expand Up @@ -199,7 +199,7 @@ public void testDecodeHandshakeCompatibility() throws IOException {
InboundDecoder decoder = new InboundDecoder(Version.CURRENT, recycler);
final ArrayList<Object> fragments = new ArrayList<>();
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(bytes);
int bytesConsumed = decoder.decode(releasable1, fragments::add);
int bytesConsumed = decoder.decode(releasable1, null, (c, f) -> fragments.add(f));
assertEquals(totalHeaderSize, bytesConsumed);
assertTrue(releasable1.hasReferences());

Expand Down Expand Up @@ -248,7 +248,7 @@ public void testCompressedDecode() throws IOException {
InboundDecoder decoder = new InboundDecoder(Version.CURRENT, recycler);
final ArrayList<Object> fragments = new ArrayList<>();
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(totalBytes);
int bytesConsumed = decoder.decode(releasable1, fragments::add);
int bytesConsumed = decoder.decode(releasable1, null, (c, f) -> fragments.add(f));
assertEquals(totalHeaderSize, bytesConsumed);
assertTrue(releasable1.hasReferences());

Expand All @@ -270,7 +270,7 @@ public void testCompressedDecode() throws IOException {

final BytesReference bytes2 = totalBytes.slice(bytesConsumed, totalBytes.length() - bytesConsumed);
final ReleasableBytesReference releasable2 = ReleasableBytesReference.wrap(bytes2);
int bytesConsumed2 = decoder.decode(releasable2, fragments::add);
int bytesConsumed2 = decoder.decode(releasable2, null, (c, f) -> fragments.add(f));
assertEquals(totalBytes.length() - totalHeaderSize, bytesConsumed2);

final Object compressionScheme = fragments.get(0);
Expand Down Expand Up @@ -312,7 +312,7 @@ public void testCompressedDecodeHandshakeCompatibility() throws IOException {
InboundDecoder decoder = new InboundDecoder(Version.CURRENT, recycler);
final ArrayList<Object> fragments = new ArrayList<>();
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(bytes);
int bytesConsumed = decoder.decode(releasable1, fragments::add);
int bytesConsumed = decoder.decode(releasable1, null, (c, f) -> fragments.add(f));
assertEquals(totalHeaderSize, bytesConsumed);
assertTrue(releasable1.hasReferences());

Expand Down Expand Up @@ -350,7 +350,7 @@ public void testVersionIncompatibilityDecodeException() throws IOException {
final ArrayList<Object> fragments = new ArrayList<>();
try (ReleasableBytesReference r = ReleasableBytesReference.wrap(bytes)) {
releasable1 = r;
expectThrows(IllegalStateException.class, () -> decoder.decode(releasable1, fragments::add));
expectThrows(IllegalStateException.class, () -> decoder.decode(releasable1, null, (c, f) -> fragments.add(f)));
}
}
// No bytes are retained
Expand Down