Skip to content

Commit

Permalink
Limit number of streams per protocol per peer (#811)
Browse files Browse the repository at this point in the history
  • Loading branch information
Menduist authored Dec 1, 2022
1 parent 31ad4ae commit 64cbbe1
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 14 deletions.
2 changes: 1 addition & 1 deletion examples/tutorial_2_customproto.nim
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ proc new(T: typedesc[TestProto]): T =
# We must close the connections ourselves when we're done with it
await conn.close()

return T(codecs: @[TestCodec], handler: handle)
return T.new(codecs = @[TestCodec], handler = handle)

## This is a constructor for our `TestProto`, that will specify our `codecs` and a `handler`, which will be called for each incoming peer asking for this protocol.
## In our handle, we simply read a message from the connection and `echo` it.
Expand Down
6 changes: 3 additions & 3 deletions examples/tutorial_3_protobuf.nim
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,16 @@ type
metricGetter: MetricCallback

proc new(_: typedesc[MetricProto], cb: MetricCallback): MetricProto =
let res = MetricProto(metricGetter: cb)
var res: MetricProto
proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
let
metrics = await res.metricGetter()
asProtobuf = metrics.encode()
await conn.writeLp(asProtobuf.buffer)
await conn.close()

res.codecs = @["/metric-getter/1.0.0"]
res.handler = handle
res = MetricProto.new(@["/metric-getter/1.0.0"], handle)
res.metricGetter = cb
return res

proc fetch(p: MetricProto, conn: Connection): Future[MetricList] {.async.} =
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorial_5_discovery.nim
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ proc new(T: typedesc[DumbProto], nodeNumber: int): T =
proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
echo "Node", nodeNumber, " received: ", string.fromBytes(await conn.readLp(1024))
await conn.close()
return T(codecs: @[DumbCodec], handler: handle)
return T.new(codecs = @[DumbCodec], handler = handle)

## ## Bootnodes
## The first time a p2p program is ran, he needs to know how to join
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorial_6_game.nim
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ proc new(T: typedesc[GameProto], g: Game): T =
# The handler of a protocol must wait for the stream to
# be finished before returning
await conn.join()
return T(codecs: @["/tron/1.0.0"], handler: handle)
return T.new(codecs = @["/tron/1.0.0"], handler = handle)

proc networking(g: Game) {.async.} =
# Create our switch, similar to the GossipSub example and
Expand Down
30 changes: 23 additions & 7 deletions libp2p/multistream.nim
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}

import std/[strutils, sequtils]
import std/[strutils, sequtils, tables]
import chronos, chronicles, stew/byteutils
import stream/connection,
protocols/protocol
Expand All @@ -21,7 +21,7 @@ logScope:
topics = "libp2p multistream"

const
MsgSize* = 64*1024
MsgSize* = 1024
Codec* = "/multistream/1.0.0"

MSCodec* = "\x13" & Codec & "\n"
Expand All @@ -33,17 +33,20 @@ type

MultiStreamError* = object of LPError

HandlerHolder* = object
HandlerHolder* = ref object
protos*: seq[string]
protocol*: LPProtocol
match*: Matcher
openedStreams: CountTable[PeerId]

MultistreamSelect* = ref object of RootObj
handlers*: seq[HandlerHolder]
codec*: string

proc new*(T: typedesc[MultistreamSelect]): T =
T(codec: MSCodec)
T(
codec: MSCodec,
)

template validateSuffix(str: string): untyped =
if str.endsWith("\n"):
Expand Down Expand Up @@ -169,9 +172,22 @@ proc handle*(m: MultistreamSelect, conn: Connection, active: bool = false) {.asy
for h in m.handlers:
if (not isNil(h.match) and h.match(ms)) or h.protos.contains(ms):
trace "found handler", conn, protocol = ms
await conn.writeLp(ms & "\n")
conn.protocol = ms
await h.protocol.handler(conn, ms)

var protocolHolder = h
let maxIncomingStreams = protocolHolder.protocol.maxIncomingStreams
if protocolHolder.openedStreams.getOrDefault(conn.peerId) >= maxIncomingStreams:
debug "Max streams for protocol reached, blocking new stream",
conn, protocol = ms, maxIncomingStreams
return
protocolHolder.openedStreams.inc(conn.peerId)
try:
await conn.writeLp(ms & "\n")
conn.protocol = ms
await protocolHolder.protocol.handler(conn, ms)
finally:
protocolHolder.openedStreams.inc(conn.peerId, -1)
if protocolHolder.openedStreams[conn.peerId] == 0:
protocolHolder.openedStreams.del(conn.peerId)
return
debug "no handlers", conn, protocol = ms
await conn.write(Na)
Expand Down
26 changes: 25 additions & 1 deletion libp2p/protocols/protocol.nim
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ when (NimMajor, NimMinor) < (1, 4):
else:
{.push raises: [].}

import chronos
import chronos, stew/results
import ../stream/connection

export results

const
DefaultMaxIncomingStreams* = 10

type
LPProtoHandler* = proc (
conn: Connection,
Expand All @@ -26,11 +31,17 @@ type
codecs*: seq[string]
handler*: LPProtoHandler ## this handler gets invoked by the protocol negotiator
started*: bool
maxIncomingStreams: Opt[int]

method init*(p: LPProtocol) {.base, gcsafe.} = discard
method start*(p: LPProtocol) {.async, base.} = p.started = true
method stop*(p: LPProtocol) {.async, base.} = p.started = false

proc maxIncomingStreams*(p: LPProtocol): int =
p.maxIncomingStreams.get(DefaultMaxIncomingStreams)

proc `maxIncomingStreams=`*(p: LPProtocol, val: int) =
p.maxIncomingStreams = Opt.some(val)

func codec*(p: LPProtocol): string =
assert(p.codecs.len > 0, "Codecs sequence was empty!")
Expand All @@ -40,3 +51,16 @@ func `codec=`*(p: LPProtocol, codec: string) =
# always insert as first codec
# if we use this abstraction
p.codecs.insert(codec, 0)

proc new*(
T: type LPProtocol,
codecs: seq[string],
handler: LPProtoHandler, # default(Opt[int]) or Opt.none(int) don't work on 1.2
maxIncomingStreams: Opt[int] | int = Opt[int]()): T =
T(
codecs: codecs,
handler: handler,
maxIncomingStreams:
when maxIncomingStreams is int: Opt.some(maxIncomingStreams)
else: maxIncomingStreams
)
73 changes: 73 additions & 0 deletions tests/testmultistream.nim
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,79 @@ suite "Multistream select":

await handlerWait.wait(30.seconds)

asyncTest "e2e - streams limit":
let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()]
let blocker = newFuture[void]()

# Start 5 streams which are blocked by `blocker`
# Try to start a new one, which should fail
# Unblock the 5 streams, check that we can open a new one
proc testHandler(conn: Connection,
proto: string):
Future[void] {.async, gcsafe.} =
await blocker
await conn.writeLp("Hello!")
await conn.close()

var protocol: LPProtocol = LPProtocol.new(
@["/test/proto/1.0.0"],
testHandler,
maxIncomingStreams = 5
)

protocol.handler = testHandler
let msListen = MultistreamSelect.new()
msListen.addHandler("/test/proto/1.0.0", protocol)

let transport1 = TcpTransport.new(upgrade = Upgrade())
await transport1.start(ma)

proc acceptedOne(c: Connection) {.async.} =
await msListen.handle(c)
await c.close()

proc acceptHandler() {.async, gcsafe.} =
while true:
let conn = await transport1.accept()
asyncSpawn acceptedOne(conn)

var handlerWait = acceptHandler()

let msDial = MultistreamSelect.new()
let transport2 = TcpTransport.new(upgrade = Upgrade())

proc connector {.async.} =
let conn = await transport2.dial(transport1.addrs[0])
check: (await msDial.select(conn, "/test/proto/1.0.0")) == true
check: string.fromBytes(await conn.readLp(1024)) == "Hello!"
await conn.close()

# Fill up the 5 allowed streams
var dialers: seq[Future[void]]
for _ in 0..<5:
dialers.add(connector())

# This one will fail during negotiation
expect(CatchableError):
try: waitFor(connector().wait(1.seconds))
except AsyncTimeoutError as exc:
check false
raise exc
# check that the dialers aren't finished
check: (await dialers[0].withTimeout(10.milliseconds)) == false

# unblock the dialers
blocker.complete()
await allFutures(dialers)

# now must work
waitFor(connector())

await transport2.stop()
await transport1.stop()

await handlerWait.cancelAndWait()

asyncTest "e2e - ls":
let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()]

Expand Down

0 comments on commit 64cbbe1

Please sign in to comment.