Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix yamux handling of writes bigger than the window size #295

Merged
merged 3 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,26 @@ open class YamuxHandler(
private val buffered = ArrayDeque<ByteBuf>()

fun add(data: ByteBuf) {
buffered.add(data)
buffered.add(data.retain())
Nashatyrev marked this conversation as resolved.
Show resolved Hide resolved
}

fun flush(sendWindow: AtomicInteger, id: MuxId): Int {
var written = 0
while (! buffered.isEmpty()) {
val buf = buffered.first()
if (buf.readableBytes() + written < sendWindow.get()) {
buffered.removeFirst()
val readableBytes = buf.readableBytes()
if (readableBytes + written < sendWindow.get()) {
sendBlocks(ctx, buf, sendWindow, id)
written += buf.readableBytes()
} else
written += readableBytes
buf.release()
buffered.removeFirst()
} else {
// partial write to fit within window
val toRead = sendWindow.get() - written
sendBlocks(ctx, buf.readSlice(toRead), sendWindow, id)
written += toRead
break
}
}
return written
}
Expand Down Expand Up @@ -96,7 +103,7 @@ open class YamuxHandler(
}
val newWindow = recWindow.addAndGet(-size.toInt())
if (newWindow < INITIAL_WINDOW_SIZE / 2) {
val delta = INITIAL_WINDOW_SIZE / 2
val delta = INITIAL_WINDOW_SIZE - newWindow
recWindow.addAndGet(delta)
ctx.write(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta.toLong()))
ctx.flush()
Expand Down
9 changes: 6 additions & 3 deletions libp2p/src/main/kotlin/io/libp2p/protocol/Ping.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,23 @@ interface PingController {
fun ping(): CompletableFuture<Long>
}

class Ping : PingBinding(PingProtocol())
class Ping(pingSize: Int) : PingBinding(PingProtocol(pingSize)) {
constructor() : this(32)
}

open class PingBinding(ping: PingProtocol) :
StrictProtocolBinding<PingController>("/ipfs/ping/1.0.0", ping)

class PingTimeoutException : Libp2pException()

open class PingProtocol : ProtocolHandler<PingController>(Long.MAX_VALUE, Long.MAX_VALUE) {
open class PingProtocol(var pingSize: Int) : ProtocolHandler<PingController>(Long.MAX_VALUE, Long.MAX_VALUE) {
var timeoutScheduler by lazyVar { Executors.newSingleThreadScheduledExecutor() }
var curTime: () -> Long = { System.currentTimeMillis() }
var random = Random()
var pingSize = 32
var pingTimeout = Duration.ofSeconds(10)

constructor() : this(32)

override fun onStartInitiator(stream: Stream): CompletableFuture<PingController> {
val handler = PingInitiator()
stream.pushHandler(handler)
Expand Down
63 changes: 63 additions & 0 deletions libp2p/src/test/java/io/libp2p/core/HostTestJava.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,69 @@ void ping() throws Exception {
System.out.println("Server stopped");
}

@Test
void largePing() throws Exception {
int pingSize = 200 * 1024;
String localListenAddress = "/ip4/127.0.0.1/tcp/40002";

Host clientHost = new HostBuilder()
.transport(TcpTransport::new)
.secureChannel((k, m) -> new TlsSecureChannel(k, m, "ECDSA"))
.muxer(StreamMuxerProtocol::getYamux)
.build();

Host serverHost = new HostBuilder()
.transport(TcpTransport::new)
.secureChannel(TlsSecureChannel::new)
.muxer(StreamMuxerProtocol::getYamux)
.protocol(new Ping(pingSize))
.listen(localListenAddress)
.build();

CompletableFuture<Void> clientStarted = clientHost.start();
CompletableFuture<Void> serverStarted = serverHost.start();
clientStarted.get(5, TimeUnit.SECONDS);
System.out.println("Client started");
serverStarted.get(5, TimeUnit.SECONDS);
System.out.println("Server started");

Assertions.assertEquals(0, clientHost.listenAddresses().size());
Assertions.assertEquals(1, serverHost.listenAddresses().size());
Assertions.assertEquals(
localListenAddress + "/p2p/" + serverHost.getPeerId(),
serverHost.listenAddresses().get(0).toString()
);

StreamPromise<PingController> ping =
clientHost.getNetwork().connect(
serverHost.getPeerId(),
new Multiaddr(localListenAddress)
).thenApply(
it -> it.muxerSession().createStream(new Ping(pingSize))
)
.join();

Stream pingStream = ping.getStream().get(5, TimeUnit.SECONDS);
System.out.println("Ping stream created");
PingController pingCtr = ping.getController().get(5, TimeUnit.SECONDS);
System.out.println("Ping controller created");

for (int i = 0; i < 10; i++) {
long latency = pingCtr.ping().join();//get(5, TimeUnit.SECONDS);
System.out.println("Ping is " + latency);
}
pingStream.close().get(5, TimeUnit.SECONDS);
System.out.println("Ping stream closed");

Assertions.assertThrows(ExecutionException.class, () ->
pingCtr.ping().get(5, TimeUnit.SECONDS));

clientHost.stop().get(5, TimeUnit.SECONDS);
System.out.println("Client stopped");
serverHost.stop().get(5, TimeUnit.SECONDS);
System.out.println("Server stopped");
}

@Test
void keyPairGeneration() {
Pair<PrivKey, PubKey> pair = KeyKt.generateKeyPair(KEY_TYPE.SECP256K1);
Expand Down
29 changes: 28 additions & 1 deletion libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import io.libp2p.mux.MuxHandlerAbstractTest
import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.*
import io.libp2p.tools.readAllBytesAndRelease
import io.netty.channel.ChannelHandlerContext
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test

class YamuxHandlerTest : MuxHandlerAbstractTest() {

Expand All @@ -27,8 +29,10 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
}
}

private fun Long.toMuxId() = MuxId(parentChannelId, this, true)

override fun writeFrame(frame: AbstractTestMuxFrame) {
val muxId = MuxId(parentChannelId, frame.streamId, true)
val muxId = frame.streamId.toMuxId()
val yamuxFrame = when (frame.flag) {
Open -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.SYN, 0)
Data -> YamuxFrame(
Expand Down Expand Up @@ -65,4 +69,27 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {

return readFrameQueue.removeFirstOrNull()
}

@Test
fun `data should be buffered and sent after window increased from zero`() {
val handler = openStreamByLocal()
val streamId = readFrameOrThrow().streamId

ech.writeInbound(
YamuxFrame(
streamId.toMuxId(),
YamuxType.WINDOW_UPDATE,
YamuxFlags.ACK,
-INITIAL_WINDOW_SIZE.toLong()
)
)

handler.ctx.writeAndFlush("1984".fromHex().toByteBuf(allocateBuf()))

assertThat(readFrame()).isNull()

ech.writeInbound(YamuxFrame(streamId.toMuxId(), YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 5000))
val frame = readFrameOrThrow()
assertThat(frame.data).isEqualTo("1984")
}
}