-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
pass the peer ID to SecureInbound in the SecureTransport and SecureMu…
…xer (#211) The peer ID may be empty. This will be the common case. In that case, connections from any peer are accepted.
- Loading branch information
1 parent
1d5963f
commit 52f593e
Showing
3 changed files
with
73 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,157 +1,128 @@ | ||
package insecure | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"github.com/libp2p/go-libp2p-core/peer" | ||
"github.com/libp2p/go-libp2p-core/sec" | ||
"io" | ||
"net" | ||
"testing" | ||
|
||
ci "github.com/libp2p/go-libp2p-core/crypto" | ||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/libp2p/go-libp2p-core/crypto" | ||
"github.com/libp2p/go-libp2p-core/peer" | ||
"github.com/libp2p/go-libp2p-core/sec" | ||
) | ||
|
||
// Run a set of sessions through the session setup and verification. | ||
func TestConnections(t *testing.T) { | ||
clientTpt := newTestTransport(t, ci.RSA, 2048) | ||
serverTpt := newTestTransport(t, ci.Ed25519, 1024) | ||
clientTpt := newTestTransport(t, crypto.RSA, 2048) | ||
serverTpt := newTestTransport(t, crypto.Ed25519, 1024) | ||
|
||
clientConn, serverConn, clientErr, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), "") | ||
require.NoError(t, clientErr) | ||
require.NoError(t, serverErr) | ||
testIDs(t, clientTpt, serverTpt, clientConn, serverConn) | ||
testKeys(t, clientTpt, serverTpt, clientConn, serverConn) | ||
testReadWrite(t, clientConn, serverConn) | ||
} | ||
|
||
func TestPeerIdMatchInbound(t *testing.T) { | ||
clientTpt := newTestTransport(t, crypto.RSA, 2048) | ||
serverTpt := newTestTransport(t, crypto.Ed25519, 1024) | ||
|
||
testConnection(t, clientTpt, serverTpt) | ||
clientConn, serverConn, clientErr, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), clientTpt.LocalPeer()) | ||
require.NoError(t, clientErr) | ||
require.NoError(t, serverErr) | ||
testIDs(t, clientTpt, serverTpt, clientConn, serverConn) | ||
testKeys(t, clientTpt, serverTpt, clientConn, serverConn) | ||
testReadWrite(t, clientConn, serverConn) | ||
} | ||
|
||
func TestPeerIDMismatchInbound(t *testing.T) { | ||
clientTpt := newTestTransport(t, crypto.RSA, 2048) | ||
serverTpt := newTestTransport(t, crypto.Ed25519, 1024) | ||
|
||
_, _, _, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), "a-random-peer") | ||
require.Error(t, serverErr) | ||
require.Contains(t, serverErr.Error(), "remote peer sent unexpected peer ID") | ||
} | ||
|
||
func TestPeerIDMismatchOutbound(t *testing.T) { | ||
clientTpt := newTestTransport(t, crypto.RSA, 2048) | ||
serverTpt := newTestTransport(t, crypto.Ed25519, 1024) | ||
|
||
_, _, clientErr, _ := connect(t, clientTpt, serverTpt, "a random peer", "") | ||
require.Error(t, clientErr) | ||
require.Contains(t, clientErr.Error(), "remote peer sent unexpected peer ID") | ||
} | ||
|
||
func newTestTransport(t *testing.T, typ, bits int) *Transport { | ||
priv, pub, err := ci.GenerateKeyPair(typ, bits) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
priv, pub, err := crypto.GenerateKeyPair(typ, bits) | ||
require.NoError(t, err) | ||
id, err := peer.IDFromPublicKey(pub) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
require.NoError(t, err) | ||
return NewWithIdentity(id, priv) | ||
} | ||
|
||
// Create a new pair of connected TCP sockets. | ||
func newConnPair(t *testing.T) (net.Conn, net.Conn) { | ||
lstnr, err := net.Listen("tcp", "localhost:0") | ||
if err != nil { | ||
t.Fatalf("Failed to listen: %v", err) | ||
return nil, nil | ||
} | ||
require.NoError(t, err, "failed to listen") | ||
|
||
var clientErr error | ||
var client net.Conn | ||
addr := lstnr.Addr() | ||
done := make(chan struct{}) | ||
|
||
go func() { | ||
defer close(done) | ||
addr := lstnr.Addr() | ||
client, clientErr = net.Dial(addr.Network(), addr.String()) | ||
}() | ||
|
||
server, err := lstnr.Accept() | ||
<-done | ||
require.NoError(t, err, "failed to accept") | ||
|
||
<-done | ||
lstnr.Close() | ||
|
||
if err != nil { | ||
t.Fatalf("Failed to accept: %v", err) | ||
} | ||
|
||
if clientErr != nil { | ||
t.Fatalf("Failed to connect: %v", clientErr) | ||
} | ||
|
||
require.NoError(t, clientErr, "failed to connect") | ||
return client, server | ||
} | ||
|
||
// Create a new pair of connected sessions based off of the provided | ||
// session generators. | ||
func connect(t *testing.T, clientTpt, serverTpt *Transport) (sec.SecureConn, sec.SecureConn) { | ||
func connect(t *testing.T, clientTpt, serverTpt *Transport, clientExpectsID, serverExpectsID peer.ID) (clientConn sec.SecureConn, serverConn sec.SecureConn, clientErr, serverErr error) { | ||
client, server := newConnPair(t) | ||
|
||
// Connect the client and server sessions | ||
done := make(chan struct{}) | ||
|
||
var clientConn sec.SecureConn | ||
var clientErr error | ||
go func() { | ||
defer close(done) | ||
clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, serverTpt.LocalPeer()) | ||
clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, clientExpectsID) | ||
}() | ||
|
||
serverConn, serverErr := serverTpt.SecureInbound(context.TODO(), server) | ||
serverConn, serverErr = serverTpt.SecureInbound(context.TODO(), server, serverExpectsID) | ||
<-done | ||
|
||
if serverErr != nil { | ||
t.Fatal(serverErr) | ||
} | ||
|
||
if clientErr != nil { | ||
t.Fatal(clientErr) | ||
} | ||
|
||
return clientConn, serverConn | ||
return | ||
} | ||
|
||
// Check the peer IDs | ||
func testIDs(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) { | ||
if clientConn.LocalPeer() != clientTpt.LocalPeer() { | ||
t.Fatal("Client Local Peer ID mismatch.") | ||
} | ||
|
||
if clientConn.RemotePeer() != serverTpt.LocalPeer() { | ||
t.Fatal("Client Remote Peer ID mismatch.") | ||
} | ||
|
||
if clientConn.LocalPeer() != serverConn.RemotePeer() { | ||
t.Fatal("Server Local Peer ID mismatch.") | ||
} | ||
require.Equal(t, clientConn.LocalPeer(), clientTpt.LocalPeer(), "Client Local Peer ID mismatch.") | ||
require.Equal(t, clientConn.RemotePeer(), serverTpt.LocalPeer(), "Client Remote Peer ID mismatch.") | ||
require.Equal(t, clientConn.LocalPeer(), serverConn.RemotePeer(), "Server Local Peer ID mismatch.") | ||
} | ||
|
||
// Check the keys | ||
func testKeys(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) { | ||
sk := serverConn.LocalPrivateKey() | ||
pk := sk.GetPublic() | ||
|
||
if !sk.Equals(serverTpt.LocalPrivateKey()) { | ||
t.Error("Private key Mismatch.") | ||
} | ||
|
||
if !pk.Equals(clientConn.RemotePublicKey()) { | ||
t.Error("Public key mismatch.") | ||
} | ||
require.True(t, sk.Equals(serverTpt.LocalPrivateKey()), "private key mismatch") | ||
require.True(t, sk.GetPublic().Equals(clientConn.RemotePublicKey()), "public key mismatch") | ||
} | ||
|
||
// Check sending and receiving messages | ||
func testReadWrite(t *testing.T, clientConn, serverConn sec.SecureConn) { | ||
before := []byte("hello world") | ||
_, err := clientConn.Write(before) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
require.NoError(t, err) | ||
|
||
after := make([]byte, len(before)) | ||
_, err = io.ReadFull(serverConn, after) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
if !bytes.Equal(before, after) { | ||
t.Errorf("Message mismatch. %v != %v", before, after) | ||
} | ||
} | ||
|
||
// Setup a new session with a pair of locally connected sockets | ||
func testConnection(t *testing.T, clientTpt, serverTpt *Transport) { | ||
clientConn, serverConn := connect(t, clientTpt, serverTpt) | ||
|
||
testIDs(t, clientTpt, serverTpt, clientConn, serverConn) | ||
testKeys(t, clientTpt, serverTpt, clientConn, serverConn) | ||
testReadWrite(t, clientConn, serverConn) | ||
|
||
clientConn.Close() | ||
serverConn.Close() | ||
require.NoError(t, err) | ||
require.Equal(t, before, after, "message mismatch") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters