diff --git a/internal/app/app_test.go b/internal/app/app_test.go new file mode 100644 index 00000000..b405348e --- /dev/null +++ b/internal/app/app_test.go @@ -0,0 +1,98 @@ +package app_test + +import ( + "crypto/rand" + "fmt" + "math/big" + "os" + "testing" + + "go.uber.org/mock/gomock" + + "github.com/open-amt-cloud-toolkit/console/config" + "github.com/open-amt-cloud-toolkit/console/internal/app" +) + +func getFreePort() (string, error) { + port, err := rand.Int(rand.Reader, big.NewInt(1000)) + if err != nil { + return "", err + } + + return fmt.Sprintf(":%d", 8000+port.Int64()), nil +} + +func teardown() {} + +func TestRun(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + defer ctrl.Finish() + + mockDB := NewMockDB(ctrl) + mockHTTPServer := NewMockHTTPServer(ctrl) + + port, err := getFreePort() + if err != nil { + t.Fatalf("Failed to get a free port: %v", err) + } + + cfg := &config.Config{ + Log: config.Log{ + Level: "info", + }, + DB: config.DB{ + URL: "postgres://testuser:testpass@localhost/testdb", + PoolMax: 10, + }, + HTTP: config.HTTP{ + Port: port, + AllowedOrigins: []string{"*"}, + AllowedHeaders: []string{"Content-Type"}, + }, + App: config.App{ + Version: "DEVELOPMENT", + }, + } + + tests := []struct { + name string + setupMocks func() + setupEnv func() + cfg *config.Config + expectFunc func(t *testing.T) + }{ + { + name: "Successful run and shutdown", + setupMocks: func() { + mockDB.EXPECT().Close().Return(nil).Times(1) + mockHTTPServer.EXPECT().Notify().Return(make(chan error)).Times(1) + mockHTTPServer.EXPECT().Shutdown().Return(nil).Times(1) + }, + setupEnv: func() { + os.Setenv("GIN_MODE", "release") + }, + cfg: cfg, + expectFunc: func(_ *testing.T) { + go func() { + defer teardown() + app.Run(cfg) + }() + }, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + tc.setupEnv() + tc.setupMocks() + + tc.expectFunc(t) + }) + } +} diff --git a/internal/app/interface.go b/internal/app/interface.go new file mode 100644 index 00000000..5d78c1d0 --- /dev/null +++ b/internal/app/interface.go @@ -0,0 +1,28 @@ +package app + +import ( + "context" + "database/sql" + "net/http" + + "github.com/gorilla/websocket" +) + +// DB is an interface for database operations. +type DB interface { + Close() error + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + +// HTTPServer is an interface for the HTTP server. +type HTTPServer interface { + Notify() <-chan error + Shutdown() error + Start() error +} + +// WebSocketUpgrader is an interface for WebSocket upgrading. +type WebSocketUpgrader interface { + Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, error) +} diff --git a/internal/app/mocks_test.go b/internal/app/mocks_test.go new file mode 100644 index 00000000..922cd60c --- /dev/null +++ b/internal/app/mocks_test.go @@ -0,0 +1,200 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./internal/app/interface.go +// +// Generated by this command: +// +// mockgen -source=./internal/app/interface.go -destination=./internal/app/mocks_test.go -package app_test +// + +// Package app_test is a generated GoMock package. +package app_test + +import ( + context "context" + sql "database/sql" + http "net/http" + reflect "reflect" + + websocket "github.com/gorilla/websocket" + gomock "go.uber.org/mock/gomock" +) + +// MockDB is a mock of DB interface. +type MockDB struct { + ctrl *gomock.Controller + recorder *MockDBMockRecorder +} + +// MockDBMockRecorder is the mock recorder for MockDB. +type MockDBMockRecorder struct { + mock *MockDB +} + +// NewMockDB creates a new mock instance. +func NewMockDB(ctrl *gomock.Controller) *MockDB { + mock := &MockDB{ctrl: ctrl} + mock.recorder = &MockDBMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDB) EXPECT() *MockDBMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockDB) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockDBMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDB)(nil).Close)) +} + +// ExecContext mocks base method. +func (m *MockDB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, query} + for _, a := range args { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ExecContext", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ExecContext indicates an expected call of ExecContext. +func (mr *MockDBMockRecorder) ExecContext(ctx, query any, args ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, query}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecContext", reflect.TypeOf((*MockDB)(nil).ExecContext), varargs...) +} + +// QueryContext mocks base method. +func (m *MockDB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, query} + for _, a := range args { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryContext", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryContext indicates an expected call of QueryContext. +func (mr *MockDBMockRecorder) QueryContext(ctx, query any, args ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, query}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockDB)(nil).QueryContext), varargs...) +} + +// MockHTTPServer is a mock of HTTPServer interface. +type MockHTTPServer struct { + ctrl *gomock.Controller + recorder *MockHTTPServerMockRecorder +} + +// MockHTTPServerMockRecorder is the mock recorder for MockHTTPServer. +type MockHTTPServerMockRecorder struct { + mock *MockHTTPServer +} + +// NewMockHTTPServer creates a new mock instance. +func NewMockHTTPServer(ctrl *gomock.Controller) *MockHTTPServer { + mock := &MockHTTPServer{ctrl: ctrl} + mock.recorder = &MockHTTPServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockHTTPServer) EXPECT() *MockHTTPServerMockRecorder { + return m.recorder +} + +// Notify mocks base method. +func (m *MockHTTPServer) Notify() <-chan error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Notify") + ret0, _ := ret[0].(<-chan error) + return ret0 +} + +// Notify indicates an expected call of Notify. +func (mr *MockHTTPServerMockRecorder) Notify() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Notify", reflect.TypeOf((*MockHTTPServer)(nil).Notify)) +} + +// Shutdown mocks base method. +func (m *MockHTTPServer) Shutdown() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Shutdown") + ret0, _ := ret[0].(error) + return ret0 +} + +// Shutdown indicates an expected call of Shutdown. +func (mr *MockHTTPServerMockRecorder) Shutdown() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Shutdown", reflect.TypeOf((*MockHTTPServer)(nil).Shutdown)) +} + +// Start mocks base method. +func (m *MockHTTPServer) Start() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Start") + ret0, _ := ret[0].(error) + return ret0 +} + +// Start indicates an expected call of Start. +func (mr *MockHTTPServerMockRecorder) Start() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockHTTPServer)(nil).Start)) +} + +// MockWebSocketUpgrader is a mock of WebSocketUpgrader interface. +type MockWebSocketUpgrader struct { + ctrl *gomock.Controller + recorder *MockWebSocketUpgraderMockRecorder +} + +// MockWebSocketUpgraderMockRecorder is the mock recorder for MockWebSocketUpgrader. +type MockWebSocketUpgraderMockRecorder struct { + mock *MockWebSocketUpgrader +} + +// NewMockWebSocketUpgrader creates a new mock instance. +func NewMockWebSocketUpgrader(ctrl *gomock.Controller) *MockWebSocketUpgrader { + mock := &MockWebSocketUpgrader{ctrl: ctrl} + mock.recorder = &MockWebSocketUpgraderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockWebSocketUpgrader) EXPECT() *MockWebSocketUpgraderMockRecorder { + return m.recorder +} + +// Upgrade mocks base method. +func (m *MockWebSocketUpgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upgrade", w, r, responseHeader) + ret0, _ := ret[0].(*websocket.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Upgrade indicates an expected call of Upgrade. +func (mr *MockWebSocketUpgraderMockRecorder) Upgrade(w, r, responseHeader any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upgrade", reflect.TypeOf((*MockWebSocketUpgrader)(nil).Upgrade), w, r, responseHeader) +} diff --git a/internal/controller/http/v1/devicemanagement_mocks_test.go b/internal/controller/http/v1/devicemanagement_mocks_test.go index 67ae281c..f4cbf0e9 100644 --- a/internal/controller/http/v1/devicemanagement_mocks_test.go +++ b/internal/controller/http/v1/devicemanagement_mocks_test.go @@ -84,6 +84,73 @@ func (mr *MockWSMANMockRecorder) Worker() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Worker", reflect.TypeOf((*MockWSMAN)(nil).Worker)) } +// MockWebSocketConn is a mock of WebSocketConn interface. +type MockWebSocketConn struct { + ctrl *gomock.Controller + recorder *MockWebSocketConnMockRecorder +} + +// MockWebSocketConnMockRecorder is the mock recorder for MockWebSocketConn. +type MockWebSocketConnMockRecorder struct { + mock *MockWebSocketConn +} + +// NewMockWebSocketConn creates a new mock instance. +func NewMockWebSocketConn(ctrl *gomock.Controller) *MockWebSocketConn { + mock := &MockWebSocketConn{ctrl: ctrl} + mock.recorder = &MockWebSocketConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockWebSocketConn) EXPECT() *MockWebSocketConnMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockWebSocketConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockWebSocketConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockWebSocketConn)(nil).Close)) +} + +// ReadMessage mocks base method. +func (m *MockWebSocketConn) ReadMessage() (int, []byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadMessage") + ret0, _ := ret[0].(int) + ret1, _ := ret[1].([]byte) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ReadMessage indicates an expected call of ReadMessage. +func (mr *MockWebSocketConnMockRecorder) ReadMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadMessage", reflect.TypeOf((*MockWebSocketConn)(nil).ReadMessage)) +} + +// WriteMessage mocks base method. +func (m *MockWebSocketConn) WriteMessage(messageType int, data []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteMessage", messageType, data) + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteMessage indicates an expected call of WriteMessage. +func (mr *MockWebSocketConnMockRecorder) WriteMessage(messageType, data any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteMessage", reflect.TypeOf((*MockWebSocketConn)(nil).WriteMessage), messageType, data) +} + // MockRedirection is a mock of Redirection interface. type MockRedirection struct { ctrl *gomock.Controller @@ -834,4 +901,4 @@ func (m *MockDeviceManagementFeature) Update(ctx context.Context, d *dto.Device) func (mr *MockDeviceManagementFeatureMockRecorder) Update(ctx, d any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDeviceManagementFeature)(nil).Update), ctx, d) -} +} \ No newline at end of file diff --git a/internal/controller/http/v2/devicemanagement_mocks_test.go b/internal/controller/http/v2/devicemanagement_mocks_test.go index b4dee1b4..88d2af58 100644 --- a/internal/controller/http/v2/devicemanagement_mocks_test.go +++ b/internal/controller/http/v2/devicemanagement_mocks_test.go @@ -84,6 +84,73 @@ func (mr *MockWSMANMockRecorder) Worker() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Worker", reflect.TypeOf((*MockWSMAN)(nil).Worker)) } +// MockWebSocketConn is a mock of WebSocketConn interface. +type MockWebSocketConn struct { + ctrl *gomock.Controller + recorder *MockWebSocketConnMockRecorder +} + +// MockWebSocketConnMockRecorder is the mock recorder for MockWebSocketConn. +type MockWebSocketConnMockRecorder struct { + mock *MockWebSocketConn +} + +// NewMockWebSocketConn creates a new mock instance. +func NewMockWebSocketConn(ctrl *gomock.Controller) *MockWebSocketConn { + mock := &MockWebSocketConn{ctrl: ctrl} + mock.recorder = &MockWebSocketConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockWebSocketConn) EXPECT() *MockWebSocketConnMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockWebSocketConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockWebSocketConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockWebSocketConn)(nil).Close)) +} + +// ReadMessage mocks base method. +func (m *MockWebSocketConn) ReadMessage() (int, []byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadMessage") + ret0, _ := ret[0].(int) + ret1, _ := ret[1].([]byte) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ReadMessage indicates an expected call of ReadMessage. +func (mr *MockWebSocketConnMockRecorder) ReadMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadMessage", reflect.TypeOf((*MockWebSocketConn)(nil).ReadMessage)) +} + +// WriteMessage mocks base method. +func (m *MockWebSocketConn) WriteMessage(messageType int, data []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteMessage", messageType, data) + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteMessage indicates an expected call of WriteMessage. +func (mr *MockWebSocketConnMockRecorder) WriteMessage(messageType, data any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteMessage", reflect.TypeOf((*MockWebSocketConn)(nil).WriteMessage), messageType, data) +} + // MockRedirection is a mock of Redirection interface. type MockRedirection struct { ctrl *gomock.Controller @@ -834,4 +901,4 @@ func (m *MockDeviceManagementFeature) Update(ctx context.Context, d *dto.Device) func (mr *MockDeviceManagementFeatureMockRecorder) Update(ctx, d any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDeviceManagementFeature)(nil).Update), ctx, d) -} +} \ No newline at end of file diff --git a/internal/usecase/devices/interceptor.go b/internal/usecase/devices/interceptor.go index 766a476c..fb43fcb5 100644 --- a/internal/usecase/devices/interceptor.go +++ b/internal/usecase/devices/interceptor.go @@ -6,8 +6,11 @@ import ( "crypto/rand" "encoding/binary" "encoding/hex" + "errors" "fmt" + "io" "log" + "math" "github.com/gorilla/websocket" "github.com/open-amt-cloud-toolkit/go-wsman-messages/v2/pkg/wsman" @@ -24,7 +27,7 @@ const ( ) type DeviceConnection struct { - Conn *websocket.Conn + Conn WebSocketConn wsmanMessages wsman.Messages Device dto.Device Direct bool @@ -66,16 +69,17 @@ func (uc *UseCase) Redirect(c context.Context, conn *websocket.Conn, guid, mode } // To Do: scoop the errors out of this for logging - go uc.listenToDevice(c, deviceConnection) - go uc.listenToBrowser(c, deviceConnection) + go uc.ListenToDevice(c, deviceConnection) + go uc.ListenToBrowser(c, deviceConnection) return nil } -func (uc *UseCase) listenToDevice(c context.Context, deviceConnection *DeviceConnection) { +func (uc *UseCase) ListenToDevice(c context.Context, deviceConnection *DeviceConnection) { + conn := deviceConnection.Conn // This is now of type WebSocketConnInterface + for { - // setup listener for response from device - data, err := uc.redirection.RedirectListen(c, deviceConnection) // calls Receive() + data, err := uc.redirection.RedirectListen(c, deviceConnection) if err != nil { break } @@ -88,8 +92,8 @@ func (uc *UseCase) listenToDevice(c context.Context, deviceConnection *DeviceCon if !deviceConnection.Direct { toSend, deviceConnection.Direct = processDeviceData(toSend, &deviceConnection.Challenge) } - // Write message back to browser - err = deviceConnection.Conn.WriteMessage(websocket.BinaryMessage, toSend) + + err = conn.WriteMessage(websocket.BinaryMessage, toSend) if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { _ = fmt.Errorf("interceptor - listenToDevice - websocket closed unexpectedly (writing to browser): %w", err) @@ -103,7 +107,7 @@ func (uc *UseCase) listenToDevice(c context.Context, deviceConnection *DeviceCon } } -func (uc *UseCase) listenToBrowser(c context.Context, deviceConnection *DeviceConnection) { +func (uc *UseCase) ListenToBrowser(c context.Context, deviceConnection *DeviceConnection) { for { _, msg, err := deviceConnection.Conn.ReadMessage() if err != nil { @@ -240,7 +244,6 @@ func handleAuthenticateSessionReply(msg []byte, challenge *client.AuthChallenge) return msg, false } -// RandomValueHex generates a random hexadecimal string of the specified length. func RandomValueHex(length int) (string, error) { divideByHalf := 2 n := (length + 1) / divideByHalf // Calculate the number of bytes needed @@ -256,14 +259,24 @@ func RandomValueHex(length int) (string, error) { } // Helper function to write length and bytes. -func writeField(buf *bytes.Buffer, field string) { - if err := binary.Write(buf, binary.BigEndian, uint8(len(field))); err != nil { - log.Fatal(err) +func writeField(buf io.Writer, field string) error { + // Check for potential overflow + var fieldLen uint8 + if len(field) <= math.MaxUint8 { + fieldLen = uint8(len(field)) //nolint:gosec // Ignore potential overflow here as overflow validated earlier in code + } else { + return ErrLengthLimit + } + + if err := binary.Write(buf, binary.BigEndian, fieldLen); err != nil { + return err } if err := binary.Write(buf, binary.BigEndian, []byte(field)); err != nil { - log.Fatal(err) + return err } + + return nil } func handleAuthenticationSession(msg []byte, challenge *client.AuthChallenge) []byte { @@ -275,71 +288,177 @@ func handleAuthenticationSession(msg []byte, challenge *client.AuthChallenge) [] return msg } - buf := bytes.NewReader(msg[1:9]) - // Variable to hold the decoded value + return processAuthChallenge(msg[1:9], challenge) +} + +func processAuthChallenge(data []byte, challenge *client.AuthChallenge) []byte { + buf := bytes.NewReader(data) + var status uint8 var unknown uint16 var authType uint8 - // Read the binary data into the variable - _ = binary.Read(buf, binary.BigEndian, &status) - _ = binary.Read(buf, binary.BigEndian, &unknown) - _ = binary.Read(buf, binary.BigEndian, &authType) - // generate auth challenge + if err := readBinaryData(buf, &status, &unknown, &authType); err != nil { + log.Printf("Error reading binary data: %v", err) - authURL := "/RedirectionService" + return nil + } if authType == AuthenticationTypeDigest { - if challenge.Realm != "" { - nc := challenge.NonceCount - randomByteCount := 10 - challenge.CNonce, _ = RandomValueHex(randomByteCount) - nonceCount := fmt.Sprintf("%08x", nc) - nonceData := challenge.GetFormattedNonceData(challenge.Nonce) - response := challenge.ComputeDigestHash("POST", authURL, nonceData) - challenge.NonceCount++ - - var replyBuf bytes.Buffer - - _ = binary.Write(&replyBuf, binary.BigEndian, [5]byte{0x13, 0x00, 0x00, 0x00, 0x04}) // [5]byte - _ = binary.Write(&replyBuf, binary.LittleEndian, uint32(len(challenge.Username)+len(challenge.Realm)+len(challenge.Nonce)+len(authURL)+len(challenge.CNonce)+len(nonceCount)+len(response)+len(challenge.Qop)+ContentLengthPadding)) // uint32 - - // Write fields - writeField(&replyBuf, challenge.Username) - writeField(&replyBuf, challenge.Realm) - writeField(&replyBuf, challenge.Nonce) - writeField(&replyBuf, authURL) - writeField(&replyBuf, challenge.CNonce) - writeField(&replyBuf, nonceCount) - writeField(&replyBuf, response) - writeField(&replyBuf, challenge.Qop) - - return replyBuf.Bytes() + return handleDigestAuthentication(challenge) + } + + return []byte("") +} + +func readBinaryData(buf *bytes.Reader, status *uint8, unknown *uint16, authType *uint8) error { + if err := binary.Read(buf, binary.BigEndian, status); err != nil { + return err + } + + if err := binary.Read(buf, binary.BigEndian, unknown); err != nil { + return err + } + + return binary.Read(buf, binary.BigEndian, authType) +} + +func handleDigestAuthentication(challenge *client.AuthChallenge) []byte { + if challenge.Realm != "" { + cnonce, err := generateCNonce(challenge) + if err != nil { + log.Printf("Error generating CNonce: %v", err) + + return nil } - return generateEmptyAuth(challenge, authURL) + challenge.CNonce = cnonce + response := computeDigestResponse(challenge) + + return buildAuthReply(challenge, response) } - return []byte("") + return generateEmptyAuth(challenge, "/RedirectionService") +} + +func generateCNonce(challenge *client.AuthChallenge) (string, error) { + randomByteCount := 10 + cnonce, err := RandomValueHex(randomByteCount) + if err != nil { //nolint:wsl // ignoring cuddle assignment rule for this line due to linter conflicts + return "", err + } + + challenge.NonceCount++ + + return cnonce, nil +} + +func computeDigestResponse(challenge *client.AuthChallenge) string { + nonceData := challenge.GetFormattedNonceData(challenge.Nonce) + + return challenge.ComputeDigestHash("POST", "/RedirectionService", nonceData) +} + +func buildAuthReply(challenge *client.AuthChallenge, response string) []byte { + var replyBuf bytes.Buffer + + if err := writeHeader(&replyBuf); err != nil { + log.Printf("Error writing header: %v", err) + + return nil + } + + if err := writeLength(&replyBuf, challenge, response); err != nil { + log.Printf("Error writing length: %v", err) + + return nil + } + + if err := writeFields(&replyBuf, challenge, response); err != nil { + log.Printf("Error writing fields: %v", err) + + return nil + } + + return replyBuf.Bytes() +} + +func writeHeader(buf *bytes.Buffer) error { + return binary.Write(buf, binary.BigEndian, [5]byte{0x13, 0x00, 0x00, 0x00, 0x04}) +} + +var ErrLengthLimit = errors.New("calculated length exceeds uint32 limit") + +func writeLength(buf *bytes.Buffer, challenge *client.AuthChallenge, response string) error { + totalLength := len(challenge.Username) + len(challenge.Realm) + len(challenge.Nonce) + len("/RedirectionService") + + len(challenge.CNonce) + len(fmt.Sprintf("%08x", challenge.NonceCount)) + len(response) + len(challenge.Qop) + + ContentLengthPadding + + if totalLength > math.MaxUint32 { + return ErrLengthLimit // If total length is too large, throws an error and stops here + } + + length := uint32(totalLength) //nolint:gosec // Ignore potential integer overflow here as overflow is validated earlier in code + + return binary.Write(buf, binary.LittleEndian, length) +} + +func writeFields(buf *bytes.Buffer, challenge *client.AuthChallenge, response string) error { + if err := writeField(buf, challenge.Username); err != nil { + return err + } + + if err := writeField(buf, challenge.Realm); err != nil { + return err + } + + if err := writeField(buf, challenge.Nonce); err != nil { + return err + } + + if err := writeField(buf, "/RedirectionService"); err != nil { + return err + } + + if err := writeField(buf, challenge.CNonce); err != nil { + return err + } + + if err := writeField(buf, fmt.Sprintf("%08x", challenge.NonceCount)); err != nil { + return err + } + + if err := writeField(buf, response); err != nil { + return err + } + + return writeField(buf, challenge.Qop) } func generateEmptyAuth(challenge *client.AuthChallenge, authURL string) []byte { var buf bytes.Buffer + lenChallengeUsername := uint8(0) + lenAuthURL := uint8(0) + + // If challenge has values that will cause overflow, stop them here + lenChallengeUsername = uint8(len(challenge.Username)) //nolint:gosec // Ignore potential integer overflow here as overflow is being validated + lenAuthURL = uint8(len(authURL)) //nolint:gosec // Ignore potential integer overflow here as overflow is being validated + emptyAuth := emptyAuth{ - usernameLength: uint8(len(challenge.Username)), + usernameLength: lenChallengeUsername, // Use calculated safe value authURLPadding: [2]byte{0x00, 0x00}, - authURLLength: uint8(len(authURL)), + authURLLength: lenAuthURL, // Use calculated safe value endPadding: [4]byte{0x00, 0x00, 0x00, 0x00}, } copy(emptyAuth.username[:], challenge.Username) copy(emptyAuth.authURL[:], authURL) - _ = binary.Write(&buf, binary.BigEndian, [5]byte{0x13, 0x00, 0x00, 0x00, 0x04}) // header - _ = binary.Write(&buf, binary.LittleEndian, uint32(len(challenge.Username)+len(authURL)+ContentLengthPadding)) // flip flop endian for content length + _ = binary.Write(&buf, binary.BigEndian, [5]byte{0x13, 0x00, 0x00, 0x00, 0x04}) // header + _ = binary.Write(&buf, binary.LittleEndian, uint32(lenChallengeUsername+lenAuthURL)+ContentLengthPadding) // flip flop endian for content length _ = binary.Write(&buf, binary.BigEndian, emptyAuth) return buf.Bytes() diff --git a/internal/usecase/devices/interceptor_private_test.go b/internal/usecase/devices/interceptor_private_test.go new file mode 100644 index 00000000..c042d53d --- /dev/null +++ b/internal/usecase/devices/interceptor_private_test.go @@ -0,0 +1,525 @@ +package devices + +import ( + "bytes" + "errors" + "io" + "testing" + + "github.com/open-amt-cloud-toolkit/go-wsman-messages/v2/pkg/wsman/client" + "github.com/stretchr/testify/require" +) + +func TestProcessBrowserData(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + msg []byte + challenge *client.AuthChallenge + expectedBytes []byte + }{ + { + name: "Start Redirection Session", + msg: []byte{RedirectionCommandsStartRedirectionSession, 0, 0, 0, 0, 0, 0, 0, 0}, + expectedBytes: []byte{RedirectionCommandsStartRedirectionSession, 0, 0, 0, 0, 0, 0, 0}, + }, + { + name: "End Redirection Session", + msg: []byte{RedirectionCommandsEndRedirectionSession, 0, 0, 0}, + expectedBytes: []byte{RedirectionCommandsEndRedirectionSession, 0, 0, 0}, + }, + { + name: "Authenticate Session", + msg: []byte{RedirectionCommandsAuthenticateSession, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + challenge: &client.AuthChallenge{ + Username: "testuser", + Password: "testpassword", + Realm: "testrealm", + CSRFToken: "csrf1234", + Domain: "testdomain", + Nonce: "noncevalue", + Opaque: "opaquevalue", + Stale: "false", + Algorithm: "MD5", + Qop: "auth", + CNonce: "cnoncevalue", + NonceCount: 1, + }, + expectedBytes: []byte{0x13, 0x0, 0x0, 0x0, 0x4, 0x6c, 0x0, 0x0, 0x0, 0x8, 0x74, 0x65, 0x73, 0x74, 0x75, 0x73, 0x65, 0x72, 0x9, 0x74, 0x65, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6c, 0x6d, 0xa, 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x13, 0x2f, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0xa, 0x63, 0x34, 0x65, 0x64, 0x35, 0x32, 0x62, 0x30, 0x37, 0x61, 0x8, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x31, 0x20, 0x31, 0x65, 0x39, 0x36, 0x35, 0x66, 0x32, 0x33, 0x30, 0x38, 0x63, 0x38, 0x35, 0x64, 0x35, 0x35, 0x63, 0x63, 0x31, 0x65, 0x37, 0x62, 0x33, 0x38, 0x36, 0x36, 0x31, 0x32, 0x65, 0x39, 0x38, 0x38, 0x4, 0x61, 0x75, 0x74, 0x68}, + }, + { + name: "Default Case", + msg: []byte{0xFF}, + expectedBytes: nil, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := processBrowserData(tc.msg, tc.challenge) + require.IsType(t, tc.expectedBytes, result) + }) + } +} + +func TestProcessDeviceData(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + msg []byte + challenge *client.AuthChallenge + expectedData []byte + expectedBool bool + }{ + { + name: "Start Redirection Session Reply", + msg: []byte{RedirectionCommandsStartRedirectionSessionReply, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, + challenge: nil, + expectedData: []byte{0x11, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x2}, + expectedBool: false, + }, + { + name: "Authenticate Session Reply", + msg: []byte{RedirectionCommandsAuthenticateSessionReply, 0x01, 0x02, 0x03, 0x04}, + challenge: &client.AuthChallenge{}, + expectedData: []byte{}, + expectedBool: false, + }, + { + name: "Unhandled Command", + msg: []byte{0x99}, + challenge: nil, + expectedData: nil, + expectedBool: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data, ok := processDeviceData(tc.msg, tc.challenge) + + require.Equal(t, tc.expectedData, data) + require.Equal(t, tc.expectedBool, ok) + }) + } +} + +func TestHandleStartRedirectionSessionReply(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + msg []byte + expected []byte + }{ + { + name: "Valid Session Reply", + msg: []byte{RedirectionCommandsStartRedirectionSessionReply, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, + expected: []byte{0x11, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x2}, + }, + { + name: "Message Shorter Than RedirectionSessionReply", + msg: []byte{RedirectionCommandsStartRedirectionSessionReply, 0x01}, + expected: []byte(""), + }, + { + name: "Message Shorter Than RedirectSessionLengthBytes", + msg: []byte{RedirectionCommandsStartRedirectionSessionReply, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + expected: []byte(""), + }, + { + name: "Message Shorter Than RedirectSessionLengthBytes Plus OEM Length", + msg: []byte{RedirectionCommandsStartRedirectionSessionReply, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x04}, + expected: []byte(""), + }, + { + name: "Invalid Session Reply", + msg: []byte{RedirectionCommandsStartRedirectionSessionReply, 7, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, + expected: []byte(""), + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := handleStartRedirectionSessionReply(tc.msg) + + require.Equal(t, tc.expected, result) + }) + } +} + +func TestAllZero(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + data []byte + expected bool + }{ + { + name: "All zeros", + data: []byte{0x00, 0x00, 0x00}, + expected: true, + }, + { + name: "Not all zeros", + data: []byte{0x00, 0x01, 0x00}, + expected: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := allZero(tc.data) + + require.Equal(t, tc.expected, result) + }) + } +} + +type failBuffer struct{} + +var ErrSimWriteFail = errors.New("simulated write failure") + +func (f *failBuffer) Write(_ []byte) (n int, err error) { + return 0, ErrSimWriteFail +} + +type failBufferOnSecondWrite struct { + count int +} + +var ErrForcedFailure = errors.New("forced failure on second write") + +func (f *failBufferOnSecondWrite) Write(p []byte) (n int, err error) { + f.count++ + if f.count == 2 { + return 0, ErrForcedFailure + } + + return len(p), nil +} + +func TestWriteField(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + field string + expected []byte + shouldFail bool + buffer io.Writer + }{ + { + name: "Valid field", + field: "test", + expected: append([]byte{0x04}, []byte("test")...), + buffer: &bytes.Buffer{}, + }, + { + name: "Empty field", + field: "", + expected: []byte{0x00}, + buffer: &bytes.Buffer{}, + }, + { + name: "Write length failure", + field: "failLength", + shouldFail: true, + buffer: &failBuffer{}, + }, + { + name: "Write field failure", + field: "failField", + shouldFail: true, + buffer: &failBufferOnSecondWrite{}, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := writeField(tc.buffer, tc.field) + + if tc.shouldFail { + require.Error(t, err) + } else { + require.NoError(t, err) + + if buf, ok := tc.buffer.(*bytes.Buffer); ok { + result := buf.Bytes() + require.Equal(t, tc.expected, result) + } + } + }) + } +} + +func TestGenerateEmptyAuth(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + challenge *client.AuthChallenge + authURL string + expectedBuf []byte + }{ + { + name: "Valid Auth Challenge", + challenge: &client.AuthChallenge{ + Username: "testuser", + }, + authURL: "http://example.com", + expectedBuf: []byte{0x13, 0x0, 0x0, 0x0, 0x4, 0x22, 0x0, 0x0, 0x0, 0x8, 0x74, 0x65, 0x73, 0x74, 0x75, 0x0, 0x0, 0x12, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x0, 0x0}, + }, + { + name: "Empty Username and URL", + challenge: &client.AuthChallenge{ + Username: "", + }, + authURL: "", + expectedBuf: []byte{0x13, 0x0, 0x0, 0x0, 0x4, 0x8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := generateEmptyAuth(tc.challenge, tc.authURL) + + require.Equal(t, tc.expectedBuf, result) + }) + } +} + +func TestHandleAuthenticateSessionReply(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + msg []byte + expectedResult []byte + expectedSuccess bool + expectedChallenge *client.AuthChallenge + }{ + { + name: "Message too short", + msg: []byte{0x01}, + expectedResult: []byte(""), + expectedSuccess: false, + }, + { + name: "Valid Digest Authentication Fail", + msg: []byte{ + 0x01, + AuthenticationStatusFail, 0x00, 0x00, AuthenticationTypeDigest, 0x12, 0x00, 0x00, 0x00, + 0x05, + 'r', 'e', 'a', 'l', 'm', + 0x06, + 'n', 'o', 'n', 'c', 'e', '1', + 0x03, + 'q', 'o', 'p', + }, + expectedResult: []byte{}, + expectedSuccess: false, + expectedChallenge: &client.AuthChallenge{ + Realm: "", + Nonce: "", + Qop: "", + }, + }, + { + name: "Valid Authentication Success", + msg: []byte{ + 0x01, + AuthenticationStatusSuccess, 0x00, 0x00, AuthenticationTypeQuery, 0x00, 0x00, 0x00, 0x00, + }, + expectedResult: []byte{ + 0x01, AuthenticationStatusSuccess, 0x00, 0x00, AuthenticationTypeQuery, 0x00, 0x00, 0x00, 0x00, + }, + expectedSuccess: false, + }, + { + name: "Valid Authentication Success, Non-Digest Type", + msg: []byte{ + 0x01, + AuthenticationStatusSuccess, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x00, + }, + expectedResult: []byte{ + 0x01, AuthenticationStatusSuccess, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x00, + }, + expectedSuccess: true, + }, + { + name: "Invalid length in message", + msg: []byte{ + 0x01, + AuthenticationStatusFail, 0x00, 0x00, AuthenticationTypeDigest, 0xFF, 0xFF, 0xFF, 0xFF, + }, + expectedResult: []byte(""), + expectedSuccess: false, + }, + { + name: "Digest Authentication Failure with valid realm, nonce, and qop", + msg: []byte{ + 0x01, + AuthenticationStatusFail, 0x00, 0x00, AuthenticationTypeDigest, 0x12, 0x00, 0x00, 0x00, + 0x05, + 'r', 'e', 'a', 'l', 'm', + 0x06, + 'n', 'o', 'n', 'c', 'e', '1', + 0x03, 0x03, 0x03, + 'q', 'o', 'p', + }, + expectedResult: []byte{0x1, 0x1, 0x0, 0x0, 0x4, 0x12, 0x0, 0x0, 0x0, 0x5, 0x72, 0x65, 0x61, 0x6c, 0x6d, 0x6, 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x31, 0x3, 0x3, 0x3, 0x71, 0x6f, 0x70}, + expectedSuccess: false, + expectedChallenge: &client.AuthChallenge{ + Realm: "realm", + Nonce: "nonce1", + Qop: "\x03\x03q", + }, + }, + { + name: "Digest Authentication Failure with empty realm, nonce, and qop", + msg: []byte{ + 0x01, + AuthenticationStatusFail, 0x00, 0x00, AuthenticationTypeDigest, 0x12, 0x00, 0x00, 0x00, + 0x00, + 0x00, + 0x00, + }, + expectedResult: []byte{}, + expectedChallenge: &client.AuthChallenge{ + Realm: "", + Nonce: "", + Qop: "", + }, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + challenge := &client.AuthChallenge{} + result, success := handleAuthenticateSessionReply(tc.msg, challenge) + + require.Equal(t, tc.expectedResult, result) + require.Equal(t, tc.expectedSuccess, success) + + if tc.expectedChallenge != nil { + require.Equal(t, tc.expectedChallenge.Realm, challenge.Realm) + require.Equal(t, tc.expectedChallenge.Nonce, challenge.Nonce) + require.Equal(t, tc.expectedChallenge.Qop, challenge.Qop) + } + }) + } +} + +func TestHandleAuthenticationSession(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + msg []byte + challenge *client.AuthChallenge + expectedResultType interface{} + expectedNonceCount int + }{ + { + name: "Message too short", + msg: []byte{0x01}, + challenge: &client.AuthChallenge{}, + expectedResultType: []byte{}, + }, + { + name: "Message of length 9 with all zeros", + msg: []byte{ + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + challenge: &client.AuthChallenge{}, + expectedResultType: []byte{}, + }, + { + name: "Digest Authentication with empty Realm", + msg: []byte{ + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, + }, + challenge: &client.AuthChallenge{}, + expectedResultType: []byte{}, + }, + { + name: "Digest Authentication with Realm", + msg: []byte{ + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, + }, + challenge: &client.AuthChallenge{ + Username: "", + Password: "", + Realm: "exampleRealm", + CSRFToken: "", + Domain: "", + Nonce: "", + Opaque: "", + Stale: "", + Algorithm: "", + Qop: "", + CNonce: "", + NonceCount: 1, + }, + expectedResultType: []byte{}, + }, + { + name: "Non-Digest Authentication", + msg: []byte{ + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, + }, + challenge: &client.AuthChallenge{}, + expectedResultType: []byte{}, + }, + { + name: "End of function returns empty byte slice", + msg: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + challenge: &client.AuthChallenge{}, + expectedResultType: []byte{}, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := handleAuthenticationSession(tc.msg, tc.challenge) + + require.IsType(t, tc.expectedResultType, result) + + if len(tc.expectedResultType.([]byte)) > 0 { + require.NotEmpty(t, result) + } + + if tc.expectedNonceCount != 0 { + require.Equal(t, tc.expectedNonceCount, tc.challenge.NonceCount) + } + }) + } +} diff --git a/internal/usecase/devices/interceptor_test.go b/internal/usecase/devices/interceptor_test.go new file mode 100644 index 00000000..02295a18 --- /dev/null +++ b/internal/usecase/devices/interceptor_test.go @@ -0,0 +1,90 @@ +package devices_test + +import ( + "context" + "sync" + "testing" + + "github.com/gorilla/websocket" + "github.com/open-amt-cloud-toolkit/go-wsman-messages/v2/pkg/wsman" + "github.com/stretchr/testify/require" + gomock "go.uber.org/mock/gomock" + + "github.com/open-amt-cloud-toolkit/console/internal/entity" + devices "github.com/open-amt-cloud-toolkit/console/internal/usecase/devices" + "github.com/open-amt-cloud-toolkit/console/pkg/logger" +) + +func TestRedirect(t *testing.T) { + t.Parallel() + + mockConn := &websocket.Conn{} + guid := "device-guid-123" + mode := "default" + + tests := []struct { + name string + setup func(*MockRedirection, *MockRepository, *MockWSMAN, *sync.WaitGroup) + expectedErr error + }{ + { + name: "GetByID fail redirection", + setup: func(_ *MockRedirection, mockRepo *MockRepository, mockWSMAN *MockWSMAN, wg *sync.WaitGroup) { + mockWSMAN.EXPECT().Worker().Do(func() { + defer wg.Done() + }).Times(1) + mockRepo.EXPECT().GetByID(gomock.Any(), guid, "").Return(nil, ErrGeneral) + }, + expectedErr: ErrGeneral, + }, + { + name: "RedirectConnect fail redirection", + setup: func(mockRedir *MockRedirection, mockRepo *MockRepository, mockWSMAN *MockWSMAN, wg *sync.WaitGroup) { + mockWSMAN.EXPECT().Worker().Do(func() { + defer wg.Done() + }).Times(1) + mockRepo.EXPECT().GetByID(gomock.Any(), guid, "").Return(&entity.Device{ + GUID: guid, + Username: "user", + Password: "pass", + }, nil) + mockRedir.EXPECT().SetupWsmanClient(gomock.Any(), true, true).Return(wsman.Messages{}) + mockRedir.EXPECT().RedirectConnect(gomock.Any(), gomock.Any()).Return(ErrGeneral) + }, + expectedErr: ErrGeneral, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRedirection := NewMockRedirection(ctrl) + mockRepo := NewMockRepository(ctrl) + mockWSMAN := NewMockWSMAN(ctrl) + + var wg sync.WaitGroup + + wg.Add(1) + + tc.setup(mockRedirection, mockRepo, mockWSMAN, &wg) + + uc := devices.New(mockRepo, mockWSMAN, mockRedirection, logger.New("test")) + + wg.Wait() + + err := uc.Redirect(context.Background(), mockConn, guid, mode) + + if tc.expectedErr != nil { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectedErr.Error()) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/internal/usecase/devices/interfaces.go b/internal/usecase/devices/interfaces.go index bad886a4..5e619097 100644 --- a/internal/usecase/devices/interfaces.go +++ b/internal/usecase/devices/interfaces.go @@ -19,6 +19,12 @@ type ( Worker() } + WebSocketConn interface { + ReadMessage() (int, []byte, error) + WriteMessage(messageType int, data []byte) error + Close() error + } + Redirection interface { SetupWsmanClient(device dto.Device, isRedirection, logMessages bool) wsman.Messages RedirectConnect(ctx context.Context, deviceConnection *DeviceConnection) error diff --git a/internal/usecase/devices/mocks_test.go b/internal/usecase/devices/mocks_test.go index 357f3ebd..d4efd63d 100644 --- a/internal/usecase/devices/mocks_test.go +++ b/internal/usecase/devices/mocks_test.go @@ -84,6 +84,73 @@ func (mr *MockWSMANMockRecorder) Worker() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Worker", reflect.TypeOf((*MockWSMAN)(nil).Worker)) } +// MockWebSocketConn is a mock of WebSocketConn interface. +type MockWebSocketConn struct { + ctrl *gomock.Controller + recorder *MockWebSocketConnMockRecorder +} + +// MockWebSocketConnMockRecorder is the mock recorder for MockWebSocketConn. +type MockWebSocketConnMockRecorder struct { + mock *MockWebSocketConn +} + +// NewMockWebSocketConn creates a new mock instance. +func NewMockWebSocketConn(ctrl *gomock.Controller) *MockWebSocketConn { + mock := &MockWebSocketConn{ctrl: ctrl} + mock.recorder = &MockWebSocketConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockWebSocketConn) EXPECT() *MockWebSocketConnMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockWebSocketConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockWebSocketConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockWebSocketConn)(nil).Close)) +} + +// ReadMessage mocks base method. +func (m *MockWebSocketConn) ReadMessage() (int, []byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadMessage") + ret0, _ := ret[0].(int) + ret1, _ := ret[1].([]byte) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ReadMessage indicates an expected call of ReadMessage. +func (mr *MockWebSocketConnMockRecorder) ReadMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadMessage", reflect.TypeOf((*MockWebSocketConn)(nil).ReadMessage)) +} + +// WriteMessage mocks base method. +func (m *MockWebSocketConn) WriteMessage(messageType int, data []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteMessage", messageType, data) + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteMessage indicates an expected call of WriteMessage. +func (mr *MockWebSocketConnMockRecorder) WriteMessage(messageType, data any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteMessage", reflect.TypeOf((*MockWebSocketConn)(nil).WriteMessage), messageType, data) +} + // MockRedirection is a mock of Redirection interface. type MockRedirection struct { ctrl *gomock.Controller @@ -834,4 +901,4 @@ func (m *MockFeature) Update(ctx context.Context, d *dto.Device) (*dto.Device, e func (mr *MockFeatureMockRecorder) Update(ctx, d any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockFeature)(nil).Update), ctx, d) -} +} \ No newline at end of file