Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get Device Authorization Flow information from management #308

Merged
merged 11 commits into from
May 8, 2022
Prev Previous commit
Next Next commit
test client GetDeviceAuthorizationFlow
Updated interface with new method
  • Loading branch information
mlsmaycon committed May 7, 2022
commit 8fb6fe73ec4f8c315d0957100bb312546d5af7ac
1 change: 1 addition & 0 deletions management/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ type Client interface {
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
}
144 changes: 107 additions & 37 deletions management/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,17 @@ import (
"google.golang.org/grpc/status"
)

var tested *GrpcClient
var serverAddr string
var mgmtMockServer *mock_server.ManagementServiceServerMock
var serverKey wgtypes.Key

const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"

func Test_Start(t *testing.T) {
func startManagement(t *testing.T) (*grpc.Server, net.Listener) {

level, _ := log.ParseLevel("debug")
log.SetLevel(level)

testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
testDir := t.TempDir()
ctx := context.Background()

config := &mgmt.Config{}
_, err = util.ReadJson("../server/testdata/management.json", config)
_, err := util.ReadJson("../server/testdata/management.json", config)
if err != nil {
t.Fatal(err)
}
Expand All @@ -52,15 +44,7 @@ func Test_Start(t *testing.T) {
if err != nil {
t.Fatal(err)
}
_, listener := startManagement(config, t)
serverAddr = listener.Addr().String()
tested, err = NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatal(err)
}
}

func startManagement(config *mgmt.Config, t *testing.T) (*grpc.Server, net.Listener) {
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -89,20 +73,20 @@ func startManagement(config *mgmt.Config, t *testing.T) (*grpc.Server, net.Liste
return s, lis
}

func startMockManagement(t *testing.T) (*grpc.Server, net.Listener) {
func startMockManagement(t *testing.T) (*grpc.Server, net.Listener, *mock_server.ManagementServiceServerMock, wgtypes.Key) {
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}

s := grpc.NewServer()

serverKey, err = wgtypes.GenerateKey()
serverKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}

mgmtMockServer = &mock_server.ManagementServiceServerMock{
mgmtMockServer := &mock_server.ManagementServiceServerMock{
GetServerKeyFunc: func(context.Context, *proto.Empty) (*proto.ServerKeyResponse, error) {
response := &proto.ServerKeyResponse{
Key: serverKey.PublicKey().String(),
Expand All @@ -119,27 +103,46 @@ func startMockManagement(t *testing.T) (*grpc.Server, net.Listener) {
}
}()

return s, lis
return s, lis, mgmtMockServer, serverKey
}

func TestClient_GetServerPublicKey(t *testing.T) {

key, err := tested.GetServerPublicKey()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Error(err)
t.Fatal(err)
}
ctx := context.Background()
_, listener := startManagement(t)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}

key, err := client.GetServerPublicKey()
if err != nil {
t.Error("couldn't retrieve management public key")
}
if key == nil {
t.Error("expecting non nil server key got nil")
t.Error("got an empty management public key")
}
}

func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
key, err := tested.GetServerPublicKey()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
_, listener := startManagement(t)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Fatal(err)
}
_, err = tested.Login(*key)
_, err = client.Login(*key)
if err == nil {
t.Error("expecting err on unregistered login, got nil")
}
Expand All @@ -149,12 +152,23 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
}

func TestClient_LoginRegistered(t *testing.T) {
key, err := tested.GetServerPublicKey()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
_, listener := startManagement(t)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}

key, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo()
resp, err := tested.Register(*key, ValidKey, "", info)
resp, err := client.Register(*key, ValidKey, "", info)
if err != nil {
t.Error(err)
}
Expand All @@ -165,13 +179,24 @@ func TestClient_LoginRegistered(t *testing.T) {
}

func TestClient_Sync(t *testing.T) {
serverKey, err := tested.GetServerPublicKey()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
_, listener := startManagement(t)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}

serverKey, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}

info := system.GetInfo()
_, err = tested.Register(*serverKey, ValidKey, "", info)
_, err = client.Register(*serverKey, ValidKey, "", info)
if err != nil {
t.Error(err)
}
Expand All @@ -181,7 +206,7 @@ func TestClient_Sync(t *testing.T) {
if err != nil {
t.Error(err)
}
remoteClient, err := NewClient(context.TODO(), serverAddr, remoteKey, false)
remoteClient, err := NewClient(context.TODO(), listener.Addr().String(), remoteKey, false)
if err != nil {
t.Fatal(err)
}
Expand All @@ -195,7 +220,7 @@ func TestClient_Sync(t *testing.T) {
ch := make(chan *mgmtProto.SyncResponse, 1)

go func() {
err = tested.Sync(func(msg *mgmtProto.SyncResponse) error {
err = client.Sync(func(msg *mgmtProto.SyncResponse) error {
ch <- msg
return nil
})
Expand Down Expand Up @@ -227,7 +252,7 @@ func TestClient_Sync(t *testing.T) {
}

func Test_SystemMetaDataFromClient(t *testing.T) {
_, lis := startMockManagement(t)
_, lis, mgmtMockServer, serverKey := startMockManagement(t)

testKey, err := wgtypes.GenerateKey()
if err != nil {
Expand Down Expand Up @@ -304,3 +329,48 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
assert.Equal(t, ValidKey, actualValidKey)
assert.Equal(t, expectedMeta, actualMeta)
}

func Test_GetDeviceAuthorizationFlow(t *testing.T) {
_, lis, mgmtMockServer, serverKey := startMockManagement(t)

testKey, err := wgtypes.GenerateKey()
if err != nil {
log.Fatal(err)
}

serverAddr := lis.Addr().String()
ctx := context.Background()

client, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
log.Fatalf("error while creating testClient: %v", err)
}

expectedFlowInfo := &proto.DeviceAuthorizationFlow{
Provider: 0,
ProviderConfig: &proto.ProviderConfig{ClientID: "client"},
}

mgmtMockServer.GetDeviceAuthorizationFlowFunc =
func(ctx context.Context, req *proto.DeviceAuthorizationFlowRequest) (*proto.EncryptedMessage, error) {

encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
if err != nil {
return nil, err
}

return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}

flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
if err != nil {
t.Error("error while retrieving device auth flow information")
}

assert.Equal(t, expectedFlowInfo.Provider, flowInfo.Provider, "provider should match")
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
}
5 changes: 3 additions & 2 deletions management/client/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,9 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
flowInfoResp := &proto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
log.Errorf("failed to decrypt device authorization flow message: %s", err)
return nil, err
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
log.Error(errWithMSG)
return nil, errWithMSG
}

return flowInfoResp, nil
Expand Down
16 changes: 12 additions & 4 deletions management/server/mock_server/management_server_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ import (
type ManagementServiceServerMock struct {
proto.UnimplementedManagementServiceServer

LoginFunc func(context.Context, *proto.EncryptedMessage) (*proto.EncryptedMessage, error)
SyncFunc func(*proto.EncryptedMessage, proto.ManagementService_SyncServer)
GetServerKeyFunc func(context.Context, *proto.Empty) (*proto.ServerKeyResponse, error)
IsHealthyFunc func(context.Context, *proto.Empty) (*proto.Empty, error)
LoginFunc func(context.Context, *proto.EncryptedMessage) (*proto.EncryptedMessage, error)
SyncFunc func(*proto.EncryptedMessage, proto.ManagementService_SyncServer)
GetServerKeyFunc func(context.Context, *proto.Empty) (*proto.ServerKeyResponse, error)
IsHealthyFunc func(context.Context, *proto.Empty) (*proto.Empty, error)
GetDeviceAuthorizationFlowFunc func(ctx context.Context, req *proto.DeviceAuthorizationFlowRequest) (*proto.EncryptedMessage, error)
}

func (m ManagementServiceServerMock) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
Expand Down Expand Up @@ -44,3 +45,10 @@ func (m ManagementServiceServerMock) IsHealthy(ctx context.Context, empty *proto
}
return nil, status.Errorf(codes.Unimplemented, "method IsHealthy not implemented")
}

func (m ManagementServiceServerMock) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.DeviceAuthorizationFlowRequest) (*proto.EncryptedMessage, error) {
if m.GetDeviceAuthorizationFlowFunc != nil {
return m.GetDeviceAuthorizationFlowFunc(ctx, req)
}
return nil, status.Errorf(codes.Unimplemented, "method GetDeviceAuthorizationFlow not implemented")
}