Skip to content

Delegate Ref Counting to ByteBuf in Netty Transport #81096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions libs/nio/src/main/java/org/elasticsearch/nio/Page.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public Page(ByteBuffer byteBuffer, Releasable closeable) {
}

private Page(ByteBuffer byteBuffer, RefCountedCloseable refCountedCloseable) {
assert refCountedCloseable.refCount() > 0;
assert refCountedCloseable.hasReferences();
this.byteBuffer = byteBuffer;
this.refCountedCloseable = refCountedCloseable;
}
Expand All @@ -51,7 +51,7 @@ public Page duplicate() {
* @return the byte buffer
*/
public ByteBuffer byteBuffer() {
assert refCountedCloseable.refCount() > 0;
assert refCountedCloseable.hasReferences();
return byteBuffer;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.InboundPipeline;
Expand Down Expand Up @@ -68,7 +69,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
final ByteBuf buffer = (ByteBuf) msg;
Netty4TcpChannel channel = ctx.channel().attr(Netty4Transport.CHANNEL_KEY).get();
final BytesReference wrapped = Netty4Utils.toBytesReference(buffer);
try (ReleasableBytesReference reference = new ReleasableBytesReference(wrapped, buffer::release)) {
try (ReleasableBytesReference reference = new ReleasableBytesReference(wrapped, new ByteBufRefCounted(buffer))) {
pipeline.handleBytes(channel, reference);
}
}
Expand Down Expand Up @@ -211,4 +212,43 @@ void failAsClosedChannel() {
buf.release();
}
}

private static final class ByteBufRefCounted implements RefCounted {

private final ByteBuf buffer;

ByteBufRefCounted(ByteBuf buffer) {
this.buffer = buffer;
}

@Override
public void incRef() {
buffer.retain();
}

@Override
public boolean tryIncRef() {
if (hasReferences() == false) {
return false;
}
try {
buffer.retain();
} catch (RuntimeException e) {
assert hasReferences() == false;
return false;
}
return true;
}

@Override
public boolean decRef() {
return buffer.release();
}

@Override
public boolean hasReferences() {
return buffer.refCnt() > 0;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public final class ReleasableBytesReference implements RefCounted, Releasable, B
private static final ReleasableBytesReference EMPTY = new ReleasableBytesReference(BytesArray.EMPTY, NO_OP);

private final BytesReference delegate;
private final AbstractRefCounted refCounted;
private final RefCounted refCounted;

public static ReleasableBytesReference empty() {
EMPTY.incRef();
Expand All @@ -42,21 +42,17 @@ public ReleasableBytesReference(BytesReference delegate, Releasable releasable)
this(delegate, new RefCountedReleasable(releasable));
}

public ReleasableBytesReference(BytesReference delegate, AbstractRefCounted refCounted) {
public ReleasableBytesReference(BytesReference delegate, RefCounted refCounted) {
this.delegate = delegate;
this.refCounted = refCounted;
assert refCounted.refCount() > 0;
assert refCounted.hasReferences();
}

public static ReleasableBytesReference wrap(BytesReference reference) {
assert reference instanceof ReleasableBytesReference == false : "use #retain() instead of #wrap() on a " + reference.getClass();
return reference.length() == 0 ? empty() : new ReleasableBytesReference(reference, NO_OP);
}

public int refCount() {
return refCounted.refCount();
}

@Override
public void incRef() {
refCounted.incRef();
Expand Down Expand Up @@ -98,19 +94,19 @@ public void close() {

@Override
public byte get(int index) {
assert refCount() > 0;
assert hasReferences();
return delegate.get(index);
}

@Override
public int getInt(int index) {
assert refCount() > 0;
assert hasReferences();
return delegate.getInt(index);
}

@Override
public int indexOf(byte marker, int from) {
assert refCount() > 0;
assert hasReferences();
return delegate.indexOf(marker, from);
}

Expand All @@ -121,7 +117,7 @@ public int length() {

@Override
public BytesReference slice(int from, int length) {
assert refCount() > 0;
assert hasReferences();
return delegate.slice(from, length);
}

Expand All @@ -132,7 +128,7 @@ public long ramBytesUsed() {

@Override
public StreamInput streamInput() throws IOException {
assert refCount() > 0;
assert hasReferences();
return new BytesReferenceStreamInput(this) {
@Override
public ReleasableBytesReference readReleasableBytesReference() throws IOException {
Expand All @@ -148,37 +144,37 @@ public ReleasableBytesReference readReleasableBytesReference() throws IOExceptio

@Override
public void writeTo(OutputStream os) throws IOException {
assert refCount() > 0;
assert hasReferences();
delegate.writeTo(os);
}

@Override
public String utf8ToString() {
assert refCount() > 0;
assert hasReferences();
return delegate.utf8ToString();
}

@Override
public BytesRef toBytesRef() {
assert refCount() > 0;
assert hasReferences();
return delegate.toBytesRef();
}

@Override
public BytesRefIterator iterator() {
assert refCount() > 0;
assert hasReferences();
return delegate.iterator();
}

@Override
public int compareTo(BytesReference o) {
assert refCount() > 0;
assert hasReferences();
return delegate.compareTo(o);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
assert refCount() > 0;
assert hasReferences();
return delegate.toXContent(builder, params);
}

Expand All @@ -189,31 +185,31 @@ public boolean isFragment() {

@Override
public boolean equals(Object obj) {
assert refCount() > 0;
assert hasReferences();
return delegate.equals(obj);
}

@Override
public int hashCode() {
assert refCount() > 0;
assert hasReferences();
return delegate.hashCode();
}

@Override
public boolean hasArray() {
assert refCount() > 0;
assert hasReferences();
return delegate.hasArray();
}

@Override
public byte[] array() {
assert refCount() > 0;
assert hasReferences();
return delegate.array();
}

@Override
public int arrayOffset() {
assert refCount() > 0;
assert hasReferences();
return delegate.arrayOffset();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ public void testInboundAggregation() throws IOException {
assertThat(aggregated.getHeader().getRequestId(), equalTo(requestId));
assertThat(aggregated.getHeader().getVersion(), equalTo(Version.CURRENT));
for (ReleasableBytesReference reference : references) {
assertEquals(1, reference.refCount());
assertTrue(reference.hasReferences());
}
aggregated.close();
for (ReleasableBytesReference reference : references) {
assertEquals(0, reference.refCount());
assertFalse(reference.hasReferences());
}
}

Expand All @@ -111,7 +111,7 @@ public void testInboundUnknownAction() throws IOException {
final ReleasableBytesReference content = ReleasableBytesReference.wrap(bytes);
aggregator.aggregate(content);
content.close();
assertEquals(0, content.refCount());
assertFalse(content.hasReferences());

// Signal EOS
InboundMessage aggregated = aggregator.finishAggregation();
Expand Down Expand Up @@ -139,7 +139,7 @@ public void testCircuitBreak() throws IOException {
// Signal EOS
InboundMessage aggregated1 = aggregator.finishAggregation();

assertEquals(0, content1.refCount());
assertFalse(content1.hasReferences());
assertThat(aggregated1, notNullValue());
assertTrue(aggregated1.isShortCircuit());
assertThat(aggregated1.getException(), instanceOf(CircuitBreakingException.class));
Expand All @@ -158,7 +158,7 @@ public void testCircuitBreak() throws IOException {
// Signal EOS
InboundMessage aggregated2 = aggregator.finishAggregation();

assertEquals(1, content2.refCount());
assertTrue(content2.hasReferences());
assertThat(aggregated2, notNullValue());
assertFalse(aggregated2.isShortCircuit());

Expand All @@ -177,7 +177,7 @@ public void testCircuitBreak() throws IOException {
// Signal EOS
InboundMessage aggregated3 = aggregator.finishAggregation();

assertEquals(1, content3.refCount());
assertTrue(content3.hasReferences());
assertThat(aggregated3, notNullValue());
assertFalse(aggregated3.isShortCircuit());
}
Expand Down Expand Up @@ -211,7 +211,7 @@ public void testCloseWillCloseContent() {
aggregator.close();

for (ReleasableBytesReference reference : references) {
assertEquals(0, reference.refCount());
assertFalse(reference.hasReferences());
}
}

Expand Down Expand Up @@ -244,10 +244,10 @@ public void testFinishAggregationWillFinishHeader() throws IOException {
assertFalse(header.needsToReadVariableHeader());
assertEquals(actionName, header.getActionName());
if (unknownAction) {
assertEquals(0, content.refCount());
assertFalse(content.hasReferences());
assertTrue(aggregated.isShortCircuit());
} else {
assertEquals(1, content.refCount());
assertTrue(content.hasReferences());
assertFalse(aggregated.isShortCircuit());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void testDecode() throws IOException {
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(totalBytes);
int bytesConsumed = decoder.decode(releasable1, fragments::add);
assertEquals(totalHeaderSize, bytesConsumed);
assertEquals(1, releasable1.refCount());
assertTrue(releasable1.hasReferences());

final Header header = (Header) fragments.get(0);
assertEquals(requestId, header.getRequestId());
Expand Down Expand Up @@ -108,7 +108,10 @@ public void testDecode() throws IOException {

assertEquals(messageBytes, content);
// Ref count is incremented since the bytes are forwarded as a fragment
assertEquals(2, releasable2.refCount());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could reasonably keep this coverage by releasing releasable2 and asserting that it still hasReferences().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ brought this back

assertTrue(releasable2.hasReferences());
releasable2.decRef();
assertTrue(releasable2.hasReferences());
assertTrue(releasable2.decRef());
assertEquals(InboundDecoder.END_CONTENT, endMarker);
}

Expand Down Expand Up @@ -141,7 +144,7 @@ public void testDecodePreHeaderSizeVariableInt() throws IOException {
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(totalBytes);
int bytesConsumed = decoder.decode(releasable1, fragments::add);
assertEquals(partialHeaderSize, bytesConsumed);
assertEquals(1, releasable1.refCount());
assertTrue(releasable1.hasReferences());

final Header header = (Header) fragments.get(0);
assertEquals(requestId, header.getRequestId());
Expand Down Expand Up @@ -198,7 +201,7 @@ public void testDecodeHandshakeCompatibility() throws IOException {
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(bytes);
int bytesConsumed = decoder.decode(releasable1, fragments::add);
assertEquals(totalHeaderSize, bytesConsumed);
assertEquals(1, releasable1.refCount());
assertTrue(releasable1.hasReferences());

final Header header = (Header) fragments.get(0);
assertEquals(requestId, header.getRequestId());
Expand Down Expand Up @@ -247,7 +250,7 @@ public void testCompressedDecode() throws IOException {
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(totalBytes);
int bytesConsumed = decoder.decode(releasable1, fragments::add);
assertEquals(totalHeaderSize, bytesConsumed);
assertEquals(1, releasable1.refCount());
assertTrue(releasable1.hasReferences());

final Header header = (Header) fragments.get(0);
assertEquals(requestId, header.getRequestId());
Expand Down Expand Up @@ -279,7 +282,7 @@ public void testCompressedDecode() throws IOException {
assertThat(content, instanceOf(ReleasableBytesReference.class));
((ReleasableBytesReference) content).close();
// Ref count is not incremented since the bytes are immediately consumed on decompression
assertEquals(1, releasable2.refCount());
assertTrue(releasable2.hasReferences());
assertEquals(InboundDecoder.END_CONTENT, endMarker);
}

Expand Down Expand Up @@ -311,7 +314,7 @@ public void testCompressedDecodeHandshakeCompatibility() throws IOException {
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(bytes);
int bytesConsumed = decoder.decode(releasable1, fragments::add);
assertEquals(totalHeaderSize, bytesConsumed);
assertEquals(1, releasable1.refCount());
assertTrue(releasable1.hasReferences());

final Header header = (Header) fragments.get(0);
assertEquals(requestId, header.getRequestId());
Expand Down Expand Up @@ -339,16 +342,19 @@ public void testVersionIncompatibilityDecodeException() throws IOException {
Compression.Scheme.DEFLATE
);

final ReleasableBytesReference releasable1;
try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference bytes = message.serialize(os);

InboundDecoder decoder = new InboundDecoder(Version.CURRENT, recycler);
final ArrayList<Object> fragments = new ArrayList<>();
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(bytes);
expectThrows(IllegalStateException.class, () -> decoder.decode(releasable1, fragments::add));
// No bytes are retained
assertEquals(1, releasable1.refCount());
try (ReleasableBytesReference r = ReleasableBytesReference.wrap(bytes)) {
releasable1 = r;
expectThrows(IllegalStateException.class, () -> decoder.decode(releasable1, fragments::add));
}
}
// No bytes are retained
assertFalse(releasable1.hasReferences());
}

public void testEnsureVersionCompatibility() throws IOException {
Expand Down
Loading