Skip to content

Commit

Permalink
Add Yamux specific unit tests (#298)
Browse files Browse the repository at this point in the history
* Add Yamux specific unit tests
* Minor Yamux refactorings
  • Loading branch information
StefanBratanov authored Aug 15, 2023
1 parent bf15e47 commit 7b853f6
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 72 deletions.
4 changes: 2 additions & 2 deletions libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxId.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package io.libp2p.etc.util.netty.mux
import io.netty.channel.ChannelId

data class MuxId(val parentId: ChannelId, val id: Long, val initiator: Boolean) : ChannelId {
override fun asShortText() = "$parentId/$id/$initiator"
override fun asLongText() = asShortText()
override fun asShortText() = "${parentId.asShortText()}/$id/$initiator"
override fun asLongText() = "${parentId.asLongText()}/$id/$initiator"
override fun compareTo(other: ChannelId?) = asShortText().compareTo(other?.asShortText() ?: "")
override fun toString() = asLongText()
}
14 changes: 8 additions & 6 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@ import io.netty.buffer.Unpooled

/**
* Contains the fields that comprise a yamux frame.
* @param streamId the ID of the stream.
* @param flag the flag value for this frame.
* @param id the ID of the stream.
* @param flags the flags value for this frame.
* @param length the length field for this frame.
* @param data the data segment.
*/
class YamuxFrame(val id: MuxId, val type: Int, val flags: Int, val lenData: Long, val data: ByteBuf? = null) :
class YamuxFrame(val id: MuxId, val type: Int, val flags: Int, val length: Long, val data: ByteBuf? = null) :
DefaultByteBufHolder(data ?: Unpooled.EMPTY_BUFFER) {

override fun toString(): String {
if (data == null)
return "YamuxFrame(id=$id, type=$type, flag=$flags)"
return "YamuxFrame(id=$id, type=$type, flag=$flags, data=${String(data.toByteArray())})"
if (data == null) {
return "YamuxFrame(id=$id, type=$type, flags=$flags, length=$length)"
}
return "YamuxFrame(id=$id, type=$type, flags=$flags, length=$length, data=${String(data.toByteArray())})"
}
}
30 changes: 21 additions & 9 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class YamuxFrameCodec(
out.writeByte(msg.type)
out.writeShort(msg.flags)
out.writeInt(msg.id.id.toInt())
out.writeInt(msg.data?.readableBytes() ?: msg.lenData.toInt())
out.writeInt(msg.data?.readableBytes() ?: msg.length.toInt())
out.writeBytes(msg.data ?: Unpooled.EMPTY_BUFFER)
}

Expand All @@ -42,32 +42,44 @@ class YamuxFrameCodec(
*/
override fun decode(ctx: ChannelHandlerContext, msg: ByteBuf, out: MutableList<Any>) {
while (msg.isReadable) {
if (msg.readableBytes() < 12)
if (msg.readableBytes() < 12) {
return
}
val readerIndex = msg.readerIndex()
msg.readByte(); // version always 0
val type = msg.readUnsignedByte()
val flags = msg.readUnsignedShort()
val streamId = msg.readUnsignedInt()
val lenData = msg.readUnsignedInt()
val length = msg.readUnsignedInt()
if (type.toInt() != YamuxType.DATA) {
val yamuxFrame = YamuxFrame(MuxId(ctx.channel().id(), streamId, isInitiator.xor(streamId.mod(2).equals(1)).not()), type.toInt(), flags, lenData)
val yamuxFrame = YamuxFrame(
MuxId(ctx.channel().id(), streamId, isInitiator.xor(streamId.mod(2) == 1).not()),
type.toInt(),
flags,
length
)
out.add(yamuxFrame)
continue
}
if (lenData > maxFrameDataLength) {
if (length > maxFrameDataLength) {
msg.skipBytes(msg.readableBytes())
throw ProtocolViolationException("Yamux frame is too large: $lenData")
throw ProtocolViolationException("Yamux frame is too large: $length")
}
if (msg.readableBytes() < lenData) {
if (msg.readableBytes() < length) {
// not enough data to read the frame content
// will wait for more ...
msg.readerIndex(readerIndex)
return
}
val data = msg.readSlice(lenData.toInt())
val data = msg.readSlice(length.toInt())
data.retain() // MessageToMessageCodec releases original buffer, but it needs to be relayed
val yamuxFrame = YamuxFrame(MuxId(ctx.channel().id(), streamId, isInitiator.xor(streamId.mod(2).equals(1)).not()), type.toInt(), flags, lenData, data)
val yamuxFrame = YamuxFrame(
MuxId(ctx.channel().id(), streamId, isInitiator.xor(streamId.mod(2) == 1).not()),
type.toInt(),
flags,
length,
data
)
out.add(yamuxFrame)
}
}
Expand Down
59 changes: 37 additions & 22 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ import io.libp2p.etc.util.netty.mux.MuxId
import io.libp2p.mux.MuxHandler
import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelHandlerContext
import org.slf4j.LoggerFactory
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger

const val INITIAL_WINDOW_SIZE = 256 * 1024
const val MAX_BUFFERED_CONNECTION_WRITES = 1024 * 1024

private val log = LoggerFactory.getLogger(YamuxHandler::class.java)

open class YamuxHandler(
override val multistreamProtocol: MultistreamProtocol,
override val maxFrameDataLength: Int,
Expand All @@ -39,7 +42,7 @@ open class YamuxHandler(

fun flush(sendWindow: AtomicInteger, id: MuxId): Int {
var written = 0
while (! buffered.isEmpty()) {
while (!buffered.isEmpty()) {
val buf = buffered.first()
val readableBytes = buf.readableBytes()
if (readableBytes + written < sendWindow.get()) {
Expand All @@ -65,38 +68,48 @@ open class YamuxHandler(
YamuxType.DATA -> handleDataRead(msg)
YamuxType.WINDOW_UPDATE -> handleWindowUpdate(msg)
YamuxType.PING -> handlePing(msg)
YamuxType.GO_AWAY -> onRemoteClose(msg.id)
YamuxType.GO_AWAY -> handleGoAway(msg)
}
}

fun handlePing(msg: YamuxFrame) {
private fun handlePing(msg: YamuxFrame) {
val ctx = getChannelHandlerContext()
when (msg.flags) {
YamuxFlags.SYN -> ctx.writeAndFlush(YamuxFrame(MuxId(msg.id.parentId, 0, msg.id.initiator), YamuxType.PING, YamuxFlags.ACK, msg.lenData))
YamuxFlags.SYN -> ctx.writeAndFlush(
YamuxFrame(
MuxId(msg.id.parentId, 0, msg.id.initiator),
YamuxType.PING,
YamuxFlags.ACK,
msg.length
)
)

YamuxFlags.ACK -> {}
}
}

fun handleFlags(msg: YamuxFrame) {
private fun handleFlags(msg: YamuxFrame) {
val ctx = getChannelHandlerContext()
when (msg.flags) {
YamuxFlags.SYN -> {
// ACK the new stream
onRemoteYamuxOpen(msg.id)
ctx.writeAndFlush(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0))
}

YamuxFlags.FIN -> onRemoteDisconnect(msg.id)
YamuxFlags.RST -> onRemoteClose(msg.id)
}
}

fun handleDataRead(msg: YamuxFrame) {
private fun handleDataRead(msg: YamuxFrame) {
val ctx = getChannelHandlerContext()
val size = msg.lenData
val size = msg.length
handleFlags(msg)
if (size.toInt() == 0)
if (size.toInt() == 0) {
return
val recWindow = receiveWindows.get(msg.id)
}
val recWindow = receiveWindows[msg.id]
if (recWindow == null) {
releaseMessage(msg.data!!)
throw Libp2pException("No receive window for " + msg.id)
Expand All @@ -111,36 +124,38 @@ open class YamuxHandler(
childRead(msg.id, msg.data!!)
}

fun handleWindowUpdate(msg: YamuxFrame) {
private fun handleWindowUpdate(msg: YamuxFrame) {
handleFlags(msg)
val size = msg.lenData.toInt()
if (size == 0)
return
val sendWindow = sendWindows.get(msg.id)
if (sendWindow == null) {
val size = msg.length.toInt()
if (size == 0) {
return
}
val sendWindow = sendWindows[msg.id] ?: return
sendWindow.addAndGet(size)
val buffer = sendBuffers.get(msg.id)
val buffer = sendBuffers[msg.id]
if (buffer != null) {
val writtenBytes = buffer.flush(sendWindow, msg.id)
totalBufferedWrites.addAndGet(-writtenBytes)
}
}

private fun handleGoAway(msg: YamuxFrame) {
log.debug("Session will be terminated. Go Away message with with error code ${msg.length} has been received.")
onRemoteClose(msg.id)
}

override fun onChildWrite(child: MuxChannel<ByteBuf>, data: ByteBuf) {
val ctx = getChannelHandlerContext()

val sendWindow = sendWindows.get(child.id)
if (sendWindow == null) {
throw Libp2pException("No send window for " + child.id)
}
val sendWindow = sendWindows[child.id] ?: throw Libp2pException("No send window for " + child.id)

if (sendWindow.get() <= 0) {
// wait until the window is increased to send more data
val buffer = sendBuffers.getOrPut(child.id, { SendBuffer(ctx) })
val buffer = sendBuffers.getOrPut(child.id) { SendBuffer(ctx) }
buffer.add(data)
if (totalBufferedWrites.addAndGet(data.readableBytes()) > MAX_BUFFERED_CONNECTION_WRITES)
if (totalBufferedWrites.addAndGet(data.readableBytes()) > MAX_BUFFERED_CONNECTION_WRITES) {
throw Libp2pException("Overflowed send buffer for connection")
}
return
}
sendBlocks(ctx, data, sendWindow, child.id)
Expand Down
2 changes: 1 addition & 1 deletion libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.libp2p.mux.yamux

/**
* Contains all the permissible values for flags in the <code>yamux</code> protocol.
* Contains all the permissible values for types in the <code>yamux</code> protocol.
*/
object YamuxType {
const val DATA = 0
Expand Down
30 changes: 15 additions & 15 deletions libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import io.libp2p.core.StreamHandler
import io.libp2p.etc.types.fromHex
import io.libp2p.etc.types.getX
import io.libp2p.etc.types.toHex
import io.libp2p.etc.util.netty.mux.MuxId
import io.libp2p.etc.util.netty.mux.RemoteWriteClosed
import io.libp2p.etc.util.netty.nettyInitializer
import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.*
import io.libp2p.mux.MuxHandlerAbstractTest.TestEventHandler
import io.libp2p.tools.TestChannel
import io.libp2p.tools.readAllBytesAndRelease
import io.netty.buffer.ByteBuf
Expand All @@ -20,10 +22,7 @@ import io.netty.handler.logging.LoggingHandler
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.data.Index
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertThrows
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Assertions.*
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import java.util.concurrent.CompletableFuture
Expand Down Expand Up @@ -95,10 +94,11 @@ abstract class MuxHandlerAbstractTest {
enum class Flag { Open, Data, Close, Reset }
}

fun Long.toMuxId() = MuxId(parentChannelId, this, true)

abstract fun writeFrame(frame: AbstractTestMuxFrame)
abstract fun readFrame(): AbstractTestMuxFrame?
fun readFrameOrThrow() = readFrame() ?: throw AssertionError("No outbound frames")

fun openStream(id: Long) = writeFrame(AbstractTestMuxFrame(id, Open))
fun writeStream(id: Long, msg: String) = writeFrame(AbstractTestMuxFrame(id, Data, msg))
fun closeStream(id: Long) = writeFrame(AbstractTestMuxFrame(id, Close))
Expand Down Expand Up @@ -478,66 +478,66 @@ abstract class MuxHandlerAbstractTest {
override fun handlerAdded(ctx: ChannelHandlerContext) {
assertFalse(isHandlerAdded)
isHandlerAdded = true
println("MultiplexHandlerTest.handlerAdded")
println("MuxHandlerAbstractTest.handlerAdded")
this.ctx = ctx
}

override fun channelRegistered(ctx: ChannelHandlerContext?) {
assertTrue(isHandlerAdded)
assertFalse(isRegistered)
isRegistered = true
println("MultiplexHandlerTest.channelRegistered")
println("MuxHandlerAbstractTest.channelRegistered")
}

override fun channelActive(ctx: ChannelHandlerContext) {
assertTrue(isRegistered)
assertFalse(isActivated)
isActivated = true
println("MultiplexHandlerTest.channelActive")
println("MuxHandlerAbstractTest.channelActive")
activeEventHandlers.forEach { it.handle(this) }
}

override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
assertTrue(isActivated)
println("MultiplexHandlerTest.channelRead")
println("MuxHandlerAbstractTest.channelRead")
msg as ByteBuf
inboundMessages += msg.readAllBytesAndRelease().toHex()
}

override fun channelReadComplete(ctx: ChannelHandlerContext?) {
readCompleteEventCount++
println("MultiplexHandlerTest.channelReadComplete")
println("MuxHandlerAbstractTest.channelReadComplete")
}

override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) {
userEvents += evt
println("MultiplexHandlerTest.userEventTriggered: $evt")
println("MuxHandlerAbstractTest.userEventTriggered: $evt")
}

override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
exceptions += cause
println("MultiplexHandlerTest.exceptionCaught")
println("MuxHandlerAbstractTest.exceptionCaught")
}

override fun channelInactive(ctx: ChannelHandlerContext) {
assertTrue(isActivated)
assertFalse(isInactivated)
isInactivated = true
println("MultiplexHandlerTest.channelInactive")
println("MuxHandlerAbstractTest.channelInactive")
}

override fun channelUnregistered(ctx: ChannelHandlerContext?) {
assertTrue(isInactivated)
assertFalse(isUnregistered)
isUnregistered = true
println("MultiplexHandlerTest.channelUnregistered")
println("MuxHandlerAbstractTest.channelUnregistered")
}

override fun handlerRemoved(ctx: ChannelHandlerContext?) {
assertTrue(isUnregistered)
assertFalse(isHandlerRemoved)
isHandlerRemoved = true
println("MultiplexHandlerTest.handlerRemoved")
println("MuxHandlerAbstractTest.handlerRemoved")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import io.libp2p.core.StreamHandler
import io.libp2p.core.multistream.MultistreamProtocolV1
import io.libp2p.etc.types.fromHex
import io.libp2p.etc.types.toHex
import io.libp2p.etc.util.netty.mux.MuxId
import io.libp2p.mux.MuxHandler
import io.libp2p.mux.MuxHandlerAbstractTest
import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.*
Expand All @@ -28,6 +27,7 @@ class MplexHandlerTest : MuxHandlerAbstractTest() {
}

override fun writeFrame(frame: AbstractTestMuxFrame) {
val muxId = frame.streamId.toMuxId()
val mplexFlag = when (frame.flag) {
Open -> MplexFlag.Type.OPEN
Data -> MplexFlag.Type.DATA
Expand All @@ -39,7 +39,7 @@ class MplexHandlerTest : MuxHandlerAbstractTest() {
else -> frame.data.fromHex().toByteBuf(allocateBuf())
}
val mplexFrame =
MplexFrame(MuxId(parentChannelId, frame.streamId, true), MplexFlag.getByType(mplexFlag, true), data)
MplexFrame(muxId, MplexFlag.getByType(mplexFlag, true), data)
ech.writeInbound(mplexFrame)
}

Expand All @@ -51,10 +51,9 @@ class MplexHandlerTest : MuxHandlerAbstractTest() {
MplexFlag.Type.DATA -> Data
MplexFlag.Type.CLOSE -> Close
MplexFlag.Type.RESET -> Reset
else -> throw AssertionError("Unknown mplex flag: ${mplexFrame.flag}")
}
val sData = maybeMplexFrame.data.readAllBytesAndRelease().toHex()
AbstractTestMuxFrame(mplexFrame.id.id, flag, sData)
val data = maybeMplexFrame.data.readAllBytesAndRelease().toHex()
AbstractTestMuxFrame(mplexFrame.id.id, flag, data)
}
}
}
Loading

0 comments on commit 7b853f6

Please sign in to comment.