diff --git a/p2p/net/mock/interface.go b/p2p/net/mock/interface.go index d89342b009..f8939db586 100644 --- a/p2p/net/mock/interface.go +++ b/p2p/net/mock/interface.go @@ -10,6 +10,7 @@ import ( "io" "time" + "github.com/libp2p/go-libp2p/core/connmgr" ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" @@ -22,6 +23,7 @@ import ( type Mocknet interface { // GenPeer generates a peer and its network.Network in the Mocknet GenPeer() (host.Host, error) + GenPeerWithConnGater(connmgr.ConnectionGater) (host.Host, error) // AddPeer adds an existing peer. we need both a privkey and addr. // ID is derived from PrivKey diff --git a/p2p/net/mock/mock_net.go b/p2p/net/mock/mock_net.go index cde4052369..63302b0cff 100644 --- a/p2p/net/mock/mock_net.go +++ b/p2p/net/mock/mock_net.go @@ -8,6 +8,7 @@ import ( "sort" "sync" + "github.com/libp2p/go-libp2p/core/connmgr" ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" @@ -64,6 +65,10 @@ func (mn *mocknet) Close() error { } func (mn *mocknet) GenPeer() (host.Host, error) { + return mn.GenPeerWithConnGater(nil) +} + +func (mn *mocknet) GenPeerWithConnGater(gater connmgr.ConnectionGater) (host.Host, error) { sk, _, err := ic.GenerateECDSAKeyPair(rand.Reader) if err != nil { return nil, err @@ -83,7 +88,14 @@ func (mn *mocknet) GenPeer() (host.Host, error) { return nil, fmt.Errorf("failed to create test multiaddr: %s", err) } - h, err := mn.AddPeer(sk, a) + p, ps, err := mn.createPeerstore(sk, a) + if err != nil { + return nil, err + } + h, err := mn.AddPeerWithOptions(p, &peerOptions{ + ps: ps, + gater: gater, + }) if err != nil { return nil, err } @@ -92,25 +104,37 @@ func (mn *mocknet) GenPeer() (host.Host, error) { } func (mn *mocknet) AddPeer(k ic.PrivKey, a ma.Multiaddr) (host.Host, error) { - p, err := peer.IDFromPublicKey(k.GetPublic()) + p, ps, err := mn.createPeerstore(k, a) if err != nil { return nil, err } + return mn.AddPeerWithPeerstore(p, ps) +} + +func (mn *mocknet) createPeerstore(k ic.PrivKey, a ma.Multiaddr) (peer.ID, peerstore.Peerstore, error) { + p, err := peer.IDFromPublicKey(k.GetPublic()) + if err != nil { + return "", nil, err + } + ps, err := pstoremem.NewPeerstore() if err != nil { - return nil, err + return "", nil, err } ps.AddAddr(p, a, peerstore.PermanentAddrTTL) ps.AddPrivKey(p, k) ps.AddPubKey(p, k.GetPublic()) - - return mn.AddPeerWithPeerstore(p, ps) + return p, ps, nil } func (mn *mocknet) AddPeerWithPeerstore(p peer.ID, ps peerstore.Peerstore) (host.Host, error) { + return mn.AddPeerWithOptions(p, &peerOptions{ps: ps}) +} + +func (mn *mocknet) AddPeerWithOptions(p peer.ID, netOpts *peerOptions) (host.Host, error) { bus := eventbus.NewBus() - n, err := newPeernet(mn, p, ps, bus) + n, err := newPeernet(mn, p, netOpts, bus) if err != nil { return nil, err } diff --git a/p2p/net/mock/mock_peernet.go b/p2p/net/mock/mock_peernet.go index a46ee8ddc9..4fe18ab326 100644 --- a/p2p/net/mock/mock_peernet.go +++ b/p2p/net/mock/mock_peernet.go @@ -7,6 +7,7 @@ import ( "math/rand" "sync" + "github.com/libp2p/go-libp2p/core/connmgr" "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" @@ -14,6 +15,11 @@ import ( ma "github.com/multiformats/go-multiaddr" ) +type peerOptions struct { + ps peerstore.Peerstore + gater connmgr.ConnectionGater +} + // peernet implements network.Network type peernet struct { mocknet *mocknet // parent @@ -28,6 +34,9 @@ type peernet struct { connsByPeer map[peer.ID]map[*conn]struct{} connsByLink map[*link]map[*conn]struct{} + // connection gater to check before dialing or accepting connections. May be nil to allow all. + gater connmgr.ConnectionGater + // implement network.Network streamHandler network.StreamHandler @@ -38,7 +47,7 @@ type peernet struct { } // newPeernet constructs a new peernet -func newPeernet(m *mocknet, p peer.ID, ps peerstore.Peerstore, bus event.Bus) (*peernet, error) { +func newPeernet(m *mocknet, p peer.ID, opts *peerOptions, bus event.Bus) (*peernet, error) { emitter, err := bus.Emitter(&event.EvtPeerConnectednessChanged{}) if err != nil { return nil, err @@ -47,7 +56,8 @@ func newPeernet(m *mocknet, p peer.ID, ps peerstore.Peerstore, bus event.Bus) (* n := &peernet{ mocknet: m, peer: p, - ps: ps, + ps: opts.ps, + gater: opts.gater, emitter: emitter, connsByPeer: map[peer.ID]map[*conn]struct{}{}, @@ -124,6 +134,10 @@ func (pn *peernet) connect(p peer.ID) (*conn, error) { } pn.RUnlock() + if pn.gater != nil && !pn.gater.InterceptPeerDial(p) { + log.Debugf("gater disallowed outbound connection to peer %s", p) + return nil, fmt.Errorf("%v connection gater disallowed connection to %v", pn.peer, p) + } log.Debugf("%s (newly) dialing %s", pn.peer, p) // ok, must create a new connection. we need a link @@ -139,18 +153,51 @@ func (pn *peernet) connect(p peer.ID) (*conn, error) { log.Debugf("%s dialing %s openingConn", pn.peer, p) // create a new connection with link - c := pn.openConn(p, l.(*link)) - return c, nil + return pn.openConn(p, l.(*link)) } -func (pn *peernet) openConn(r peer.ID, l *link) *conn { +func (pn *peernet) openConn(r peer.ID, l *link) (*conn, error) { lc, rc := l.newConnPair(pn) - log.Debugf("%s opening connection to %s", pn.LocalPeer(), lc.RemotePeer()) addConnPair(pn, rc.net, lc, rc) + log.Debugf("%s opening connection to %s", pn.LocalPeer(), lc.RemotePeer()) + abort := func() { + _ = lc.Close() + _ = rc.Close() + } + if pn.gater != nil && !pn.gater.InterceptAddrDial(lc.remote, lc.remoteAddr) { + abort() + return nil, fmt.Errorf("%v rejected dial to %v on addr %v", lc.local, lc.remote, lc.remoteAddr) + } + if rc.net.gater != nil && !rc.net.gater.InterceptAccept(rc) { + abort() + return nil, fmt.Errorf("%v rejected connection from %v", rc.local, rc.remote) + } + if err := checkSecureAndUpgrade(network.DirOutbound, pn.gater, lc); err != nil { + abort() + return nil, err + } + if err := checkSecureAndUpgrade(network.DirInbound, rc.net.gater, rc); err != nil { + abort() + return nil, err + } go rc.net.remoteOpenedConn(rc) pn.addConn(lc) - return lc + return lc, nil +} + +func checkSecureAndUpgrade(dir network.Direction, gater connmgr.ConnectionGater, c *conn) error { + if gater == nil { + return nil + } + if !gater.InterceptSecured(dir, c.remote, c) { + return fmt.Errorf("%v rejected secure handshake with %v", c.local, c.remote) + } + allow, _ := gater.InterceptUpgraded(c) + if !allow { + return fmt.Errorf("%v rejected upgrade with %v", c.local, c.remote) + } + return nil } // addConnPair adds connection to both peernets at the same time diff --git a/p2p/net/mock/mock_test.go b/p2p/net/mock/mock_test.go index 2ea1bf18dd..483122c14f 100644 --- a/p2p/net/mock/mock_test.go +++ b/p2p/net/mock/mock_test.go @@ -13,9 +13,12 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" + "github.com/libp2p/go-libp2p/p2p/net/conngater" + manet "github.com/multiformats/go-multiaddr/net" "github.com/libp2p/go-libp2p-testing/ci" tetc "github.com/libp2p/go-libp2p-testing/etc" @@ -681,3 +684,68 @@ func TestEventBus(t *testing.T) { } } } + +func TestBlockByPeerID(t *testing.T) { + m, gater1, host1, _, host2 := WithConnectionGaters(t) + + err := gater1.BlockPeer(host2.ID()) + if err != nil { + t.Fatal(err) + } + + _, err = m.ConnectPeers(host1.ID(), host2.ID()) + if err == nil { + t.Fatal("Should have blocked connection to banned peer") + } + + _, err = m.ConnectPeers(host2.ID(), host1.ID()) + if err == nil { + t.Fatal("Should have blocked connection from banned peer") + } +} + +func TestBlockByIP(t *testing.T) { + m, gater1, host1, _, host2 := WithConnectionGaters(t) + + ip, err := manet.ToIP(host2.Addrs()[0]) + if err != nil { + t.Fatal(err) + } + err = gater1.BlockAddr(ip) + if err != nil { + t.Fatal(err) + } + + _, err = m.ConnectPeers(host1.ID(), host2.ID()) + if err == nil { + t.Fatal("Should have blocked connection to banned IP") + } + + _, err = m.ConnectPeers(host2.ID(), host1.ID()) + if err == nil { + t.Fatal("Should have blocked connection from banned IP") + } +} + +func WithConnectionGaters(t *testing.T) (Mocknet, *conngater.BasicConnectionGater, host.Host, *conngater.BasicConnectionGater, host.Host) { + m := New() + addPeer := func() (*conngater.BasicConnectionGater, host.Host) { + gater, err := conngater.NewBasicConnectionGater(nil) + if err != nil { + t.Fatal(err) + } + h, err := m.GenPeerWithConnGater(gater) + if err != nil { + t.Fatal(err) + } + return gater, h + } + gater1, host1 := addPeer() + gater2, host2 := addPeer() + + err := m.LinkAll() + if err != nil { + t.Fatal(err) + } + return m, gater1, host1, gater2, host2 +}