Skip to content

Commit

Permalink
pass the peer ID to SecureInbound in the SecureTransport and SecureMu…
Browse files Browse the repository at this point in the history
…xer (#211)

The peer ID may be empty. This will be the common case. In that case,
connections from any peer are accepted.
  • Loading branch information
marten-seemann authored Sep 8, 2021
1 parent 1d5963f commit 52f593e
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 96 deletions.
6 changes: 5 additions & 1 deletion core/sec/insecure/insecure.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (t *Transport) LocalPrivateKey() ci.PrivKey {
//
// SecureInbound may fail if the remote peer sends an ID and public key that are inconsistent
// with each other, or if a network error occurs during the ID exchange.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.SecureConn, error) {
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
conn := &Conn{
Conn: insecure,
local: t.id,
Expand All @@ -72,6 +72,10 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.S
return nil, err
}

if t.key != nil && p != "" && p != conn.remote {
return nil, fmt.Errorf("remote peer sent unexpected peer ID. expected=%s received=%s", p, conn.remote)
}

return conn, nil
}

Expand Down
155 changes: 63 additions & 92 deletions core/sec/insecure/insecure_test.go
Original file line number Diff line number Diff line change
@@ -1,157 +1,128 @@
package insecure

import (
"bytes"
"context"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/sec"
"io"
"net"
"testing"

ci "github.com/libp2p/go-libp2p-core/crypto"
"github.com/stretchr/testify/require"

"github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/sec"
)

// Run a set of sessions through the session setup and verification.
func TestConnections(t *testing.T) {
clientTpt := newTestTransport(t, ci.RSA, 2048)
serverTpt := newTestTransport(t, ci.Ed25519, 1024)
clientTpt := newTestTransport(t, crypto.RSA, 2048)
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)

clientConn, serverConn, clientErr, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), "")
require.NoError(t, clientErr)
require.NoError(t, serverErr)
testIDs(t, clientTpt, serverTpt, clientConn, serverConn)
testKeys(t, clientTpt, serverTpt, clientConn, serverConn)
testReadWrite(t, clientConn, serverConn)
}

func TestPeerIdMatchInbound(t *testing.T) {
clientTpt := newTestTransport(t, crypto.RSA, 2048)
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)

testConnection(t, clientTpt, serverTpt)
clientConn, serverConn, clientErr, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), clientTpt.LocalPeer())
require.NoError(t, clientErr)
require.NoError(t, serverErr)
testIDs(t, clientTpt, serverTpt, clientConn, serverConn)
testKeys(t, clientTpt, serverTpt, clientConn, serverConn)
testReadWrite(t, clientConn, serverConn)
}

func TestPeerIDMismatchInbound(t *testing.T) {
clientTpt := newTestTransport(t, crypto.RSA, 2048)
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)

_, _, _, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), "a-random-peer")
require.Error(t, serverErr)
require.Contains(t, serverErr.Error(), "remote peer sent unexpected peer ID")
}

func TestPeerIDMismatchOutbound(t *testing.T) {
clientTpt := newTestTransport(t, crypto.RSA, 2048)
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)

_, _, clientErr, _ := connect(t, clientTpt, serverTpt, "a random peer", "")
require.Error(t, clientErr)
require.Contains(t, clientErr.Error(), "remote peer sent unexpected peer ID")
}

func newTestTransport(t *testing.T, typ, bits int) *Transport {
priv, pub, err := ci.GenerateKeyPair(typ, bits)
if err != nil {
t.Fatal(err)
}
priv, pub, err := crypto.GenerateKeyPair(typ, bits)
require.NoError(t, err)
id, err := peer.IDFromPublicKey(pub)
if err != nil {
t.Fatal(err)
}

require.NoError(t, err)
return NewWithIdentity(id, priv)
}

// Create a new pair of connected TCP sockets.
func newConnPair(t *testing.T) (net.Conn, net.Conn) {
lstnr, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
return nil, nil
}
require.NoError(t, err, "failed to listen")

var clientErr error
var client net.Conn
addr := lstnr.Addr()
done := make(chan struct{})

go func() {
defer close(done)
addr := lstnr.Addr()
client, clientErr = net.Dial(addr.Network(), addr.String())
}()

server, err := lstnr.Accept()
<-done
require.NoError(t, err, "failed to accept")

<-done
lstnr.Close()

if err != nil {
t.Fatalf("Failed to accept: %v", err)
}

if clientErr != nil {
t.Fatalf("Failed to connect: %v", clientErr)
}

require.NoError(t, clientErr, "failed to connect")
return client, server
}

// Create a new pair of connected sessions based off of the provided
// session generators.
func connect(t *testing.T, clientTpt, serverTpt *Transport) (sec.SecureConn, sec.SecureConn) {
func connect(t *testing.T, clientTpt, serverTpt *Transport, clientExpectsID, serverExpectsID peer.ID) (clientConn sec.SecureConn, serverConn sec.SecureConn, clientErr, serverErr error) {
client, server := newConnPair(t)

// Connect the client and server sessions
done := make(chan struct{})

var clientConn sec.SecureConn
var clientErr error
go func() {
defer close(done)
clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, serverTpt.LocalPeer())
clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, clientExpectsID)
}()

serverConn, serverErr := serverTpt.SecureInbound(context.TODO(), server)
serverConn, serverErr = serverTpt.SecureInbound(context.TODO(), server, serverExpectsID)
<-done

if serverErr != nil {
t.Fatal(serverErr)
}

if clientErr != nil {
t.Fatal(clientErr)
}

return clientConn, serverConn
return
}

// Check the peer IDs
func testIDs(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) {
if clientConn.LocalPeer() != clientTpt.LocalPeer() {
t.Fatal("Client Local Peer ID mismatch.")
}

if clientConn.RemotePeer() != serverTpt.LocalPeer() {
t.Fatal("Client Remote Peer ID mismatch.")
}

if clientConn.LocalPeer() != serverConn.RemotePeer() {
t.Fatal("Server Local Peer ID mismatch.")
}
require.Equal(t, clientConn.LocalPeer(), clientTpt.LocalPeer(), "Client Local Peer ID mismatch.")
require.Equal(t, clientConn.RemotePeer(), serverTpt.LocalPeer(), "Client Remote Peer ID mismatch.")
require.Equal(t, clientConn.LocalPeer(), serverConn.RemotePeer(), "Server Local Peer ID mismatch.")
}

// Check the keys
func testKeys(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) {
sk := serverConn.LocalPrivateKey()
pk := sk.GetPublic()

if !sk.Equals(serverTpt.LocalPrivateKey()) {
t.Error("Private key Mismatch.")
}

if !pk.Equals(clientConn.RemotePublicKey()) {
t.Error("Public key mismatch.")
}
require.True(t, sk.Equals(serverTpt.LocalPrivateKey()), "private key mismatch")
require.True(t, sk.GetPublic().Equals(clientConn.RemotePublicKey()), "public key mismatch")
}

// Check sending and receiving messages
func testReadWrite(t *testing.T, clientConn, serverConn sec.SecureConn) {
before := []byte("hello world")
_, err := clientConn.Write(before)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)

after := make([]byte, len(before))
_, err = io.ReadFull(serverConn, after)
if err != nil {
t.Fatal(err)
}

if !bytes.Equal(before, after) {
t.Errorf("Message mismatch. %v != %v", before, after)
}
}

// Setup a new session with a pair of locally connected sockets
func testConnection(t *testing.T, clientTpt, serverTpt *Transport) {
clientConn, serverConn := connect(t, clientTpt, serverTpt)

testIDs(t, clientTpt, serverTpt, clientConn, serverConn)
testKeys(t, clientTpt, serverTpt, clientConn, serverConn)
testReadWrite(t, clientConn, serverConn)

clientConn.Close()
serverConn.Close()
require.NoError(t, err)
require.Equal(t, before, after, "message mismatch")
}
8 changes: 5 additions & 3 deletions core/sec/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ type SecureConn interface {
// plain-text, native connections into authenticated, encrypted connections.
type SecureTransport interface {
// SecureInbound secures an inbound connection.
SecureInbound(ctx context.Context, insecure net.Conn) (SecureConn, error)
// If p is empty, connections from any peer are accepted.
SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error)

// SecureOutbound secures an outbound connection.
SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error)
Expand All @@ -29,9 +30,10 @@ type SecureTransport interface {
// and open outbound connections with simultaneous open.
type SecureMuxer interface {
// SecureInbound secures an inbound connection.
// The returned boolean indicates whether the connection should be trated as a server
// The returned boolean indicates whether the connection should be treated as a server
// connection; in the case of SecureInbound it should always be true.
SecureInbound(ctx context.Context, insecure net.Conn) (SecureConn, bool, error)
// If p is empty, connections from any peer are accepted.
SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, bool, error)

// SecureOutbound secures an outbound connection.
// The returned boolean indicates whether the connection should be treated as a server
Expand Down

0 comments on commit 52f593e

Please sign in to comment.