Skip to content

Commit

Permalink
Get Device Authorization Flow information from management (#308)
Browse files Browse the repository at this point in the history
We will configure the device authorization
flow information and a client will
retrieve it and initiate a
device authorization gran flow
  • Loading branch information
mlsmaycon authored May 8, 2022
1 parent fec3132 commit 7e5449f
Show file tree
Hide file tree
Showing 12 changed files with 792 additions and 116 deletions.
2 changes: 1 addition & 1 deletion iface/iface.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

const (
DefaultMTU = 1280
DefaultMTU = 1280
DefaultWgPort = 51820
)

Expand Down
1 change: 1 addition & 0 deletions management/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ type Client interface {
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
}
161 changes: 125 additions & 36 deletions management/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,17 @@ import (
"google.golang.org/grpc/status"
)

var tested *GrpcClient
var serverAddr string
var mgmtMockServer *mock_server.ManagementServiceServerMock
var serverKey wgtypes.Key

const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"

func Test_Start(t *testing.T) {
func startManagement(t *testing.T) (*grpc.Server, net.Listener) {

level, _ := log.ParseLevel("debug")
log.SetLevel(level)

testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
testDir := t.TempDir()
ctx := context.Background()

config := &mgmt.Config{}
_, err = util.ReadJson("../server/testdata/management.json", config)
_, err := util.ReadJson("../server/testdata/management.json", config)
if err != nil {
t.Fatal(err)
}
Expand All @@ -52,15 +44,7 @@ func Test_Start(t *testing.T) {
if err != nil {
t.Fatal(err)
}
_, listener := startManagement(config, t)
serverAddr = listener.Addr().String()
tested, err = NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatal(err)
}
}

func startManagement(config *mgmt.Config, t *testing.T) (*grpc.Server, net.Listener) {
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -89,20 +73,20 @@ func startManagement(config *mgmt.Config, t *testing.T) (*grpc.Server, net.Liste
return s, lis
}

func startMockManagement(t *testing.T) (*grpc.Server, net.Listener) {
func startMockManagement(t *testing.T) (*grpc.Server, net.Listener, *mock_server.ManagementServiceServerMock, wgtypes.Key) {
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}

s := grpc.NewServer()

serverKey, err = wgtypes.GenerateKey()
serverKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}

mgmtMockServer = &mock_server.ManagementServiceServerMock{
mgmtMockServer := &mock_server.ManagementServiceServerMock{
GetServerKeyFunc: func(context.Context, *proto.Empty) (*proto.ServerKeyResponse, error) {
response := &proto.ServerKeyResponse{
Key: serverKey.PublicKey().String(),
Expand All @@ -119,27 +103,59 @@ func startMockManagement(t *testing.T) (*grpc.Server, net.Listener) {
}
}()

return s, lis
return s, lis, mgmtMockServer, serverKey
}

func closeManagementSilently(s *grpc.Server, listener net.Listener) {
s.GracefulStop()
err := listener.Close()
if err != nil {
log.Warnf("error while closing management listener %v", err)
return
}
}

func TestClient_GetServerPublicKey(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)

key, err := tested.GetServerPublicKey()
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Error(err)
t.Fatal(err)
}

key, err := client.GetServerPublicKey()
if err != nil {
t.Error("couldn't retrieve management public key")
}
if key == nil {
t.Error("expecting non nil server key got nil")
t.Error("got an empty management public key")
}
}

func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
key, err := tested.GetServerPublicKey()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)

client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Fatal(err)
}
_, err = tested.Login(*key)
_, err = client.Login(*key)
if err == nil {
t.Error("expecting err on unregistered login, got nil")
}
Expand All @@ -149,12 +165,25 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
}

func TestClient_LoginRegistered(t *testing.T) {
key, err := tested.GetServerPublicKey()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)

client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}

key, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo()
resp, err := tested.Register(*key, ValidKey, "", info)
resp, err := client.Register(*key, ValidKey, "", info)
if err != nil {
t.Error(err)
}
Expand All @@ -165,13 +194,26 @@ func TestClient_LoginRegistered(t *testing.T) {
}

func TestClient_Sync(t *testing.T) {
serverKey, err := tested.GetServerPublicKey()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)

client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}

serverKey, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}

info := system.GetInfo()
_, err = tested.Register(*serverKey, ValidKey, "", info)
_, err = client.Register(*serverKey, ValidKey, "", info)
if err != nil {
t.Error(err)
}
Expand All @@ -181,7 +223,7 @@ func TestClient_Sync(t *testing.T) {
if err != nil {
t.Error(err)
}
remoteClient, err := NewClient(context.TODO(), serverAddr, remoteKey, false)
remoteClient, err := NewClient(context.TODO(), listener.Addr().String(), remoteKey, false)
if err != nil {
t.Fatal(err)
}
Expand All @@ -195,7 +237,7 @@ func TestClient_Sync(t *testing.T) {
ch := make(chan *mgmtProto.SyncResponse, 1)

go func() {
err = tested.Sync(func(msg *mgmtProto.SyncResponse) error {
err = client.Sync(func(msg *mgmtProto.SyncResponse) error {
ch <- msg
return nil
})
Expand Down Expand Up @@ -227,7 +269,8 @@ func TestClient_Sync(t *testing.T) {
}

func Test_SystemMetaDataFromClient(t *testing.T) {
_, lis := startMockManagement(t)
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop()

testKey, err := wgtypes.GenerateKey()
if err != nil {
Expand Down Expand Up @@ -304,3 +347,49 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
assert.Equal(t, ValidKey, actualValidKey)
assert.Equal(t, expectedMeta, actualMeta)
}

func Test_GetDeviceAuthorizationFlow(t *testing.T) {
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop()

testKey, err := wgtypes.GenerateKey()
if err != nil {
log.Fatal(err)
}

serverAddr := lis.Addr().String()
ctx := context.Background()

client, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
log.Fatalf("error while creating testClient: %v", err)
}

expectedFlowInfo := &proto.DeviceAuthorizationFlow{
Provider: 0,
ProviderConfig: &proto.ProviderConfig{ClientID: "client"},
}

mgmtMockServer.GetDeviceAuthorizationFlowFunc =
func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*proto.EncryptedMessage, error) {

encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
if err != nil {
return nil, err
}

return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}

flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
if err != nil {
t.Error("error while retrieving device auth flow information")
}

assert.Equal(t, expectedFlowInfo.Provider, flowInfo.Provider, "provider should match")
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
}
34 changes: 34 additions & 0 deletions management/client/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,37 @@ func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken s
func (c *GrpcClient) Login(serverKey wgtypes.Key) (*proto.LoginResponse, error) {
return c.login(serverKey, &proto.LoginRequest{})
}

// GetDeviceAuthorizationFlow returns a device authorization flow information.
// It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get device authorization flow")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel()

message := &proto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
if err != nil {
return nil, err
}

resp, err := c.realClient.GetDeviceAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG},
)
if err != nil {
return nil, err
}

flowInfoResp := &proto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
log.Error(errWithMSG)
return nil, errWithMSG
}

return flowInfoResp, nil
}
18 changes: 13 additions & 5 deletions management/client/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ import (
)

type MockClient struct {
CloseFunc func() error
SyncFunc func(msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key) (*proto.LoginResponse, error)
CloseFunc func() error
SyncFunc func(msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
}

func (m *MockClient) Close() error {
Expand Down Expand Up @@ -48,3 +49,10 @@ func (m *MockClient) Login(serverKey wgtypes.Key) (*proto.LoginResponse, error)
}
return m.LoginFunc(serverKey)
}

func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
if m.GetDeviceAuthorizationFlowFunc == nil {
return nil, nil
}
return m.GetDeviceAuthorizationFlowFunc(serverKey)
}
Loading

0 comments on commit 7e5449f

Please sign in to comment.