Skip to content

Commit

Permalink
Fix yamux handling of writes bigger than the window size (#295)
Browse files Browse the repository at this point in the history
* Fix yamux handling of writes bigger than the window size
* Fix inefficient window size handling
* Release delayed yamux send buffer once sent. 
* Add Unit test
---------
Co-authored-by: Anton Nashatyrev <anton.nashatyrev@gmail.com>
  • Loading branch information
ianopolous authored Aug 15, 2023
1 parent 3c88678 commit bf15e47
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 10 deletions.
19 changes: 13 additions & 6 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,26 @@ open class YamuxHandler(
private val buffered = ArrayDeque<ByteBuf>()

fun add(data: ByteBuf) {
buffered.add(data)
buffered.add(data.retain())
}

fun flush(sendWindow: AtomicInteger, id: MuxId): Int {
var written = 0
while (! buffered.isEmpty()) {
val buf = buffered.first()
if (buf.readableBytes() + written < sendWindow.get()) {
buffered.removeFirst()
val readableBytes = buf.readableBytes()
if (readableBytes + written < sendWindow.get()) {
sendBlocks(ctx, buf, sendWindow, id)
written += buf.readableBytes()
} else
written += readableBytes
buf.release()
buffered.removeFirst()
} else {
// partial write to fit within window
val toRead = sendWindow.get() - written
sendBlocks(ctx, buf.readSlice(toRead), sendWindow, id)
written += toRead
break
}
}
return written
}
Expand Down Expand Up @@ -96,7 +103,7 @@ open class YamuxHandler(
}
val newWindow = recWindow.addAndGet(-size.toInt())
if (newWindow < INITIAL_WINDOW_SIZE / 2) {
val delta = INITIAL_WINDOW_SIZE / 2
val delta = INITIAL_WINDOW_SIZE - newWindow
recWindow.addAndGet(delta)
ctx.write(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta.toLong()))
ctx.flush()
Expand Down
9 changes: 6 additions & 3 deletions libp2p/src/main/kotlin/io/libp2p/protocol/Ping.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,23 @@ interface PingController {
fun ping(): CompletableFuture<Long>
}

class Ping : PingBinding(PingProtocol())
class Ping(pingSize: Int) : PingBinding(PingProtocol(pingSize)) {
constructor() : this(32)
}

open class PingBinding(ping: PingProtocol) :
StrictProtocolBinding<PingController>("/ipfs/ping/1.0.0", ping)

class PingTimeoutException : Libp2pException()

open class PingProtocol : ProtocolHandler<PingController>(Long.MAX_VALUE, Long.MAX_VALUE) {
open class PingProtocol(var pingSize: Int) : ProtocolHandler<PingController>(Long.MAX_VALUE, Long.MAX_VALUE) {
var timeoutScheduler by lazyVar { Executors.newSingleThreadScheduledExecutor() }
var curTime: () -> Long = { System.currentTimeMillis() }
var random = Random()
var pingSize = 32
var pingTimeout = Duration.ofSeconds(10)

constructor() : this(32)

override fun onStartInitiator(stream: Stream): CompletableFuture<PingController> {
val handler = PingInitiator()
stream.pushHandler(handler)
Expand Down
63 changes: 63 additions & 0 deletions libp2p/src/test/java/io/libp2p/core/HostTestJava.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,69 @@ void ping() throws Exception {
System.out.println("Server stopped");
}

@Test
void largePing() throws Exception {
int pingSize = 200 * 1024;
String localListenAddress = "/ip4/127.0.0.1/tcp/40002";

Host clientHost = new HostBuilder()
.transport(TcpTransport::new)
.secureChannel((k, m) -> new TlsSecureChannel(k, m, "ECDSA"))
.muxer(StreamMuxerProtocol::getYamux)
.build();

Host serverHost = new HostBuilder()
.transport(TcpTransport::new)
.secureChannel(TlsSecureChannel::new)
.muxer(StreamMuxerProtocol::getYamux)
.protocol(new Ping(pingSize))
.listen(localListenAddress)
.build();

CompletableFuture<Void> clientStarted = clientHost.start();
CompletableFuture<Void> serverStarted = serverHost.start();
clientStarted.get(5, TimeUnit.SECONDS);
System.out.println("Client started");
serverStarted.get(5, TimeUnit.SECONDS);
System.out.println("Server started");

Assertions.assertEquals(0, clientHost.listenAddresses().size());
Assertions.assertEquals(1, serverHost.listenAddresses().size());
Assertions.assertEquals(
localListenAddress + "/p2p/" + serverHost.getPeerId(),
serverHost.listenAddresses().get(0).toString()
);

StreamPromise<PingController> ping =
clientHost.getNetwork().connect(
serverHost.getPeerId(),
new Multiaddr(localListenAddress)
).thenApply(
it -> it.muxerSession().createStream(new Ping(pingSize))
)
.join();

Stream pingStream = ping.getStream().get(5, TimeUnit.SECONDS);
System.out.println("Ping stream created");
PingController pingCtr = ping.getController().get(5, TimeUnit.SECONDS);
System.out.println("Ping controller created");

for (int i = 0; i < 10; i++) {
long latency = pingCtr.ping().join();//get(5, TimeUnit.SECONDS);
System.out.println("Ping is " + latency);
}
pingStream.close().get(5, TimeUnit.SECONDS);
System.out.println("Ping stream closed");

Assertions.assertThrows(ExecutionException.class, () ->
pingCtr.ping().get(5, TimeUnit.SECONDS));

clientHost.stop().get(5, TimeUnit.SECONDS);
System.out.println("Client stopped");
serverHost.stop().get(5, TimeUnit.SECONDS);
System.out.println("Server stopped");
}

@Test
void keyPairGeneration() {
Pair<PrivKey, PubKey> pair = KeyKt.generateKeyPair(KEY_TYPE.SECP256K1);
Expand Down
29 changes: 28 additions & 1 deletion libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import io.libp2p.mux.MuxHandlerAbstractTest
import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.*
import io.libp2p.tools.readAllBytesAndRelease
import io.netty.channel.ChannelHandlerContext
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test

class YamuxHandlerTest : MuxHandlerAbstractTest() {

Expand All @@ -27,8 +29,10 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
}
}

private fun Long.toMuxId() = MuxId(parentChannelId, this, true)

override fun writeFrame(frame: AbstractTestMuxFrame) {
val muxId = MuxId(parentChannelId, frame.streamId, true)
val muxId = frame.streamId.toMuxId()
val yamuxFrame = when (frame.flag) {
Open -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.SYN, 0)
Data -> YamuxFrame(
Expand Down Expand Up @@ -65,4 +69,27 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {

return readFrameQueue.removeFirstOrNull()
}

@Test
fun `data should be buffered and sent after window increased from zero`() {
val handler = openStreamByLocal()
val streamId = readFrameOrThrow().streamId

ech.writeInbound(
YamuxFrame(
streamId.toMuxId(),
YamuxType.WINDOW_UPDATE,
YamuxFlags.ACK,
-INITIAL_WINDOW_SIZE.toLong()
)
)

handler.ctx.writeAndFlush("1984".fromHex().toByteBuf(allocateBuf()))

assertThat(readFrame()).isNull()

ech.writeInbound(YamuxFrame(streamId.toMuxId(), YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 5000))
val frame = readFrameOrThrow()
assertThat(frame.data).isEqualTo("1984")
}
}

0 comments on commit bf15e47

Please sign in to comment.