From f365ea51bda9ef616934800dbf0a5872f4386a0d Mon Sep 17 00:00:00 2001 From: Ayush Thakur <100013900+ayusht2810@users.noreply.github.com> Date: Fri, 13 Oct 2023 10:12:41 +0530 Subject: [PATCH] [MI-3591] Add transactions in the functions to avoid race conditions (#344) * [MI-3591] Add transactions in the functions to avoid race conditions * [MI-3591] Fix test cases and lint errors * [MI-3591] Update test cases * [MI-3591] Refactore test cases * [MI-3591] Add checks for nil in client * [MI-3591] Add more checks for nil in client * [MI-3591] Update log messages * [MI-3591] Fix lint * [MI-3591] Remove extra condition --- server/command.go | 28 +- server/command_test.go | 380 +++++++------------- server/handlers/handlers.go | 3 +- server/handlers/handlers_test.go | 7 +- server/message_hooks.go | 248 ++++++++++--- server/message_hooks_test.go | 510 +++++++++++++++++++++++++-- server/monitor/subscriptions.go | 24 +- server/monitor/subscriptions_test.go | 87 ++++- server/msteams/client.go | 56 +++ server/plugin_test.go | 3 +- server/store/mocks/Store.go | 121 +++++-- server/store/store.go | 165 +++++++-- server/store/store_test.go | 183 ++++++---- server/testutils/data.go | 36 ++ 14 files changed, 1377 insertions(+), 474 deletions(-) diff --git a/server/command.go b/server/command.go index 6ae6dcbaf..4bad80747 100644 --- a/server/command.go +++ b/server/command.go @@ -207,15 +207,35 @@ func (p *Plugin) executeLinkCommand(args *model.CommandArgs, parameters []string return p.cmdError(args.UserId, args.ChannelId, "Unable to create new link.") } - err = p.store.SaveChannelSubscription(storemodels.ChannelSubscription{ + tx, err := p.store.BeginTx() + if err != nil { + p.API.LogError("Unable to begin the database transaction", "error", err.Error()) + return p.cmdError(args.UserId, args.ChannelId, "Something went wrong") + } + + var txErr error + defer func() { + if txErr != nil { + if err := p.store.RollbackTx(tx); err != nil { + p.API.LogError("Unable to rollback database transaction", "error", err.Error()) + } + } + }() + + if txErr = p.store.SaveChannelSubscription(storemodels.ChannelSubscription{ SubscriptionID: channelsSubscription.ID, TeamID: channelLink.MSTeamsTeam, ChannelID: channelLink.MSTeamsChannel, ExpiresOn: channelsSubscription.ExpiresOn, Secret: p.getConfiguration().WebhookSecret, - }) - if err != nil { - return p.cmdError(args.UserId, args.ChannelId, "Unable to save the subscription in the monitoring system: "+err.Error()) + }, tx); txErr != nil { + p.API.LogWarn("Unable to save the subscription in the DB", "error", txErr.Error()) + return p.cmdError(args.UserId, args.ChannelId, "Error occurred while saving the subscription") + } + + if err := p.store.CommitTx(tx); err != nil { + p.API.LogError("Unable to commit database transaction", "error", err.Error()) + return p.cmdError(args.UserId, args.ChannelId, "Something went wrong") } p.sendBotEphemeralPost(args.UserId, args.ChannelId, "The MS Teams channel is now linked to this Mattermost channel.") diff --git a/server/command_test.go b/server/command_test.go index fd81c3363..ce823bf25 100644 --- a/server/command_test.go +++ b/server/command_test.go @@ -1,6 +1,7 @@ package main import ( + "database/sql" "testing" "time" @@ -41,11 +42,7 @@ func TestExecuteUnlinkCommand(t *testing.T) { Type: model.ChannelTypeOpen, }, nil).Times(1) api.On("HasPermissionToChannel", testutils.GetUserID(), testutils.GetChannelID(), model.PermissionManageChannelRoles).Return(true).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "The MS Teams channel is no longer linked to this Mattermost channel.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "The MS Teams channel is no longer linked to this Mattermost channel.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) api.On("LogDebug", "Unable to delete the subscription on MS Teams", "subscriptionID", "testSubscriptionID", "error", "unable to delete the subscription").Return().Once() }, setupStore: func(s *mockStore.Store) { @@ -74,11 +71,7 @@ func TestExecuteUnlinkCommand(t *testing.T) { Type: model.ChannelTypeOpen, }, nil).Times(1) api.On("HasPermissionToChannel", testutils.GetUserID(), "Mock-ChannelID", model.PermissionManageChannelRoles).Return(true).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: "Mock-ChannelID", - Message: "This Mattermost channel is not linked to any MS Teams channel.", - }).Return(testutils.GetPost("Mock-ChannelID", testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", "Mock-ChannelID", "This Mattermost channel is not linked to any MS Teams channel.")).Return(testutils.GetPost("Mock-ChannelID", testutils.GetUserID(), time.Now().UnixMicro())).Times(1) api.On("LogDebug", "Unable to get the link by channel ID", "error", "Error while getting link").Return().Once() }, setupStore: func(s *mockStore.Store) { @@ -98,11 +91,7 @@ func TestExecuteUnlinkCommand(t *testing.T) { Type: model.ChannelTypeOpen, }, nil).Times(1) api.On("HasPermissionToChannel", testutils.GetUserID(), "Mock-ChannelID", model.PermissionManageChannelRoles).Return(true).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: "Mock-ChannelID", - Message: "Unable to delete link.", - }).Return(testutils.GetPost("Mock-ChannelID", testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", "Mock-ChannelID", "Unable to delete link.")).Return(testutils.GetPost("Mock-ChannelID", testutils.GetUserID(), time.Now().UnixMicro())).Times(1) api.On("LogDebug", "Unable to delete the link by channel ID", "error", "Error while deleting a link").Return().Once() }, setupStore: func(s *mockStore.Store) { @@ -116,10 +105,7 @@ func TestExecuteUnlinkCommand(t *testing.T) { args: &model.CommandArgs{}, setupAPI: func(api *plugintest.API) { api.On("GetChannel", "").Return(nil, testutils.GetInternalServerAppError("Error while getting the current channel.")).Once() - api.On("SendEphemeralPost", "", &model.Post{ - UserId: "bot-user-id", - Message: "Unable to get the current channel information.", - }).Return(testutils.GetPost(testutils.GetChannelID(), "bot-user-id", time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", "", testutils.GetEphemeralPost("bot-user-id", "", "Unable to get the current channel information.")).Return(testutils.GetPost(testutils.GetChannelID(), "bot-user-id", time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) {}, setupClient: func(c *mockClient.Client) {}, @@ -136,11 +122,7 @@ func TestExecuteUnlinkCommand(t *testing.T) { Type: model.ChannelTypeOpen, }, nil).Times(1) api.On("HasPermissionToChannel", testutils.GetUserID(), testutils.GetChannelID(), model.PermissionManageChannelRoles).Return(false).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - ChannelId: testutils.GetChannelID(), - UserId: "bot-user-id", - Message: "Unable to unlink the channel, you have to be a channel admin to unlink it.", - }).Return(testutils.GetPost(testutils.GetChannelID(), "bot-user-id", time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "Unable to unlink the channel, you have to be a channel admin to unlink it.")).Return(testutils.GetPost(testutils.GetChannelID(), "bot-user-id", time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) {}, setupClient: func(c *mockClient.Client) {}, @@ -156,11 +138,7 @@ func TestExecuteUnlinkCommand(t *testing.T) { Id: testutils.GetChannelID(), Type: model.ChannelTypeDirect, }, nil).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - ChannelId: testutils.GetChannelID(), - UserId: "bot-user-id", - Message: "Linking/unlinking a direct or group message is not allowed", - }).Return(testutils.GetPost(testutils.GetChannelID(), "bot-user-id", time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "Linking/unlinking a direct or group message is not allowed")).Return(testutils.GetPost(testutils.GetChannelID(), "bot-user-id", time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) {}, setupClient: func(c *mockClient.Client) {}, @@ -177,11 +155,7 @@ func TestExecuteUnlinkCommand(t *testing.T) { Type: model.ChannelTypeOpen, }, nil).Times(1) api.On("HasPermissionToChannel", testutils.GetUserID(), testutils.GetChannelID(), model.PermissionManageChannelRoles).Return(true).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "The MS Teams channel is no longer linked to this Mattermost channel.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "The MS Teams channel is no longer linked to this Mattermost channel.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) api.On("LogDebug", "Unable to get the subscription by MS Teams channel ID", "error", "unable to get the subscription").Return().Once() }, setupStore: func(s *mockStore.Store) { @@ -205,11 +179,7 @@ func TestExecuteUnlinkCommand(t *testing.T) { Type: model.ChannelTypeOpen, }, nil).Times(1) api.On("HasPermissionToChannel", testutils.GetUserID(), testutils.GetChannelID(), model.PermissionManageChannelRoles).Return(true).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "The MS Teams channel is no longer linked to this Mattermost channel.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "The MS Teams channel is no longer linked to this Mattermost channel.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) api.On("LogDebug", "Unable to delete the subscription from the DB", "subscriptionID", "testSubscriptionID", "error", "unable to delete the subscription").Return().Once() }, setupStore: func(s *mockStore.Store) { @@ -254,11 +224,7 @@ func TestExecuteShowCommand(t *testing.T) { ChannelId: testutils.GetChannelID(), }, setupAPI: func(api *plugintest.API) { - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "This channel is linked to the MS Teams Channel \"\" (with id: ) in the Team \"\" (with the id: ).", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "This channel is linked to the MS Teams Channel \"\" (with id: ) in the Team \"\" (with the id: ).")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("GetLinkByChannelID", testutils.GetChannelID()).Return(&storemodels.ChannelLink{ @@ -274,10 +240,7 @@ func TestExecuteShowCommand(t *testing.T) { description: "Unable to get the link", args: &model.CommandArgs{}, setupAPI: func(api *plugintest.API) { - api.On("SendEphemeralPost", "", &model.Post{ - UserId: "bot-user-id", - Message: "Link doesn't exist.", - }).Return(testutils.GetPost("", "", time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", "", testutils.GetEphemeralPost("bot-user-id", "", "Link doesn't exist.")).Return(testutils.GetPost("", "", time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("GetLinkByChannelID", "").Return(nil, errors.New("Error while getting the link")).Times(1) @@ -290,11 +253,7 @@ func TestExecuteShowCommand(t *testing.T) { ChannelId: "Invalid-ChannelID", }, setupAPI: func(api *plugintest.API) { - api.On("SendEphemeralPost", "", &model.Post{ - UserId: "bot-user-id", - ChannelId: "Invalid-ChannelID", - Message: "Unable to get the MS Teams team information.", - }).Return(testutils.GetPost("Invalid-ChannelID", "", time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", "", testutils.GetEphemeralPost("bot-user-id", "Invalid-ChannelID", "Unable to get the MS Teams team information.")).Return(testutils.GetPost("Invalid-ChannelID", "", time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("GetLinkByChannelID", "Invalid-ChannelID").Return(&storemodels.ChannelLink{ @@ -337,17 +296,9 @@ func TestExecuteShowLinksCommand(t *testing.T) { setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(true).Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: commandWaitingMessage, - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), commandWaitingMessage)).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "| Mattermost Team | Mattermost Channel | MS Teams Team | MS Teams Channel | \n| :------|:--------|:-------|:-----------|\n|Test MM team|Test MM channel|Test MS team|Test MS channel|", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "| Mattermost Team | Mattermost Channel | MS Teams Team | MS Teams Channel | \n| :------|:--------|:-------|:-----------|\n|Test MM team|Test MM channel|Test MS team|Test MS channel|")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("ListChannelLinksWithNames").Return(testutils.GetChannelLinks(1), nil).Times(1) @@ -365,11 +316,7 @@ func TestExecuteShowLinksCommand(t *testing.T) { }, setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(false).Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "Unable to execute the command, only system admins have access to execute this command.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "Unable to execute the command, only system admins have access to execute this command.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) {}, setupClient: func(c *mockClient.Client) {}, @@ -383,11 +330,7 @@ func TestExecuteShowLinksCommand(t *testing.T) { setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(true).Once() api.On("LogDebug", "Unable to get links from store", "Error", "error in getting links").Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "Something went wrong.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "Something went wrong.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("ListChannelLinksWithNames").Return(nil, errors.New("error in getting links")).Times(1) @@ -402,11 +345,7 @@ func TestExecuteShowLinksCommand(t *testing.T) { }, setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(true).Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "No links present.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "No links present.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("ListChannelLinksWithNames").Return(nil, nil).Times(1) @@ -425,17 +364,9 @@ func TestExecuteShowLinksCommand(t *testing.T) { api.On("LogDebug", "Unable to get the MS Teams teams information", "Error", "error in getting teams info").Once() api.On("LogDebug", "Unable to get the MS Teams channel information for the team", "TeamID", testutils.GetTeamsTeamID(), "Error", "error in getting channels info").Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: commandWaitingMessage, - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), commandWaitingMessage)).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "| Mattermost Team | Mattermost Channel | MS Teams Team | MS Teams Channel | \n| :------|:--------|:-------|:-----------|\n|Test MM team|Test MM channel|||\n|Test MM team|Test MM channel|||\n|Test MM team|Test MM channel|||\n|Test MM team|Test MM channel|||\nThere were some errors while fetching information. Please check the server logs.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "| Mattermost Team | Mattermost Channel | MS Teams Team | MS Teams Channel | \n| :------|:--------|:-------|:-----------|\n|Test MM team|Test MM channel|||\n|Test MM team|Test MM channel|||\n|Test MM team|Test MM channel|||\n|Test MM team|Test MM channel|||\nThere were some errors while fetching information. Please check the server logs.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("ListChannelLinksWithNames").Return(testutils.GetChannelLinks(4), nil).Times(1) @@ -475,10 +406,7 @@ func TestExecuteDisconnectCommand(t *testing.T) { UserId: testutils.GetUserID(), }, setupAPI: func(api *plugintest.API) { - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - Message: "Your account has been disconnected.", - }).Return(testutils.GetPost("", testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", "", "Your account has been disconnected.")).Return(testutils.GetPost("", testutils.GetUserID(), time.Now().UnixMicro())).Times(1) api.On("LogDebug", "Unable to delete the last prompt timestamp for the user", "MMUserID", testutils.GetUserID(), "Error", "error in deleting prompt time") }, @@ -494,10 +422,7 @@ func TestExecuteDisconnectCommand(t *testing.T) { description: "User account is not connected", args: &model.CommandArgs{}, setupAPI: func(api *plugintest.API) { - api.On("SendEphemeralPost", "", &model.Post{ - UserId: "bot-user-id", - Message: "Error: the account is not connected", - }).Return(testutils.GetPost("", "", time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", "", testutils.GetEphemeralPost("bot-user-id", "", "Error: the account is not connected")).Return(testutils.GetPost("", "", time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("MattermostToTeamsUserID", "").Return("", errors.New("Unable to get team UserID")).Times(1) @@ -507,10 +432,7 @@ func TestExecuteDisconnectCommand(t *testing.T) { description: "User account is not connected as token is not found", args: &model.CommandArgs{}, setupAPI: func(api *plugintest.API) { - api.On("SendEphemeralPost", "", &model.Post{ - UserId: "bot-user-id", - Message: "Error: the account is not connected", - }).Return(testutils.GetPost("", "", time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", "", testutils.GetEphemeralPost("bot-user-id", "", "Error: the account is not connected")).Return(testutils.GetPost("", "", time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("MattermostToTeamsUserID", "").Return("", nil).Times(1) @@ -523,10 +445,7 @@ func TestExecuteDisconnectCommand(t *testing.T) { UserId: testutils.GetUserID(), }, setupAPI: func(api *plugintest.API) { - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - Message: "Error: unable to disconnect your account, Error while disconnecting your account", - }).Return(testutils.GetPost("", testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", "", "Error: unable to disconnect your account, Error while disconnecting your account")).Return(testutils.GetPost("", testutils.GetUserID(), time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("MattermostToTeamsUserID", testutils.GetUserID()).Return("", nil).Times(1) @@ -564,11 +483,7 @@ func TestExecuteDisconnectBotCommand(t *testing.T) { }, setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(true).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "The bot account has been disconnected.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "The bot account has been disconnected.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("MattermostToTeamsUserID", "bot-user-id").Return(testutils.GetUserID(), nil).Times(1) @@ -583,11 +498,7 @@ func TestExecuteDisconnectBotCommand(t *testing.T) { }, setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(true).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "Error: unable to find the connected bot account", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "Error: unable to find the connected bot account")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("MattermostToTeamsUserID", "bot-user-id").Return("", errors.New("Error: unable to find the connected bot account")).Times(1) @@ -601,11 +512,7 @@ func TestExecuteDisconnectBotCommand(t *testing.T) { }, setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(true).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "Error: unable to disconnect the bot account, Error while disconnecting the bot account", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "Error: unable to disconnect the bot account, Error while disconnecting the bot account")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("MattermostToTeamsUserID", "bot-user-id").Return(testutils.GetUserID(), nil).Times(1) @@ -617,10 +524,7 @@ func TestExecuteDisconnectBotCommand(t *testing.T) { args: &model.CommandArgs{}, setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", "", model.PermissionManageSystem).Return(false).Times(1) - api.On("SendEphemeralPost", "", &model.Post{ - UserId: "bot-user-id", - Message: "Unable to disconnect the bot account, only system admins can disconnect the bot account.", - }).Return(testutils.GetPost("", "", time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", "", testutils.GetEphemeralPost("bot-user-id", "", "Unable to disconnect the bot account, only system admins can disconnect the bot account.")).Return(testutils.GetPost("", "", time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) {}, }, @@ -665,24 +569,88 @@ func TestExecuteLinkCommand(t *testing.T) { }, }, nil).Times(2) api.On("HasPermissionToChannel", testutils.GetUserID(), testutils.GetChannelID(), model.PermissionManageChannelRoles).Return(true).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "Please wait while your request is being processed.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "The MS Teams channel is now linked to this Mattermost channel.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), commandWaitingMessage)).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "The MS Teams channel is now linked to this Mattermost channel.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("CheckEnabledTeamByTeamID", testutils.GetTeamsUserID()).Return(true).Times(1) s.On("GetLinkByChannelID", testutils.GetChannelID()).Return(nil, nil).Times(1) s.On("GetLinkByMSTeamsChannelID", testutils.GetTeamsUserID(), testutils.GetChannelID()).Return(nil, nil).Times(1) s.On("GetTokenForMattermostUser", testutils.GetUserID()).Return(&oauth2.Token{}, nil).Times(1) - s.On("StoreChannelLink", mock.Anything).Return(nil).Times(1) - s.On("SaveChannelSubscription", mock.Anything).Return(nil).Times(1) + s.On("StoreChannelLink", mock.AnythingOfType("*storemodels.ChannelLink")).Return(nil).Times(1) + s.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + s.On("SaveChannelSubscription", mock.AnythingOfType("storemodels.ChannelSubscription"), &sql.Tx{}).Return(nil).Times(1) + s.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) + }, + setupClient: func(c *mockClient.Client, uc *mockClient.Client) { + uc.On("GetChannelInTeam", testutils.GetTeamsUserID(), testutils.GetChannelID()).Return(&msteams.Channel{}, nil) + }, + }, + { + description: "Error in beginning the database transaction", + parameters: []string{testutils.GetTeamsUserID(), testutils.GetChannelID()}, + args: &model.CommandArgs{ + UserId: testutils.GetUserID(), + TeamId: testutils.GetTeamsUserID(), + ChannelId: testutils.GetChannelID(), + }, + setupAPI: func(api *plugintest.API) { + api.On("GetChannel", testutils.GetChannelID()).Return(&model.Channel{ + Type: model.ChannelTypeOpen, + }, nil).Times(1) + api.On("GetConfig").Return(&model.Config{ + ServiceSettings: model.ServiceSettings{ + SiteURL: model.NewString("/"), + }, + }, nil).Times(2) + api.On("HasPermissionToChannel", testutils.GetUserID(), testutils.GetChannelID(), model.PermissionManageChannelRoles).Return(true).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), commandWaitingMessage)).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "Something went wrong")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("LogError", "Unable to begin the database transaction", "error", "error in beginning the database transaction") + }, + setupStore: func(s *mockStore.Store) { + s.On("CheckEnabledTeamByTeamID", testutils.GetTeamsUserID()).Return(true).Times(1) + s.On("GetLinkByChannelID", testutils.GetChannelID()).Return(nil, nil).Times(1) + s.On("GetLinkByMSTeamsChannelID", testutils.GetTeamsUserID(), testutils.GetChannelID()).Return(nil, nil).Times(1) + s.On("GetTokenForMattermostUser", testutils.GetUserID()).Return(&oauth2.Token{}, nil).Times(1) + s.On("StoreChannelLink", mock.AnythingOfType("*storemodels.ChannelLink")).Return(nil).Times(1) + s.On("BeginTx").Return(nil, errors.New("error in beginning the database transaction")).Times(1) + }, + setupClient: func(c *mockClient.Client, uc *mockClient.Client) { + uc.On("GetChannelInTeam", testutils.GetTeamsUserID(), testutils.GetChannelID()).Return(&msteams.Channel{}, nil) + }, + }, + { + description: "Unable to commit the database transaction", + parameters: []string{testutils.GetTeamsUserID(), testutils.GetChannelID()}, + args: &model.CommandArgs{ + UserId: testutils.GetUserID(), + TeamId: testutils.GetTeamsUserID(), + ChannelId: testutils.GetChannelID(), + }, + setupAPI: func(api *plugintest.API) { + api.On("GetChannel", testutils.GetChannelID()).Return(&model.Channel{ + Type: model.ChannelTypeOpen, + }, nil).Times(1) + api.On("GetConfig").Return(&model.Config{ + ServiceSettings: model.ServiceSettings{ + SiteURL: model.NewString("/"), + }, + }, nil).Times(2) + api.On("HasPermissionToChannel", testutils.GetUserID(), testutils.GetChannelID(), model.PermissionManageChannelRoles).Return(true).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), commandWaitingMessage)).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "Something went wrong")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("LogError", "Unable to commit database transaction", "error", "error in committing transaction") + }, + setupStore: func(s *mockStore.Store) { + s.On("CheckEnabledTeamByTeamID", testutils.GetTeamsUserID()).Return(true).Times(1) + s.On("GetLinkByChannelID", testutils.GetChannelID()).Return(nil, nil).Times(1) + s.On("GetLinkByMSTeamsChannelID", testutils.GetTeamsUserID(), testutils.GetChannelID()).Return(nil, nil).Times(1) + s.On("GetTokenForMattermostUser", testutils.GetUserID()).Return(&oauth2.Token{}, nil).Times(1) + s.On("StoreChannelLink", mock.AnythingOfType("*storemodels.ChannelLink")).Return(nil).Times(1) + s.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + s.On("SaveChannelSubscription", mock.AnythingOfType("storemodels.ChannelSubscription"), &sql.Tx{}).Return(nil).Times(1) + s.On("CommitTx", &sql.Tx{}).Return(errors.New("error in committing transaction")).Times(1) }, setupClient: func(c *mockClient.Client, uc *mockClient.Client) { uc.On("GetChannelInTeam", testutils.GetTeamsUserID(), testutils.GetChannelID()).Return(&msteams.Channel{}, nil) @@ -701,11 +669,7 @@ func TestExecuteLinkCommand(t *testing.T) { Type: model.ChannelTypeOpen, }, nil).Times(1) api.On("HasPermissionToChannel", testutils.GetUserID(), testutils.GetChannelID(), model.PermissionManageChannelRoles).Return(true).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "A link for this channel already exists. Please unlink the channel before you link again with another channel.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "A link for this channel already exists. Please unlink the channel before you link again with another channel.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("CheckEnabledTeamByTeamID", testutils.GetTeamsUserID()).Return(true).Times(1) @@ -730,11 +694,7 @@ func TestExecuteLinkCommand(t *testing.T) { Type: model.ChannelTypeOpen, }, nil).Times(1) api.On("HasPermissionToChannel", testutils.GetUserID(), testutils.GetChannelID(), model.PermissionManageChannelRoles).Return(true).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "Invalid link command, please pass the MS Teams team id and channel id as parameters.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "Invalid link command, please pass the MS Teams team id and channel id as parameters.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) {}, setupClient: func(c *mockClient.Client, uc *mockClient.Client) {}, @@ -748,10 +708,7 @@ func TestExecuteLinkCommand(t *testing.T) { Type: model.ChannelTypeOpen, }, nil).Times(1) api.On("HasPermissionToChannel", "", "", model.PermissionManageChannelRoles).Return(true).Times(1) - api.On("SendEphemeralPost", "", &model.Post{ - UserId: "bot-user-id", - Message: "This team is not enabled for MS Teams sync.", - }).Return(testutils.GetPost("", "", time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", "", testutils.GetEphemeralPost("bot-user-id", "", "This team is not enabled for MS Teams sync.")).Return(testutils.GetPost("", "", time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("CheckEnabledTeamByTeamID", "").Return(false).Times(1) @@ -766,10 +723,7 @@ func TestExecuteLinkCommand(t *testing.T) { }, setupAPI: func(api *plugintest.API) { api.On("GetChannel", "").Return(nil, testutils.GetInternalServerAppError("Error while getting the current channel.")).Times(1) - api.On("SendEphemeralPost", "", &model.Post{ - UserId: "bot-user-id", - Message: "Unable to get the current channel information.", - }).Return(testutils.GetPost("", "", time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", "", testutils.GetEphemeralPost("bot-user-id", "", "Unable to get the current channel information.")).Return(testutils.GetPost("", "", time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("CheckEnabledTeamByTeamID", testutils.GetTeamsUserID()).Return(true).Times(1) @@ -788,11 +742,7 @@ func TestExecuteLinkCommand(t *testing.T) { Type: model.ChannelTypeOpen, }, nil).Times(1) api.On("HasPermissionToChannel", "", testutils.GetChannelID(), model.PermissionManageChannelRoles).Return(false).Times(1) - api.On("SendEphemeralPost", "", &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "Unable to link the channel. You have to be a channel admin to link it.", - }).Return(testutils.GetPost(testutils.GetChannelID(), "", time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", "", testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "Unable to link the channel. You have to be a channel admin to link it.")).Return(testutils.GetPost(testutils.GetChannelID(), "", time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("CheckEnabledTeamByTeamID", testutils.GetTeamsUserID()).Return(true).Times(1) @@ -810,11 +760,7 @@ func TestExecuteLinkCommand(t *testing.T) { api.On("GetChannel", testutils.GetChannelID()).Return(&model.Channel{ Type: model.ChannelTypeGroup, }, nil).Times(1) - api.On("SendEphemeralPost", "", &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "Linking/unlinking a direct or group message is not allowed", - }).Return(testutils.GetPost(testutils.GetChannelID(), "", time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", "", testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "Linking/unlinking a direct or group message is not allowed")).Return(testutils.GetPost(testutils.GetChannelID(), "", time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("CheckEnabledTeamByTeamID", testutils.GetTeamsUserID()).Return(true).Times(1) @@ -839,11 +785,7 @@ func TestExecuteLinkCommand(t *testing.T) { }, }, nil).Times(1) api.On("HasPermissionToChannel", testutils.GetUserID(), testutils.GetChannelID(), model.PermissionManageChannelRoles).Return(true).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: "bot-user-id", - ChannelId: testutils.GetChannelID(), - Message: "MS Teams channel not found or you don't have the permissions to access it.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost("bot-user-id", testutils.GetChannelID(), "MS Teams channel not found or you don't have the permissions to access it.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Times(1) }, setupStore: func(s *mockStore.Store) { s.On("CheckEnabledTeamByTeamID", testutils.GetTeamsUserID()).Return(true).Times(1) @@ -879,11 +821,7 @@ func TestExecuteConnectCommand(t *testing.T) { { description: "User already connected", setupAPI: func(api *plugintest.API) { - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "You are already connected to MS Teams. Please disconnect your account first before connecting again.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "You are already connected to MS Teams. Please disconnect your account first before connecting again.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("GetTokenForMattermostUser", testutils.GetUserID()).Return(&oauth2.Token{}, nil).Once() @@ -893,11 +831,7 @@ func TestExecuteConnectCommand(t *testing.T) { description: "Unable to store OAuth state", setupAPI: func(api *plugintest.API) { api.On("LogError", "Error in storing the OAuth state", "error", "error in storing oauth state") - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "Error trying to connect the account, please try again.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "Error trying to connect the account, please try again.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("GetTokenForMattermostUser", testutils.GetUserID()).Return(nil, errors.New("token not found")).Once() @@ -908,11 +842,7 @@ func TestExecuteConnectCommand(t *testing.T) { description: "Unable to set in KV store", setupAPI: func(api *plugintest.API) { api.On("KVSet", "_code_verifier_"+testutils.GetUserID(), mock.Anything).Return(testutils.GetInternalServerAppError("unable to set in KV store")).Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "Error trying to connect the account, please try again.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "Error trying to connect the account, please try again.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("GetTokenForMattermostUser", testutils.GetUserID()).Return(nil, errors.New("token not found")).Once() @@ -962,11 +892,7 @@ func TestExecuteConnectBotCommand(t *testing.T) { description: "User don't have permission to execute the command", setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(false).Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "Unable to connect the bot account, only system admins can connect the bot account.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "Unable to connect the bot account, only system admins can connect the bot account.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(_ *mockStore.Store) {}, }, @@ -974,11 +900,7 @@ func TestExecuteConnectBotCommand(t *testing.T) { description: "Bot user already connected", setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(true).Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "The bot account is already connected to MS Teams. Please disconnect the bot account first before connecting again.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "The bot account is already connected to MS Teams. Please disconnect the bot account first before connecting again.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("GetTokenForMattermostUser", p.userID).Return(&oauth2.Token{}, nil).Once() @@ -989,11 +911,7 @@ func TestExecuteConnectBotCommand(t *testing.T) { setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(true).Once() api.On("LogError", "Error in storing the OAuth state", "error", "error in storing oauth state") - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "Error trying to connect the bot account, please try again.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "Error trying to connect the bot account, please try again.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("GetTokenForMattermostUser", p.userID).Return(nil, errors.New("token not found")).Once() @@ -1005,11 +923,7 @@ func TestExecuteConnectBotCommand(t *testing.T) { setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(true).Once() api.On("KVSet", "_code_verifier_"+p.userID, mock.Anything).Return(testutils.GetInternalServerAppError("unable to set in KV store")).Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "Error trying to connect the bot account, please try again.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "Error trying to connect the bot account, please try again.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("GetTokenForMattermostUser", p.userID).Return(nil, errors.New("token not found")).Once() @@ -1185,11 +1099,7 @@ func TestExecutePromoteCommand(t *testing.T) { description: "No params", params: []string{}, setupAPI: func(api *plugintest.API) { - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "Invalid promote command, please pass the current username and promoted username as parameters.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "Invalid promote command, please pass the current username and promoted username as parameters.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) {}, }, @@ -1197,11 +1107,7 @@ func TestExecutePromoteCommand(t *testing.T) { description: "Too many params", params: []string{"user1", "user2", "user3"}, setupAPI: func(api *plugintest.API) { - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "Invalid promote command, please pass the current username and promoted username as parameters.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "Invalid promote command, please pass the current username and promoted username as parameters.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) {}, }, @@ -1211,11 +1117,7 @@ func TestExecutePromoteCommand(t *testing.T) { params: []string{"valid-user", "valid-user"}, setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(false).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "Unable to execute the command, only system admins have access to execute this command.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "Unable to execute the command, only system admins have access to execute this command.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) {}, }, @@ -1225,11 +1127,7 @@ func TestExecutePromoteCommand(t *testing.T) { setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(true).Times(1) api.On("GetUserByUsername", "not-existing-user").Return(nil, &model.AppError{}).Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "Error: Unable to promote account not-existing-user, user not found", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "Error: Unable to promote account not-existing-user, user not found")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) {}, }, @@ -1239,11 +1137,7 @@ func TestExecutePromoteCommand(t *testing.T) { setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(true).Times(1) api.On("GetUserByUsername", "existing-user").Return(&model.User{Id: "test", Username: "existing-user", RemoteId: model.NewString("test")}, nil).Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "Error: Unable to promote account existing-user, it is not a known msteams user account", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "Error: Unable to promote account existing-user, it is not a known msteams user account")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("MattermostToTeamsUserID", "test").Return("", errors.New("not-found")).Times(1) @@ -1255,11 +1149,7 @@ func TestExecutePromoteCommand(t *testing.T) { setupAPI: func(api *plugintest.API) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(true).Times(1) api.On("GetUserByUsername", "existing-user").Return(&model.User{Id: "test", Username: "existing-user", RemoteId: nil}, nil).Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "Error: Unable to promote account existing-user, it is already a regular account", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "Error: Unable to promote account existing-user, it is already a regular account")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("MattermostToTeamsUserID", "test").Return("ms-test", nil).Times(1) @@ -1272,11 +1162,7 @@ func TestExecutePromoteCommand(t *testing.T) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(true).Times(1) api.On("GetUserByUsername", "valid-user").Return(&model.User{Id: "test", Username: "valid-user", RemoteId: model.NewString("test")}, nil).Once() api.On("GetUserByUsername", "new-user").Return(&model.User{Id: "test2", Username: "new-user", RemoteId: nil}, nil).Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "Error: the promoted username already exists, please use a different username.", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "Error: the promoted username already exists, please use a different username.")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("MattermostToTeamsUserID", "test").Return("ms-test", nil).Times(1) @@ -1290,11 +1176,7 @@ func TestExecutePromoteCommand(t *testing.T) { api.On("GetUserByUsername", "valid-user").Return(&model.User{Id: "test", Username: "valid-user", RemoteId: model.NewString("test")}, nil).Once() api.On("GetUserByUsername", "new-user").Return(nil, &model.AppError{}).Once() api.On("UpdateUser", mock.Anything).Return(nil, &model.AppError{}).Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "Error: Unable to promote account valid-user", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "Error: Unable to promote account valid-user")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("MattermostToTeamsUserID", "test").Return("ms-test", nil).Times(1) @@ -1308,11 +1190,7 @@ func TestExecutePromoteCommand(t *testing.T) { api.On("GetUserByUsername", "valid-user").Return(&model.User{Id: "test", Username: "valid-user", RemoteId: model.NewString("test")}, nil).Once() api.On("GetUserByUsername", "new-user").Return(nil, &model.AppError{}).Once() api.On("UpdateUser", &model.User{Id: "test", Username: "new-user", RemoteId: nil}).Return(&model.User{Id: "test", Username: "new-user", RemoteId: nil}, nil).Once() - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "Account valid-user has been promoted and updated the username to new-user", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "Account valid-user has been promoted and updated the username to new-user")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("MattermostToTeamsUserID", "test").Return("ms-test", nil).Times(1) @@ -1325,11 +1203,7 @@ func TestExecutePromoteCommand(t *testing.T) { api.On("HasPermissionTo", testutils.GetUserID(), model.PermissionManageSystem).Return(true).Times(1) api.On("GetUserByUsername", "valid-user").Return(&model.User{Id: "test", Username: "valid-user", RemoteId: model.NewString("test")}, nil).Times(2) api.On("UpdateUser", &model.User{Id: "test", Username: "valid-user", RemoteId: nil}).Return(&model.User{Id: "test", Username: "valid-user", RemoteId: nil}, nil).Times(1) - api.On("SendEphemeralPost", testutils.GetUserID(), &model.Post{ - UserId: p.userID, - ChannelId: testutils.GetChannelID(), - Message: "Account valid-user has been promoted and updated the username to valid-user", - }).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() + api.On("SendEphemeralPost", testutils.GetUserID(), testutils.GetEphemeralPost(p.userID, testutils.GetChannelID(), "Account valid-user has been promoted and updated the username to valid-user")).Return(testutils.GetPost(testutils.GetChannelID(), testutils.GetUserID(), time.Now().UnixMicro())).Once() }, setupStore: func(s *mockStore.Store) { s.On("MattermostToTeamsUserID", "test").Return("ms-test", nil).Times(1) diff --git a/server/handlers/handlers.go b/server/handlers/handlers.go index 1feacb6c8..49c6d0d95 100644 --- a/server/handlers/handlers.go +++ b/server/handlers/handlers.go @@ -264,8 +264,7 @@ func (ah *ActivityHandler) handleCreatedActivity(activityIds msteams.ActivityIds ah.updateLastReceivedChangeDate(msg.LastUpdateAt) if newPost != nil && newPost.Id != "" && msg.ID != "" { - err = ah.plugin.GetStore().LinkPosts(storemodels.PostInfo{MattermostID: newPost.Id, MSTeamsChannel: msg.ChatID + msg.ChannelID, MSTeamsID: msg.ID, MSTeamsLastUpdateAt: msg.LastUpdateAt}) - if err != nil { + if err := ah.plugin.GetStore().LinkPosts(storemodels.PostInfo{MattermostID: newPost.Id, MSTeamsChannel: fmt.Sprintf(msg.ChatID + msg.ChannelID), MSTeamsID: msg.ID, MSTeamsLastUpdateAt: msg.LastUpdateAt}, nil); err != nil { ah.plugin.GetAPI().LogWarn("Error updating the MSTeams/Mattermost post link metadata", "error", err) } } diff --git a/server/handlers/handlers_test.go b/server/handlers/handlers_test.go index c30e735e7..5c419fcf9 100644 --- a/server/handlers/handlers_test.go +++ b/server/handlers/handlers_test.go @@ -1,6 +1,7 @@ package handlers import ( + "database/sql" "errors" "fmt" "testing" @@ -374,7 +375,7 @@ func TestHandleCreatedActivity(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: testutils.GetMessageID(), MSTeamsChannel: testutils.GetMSTeamsChannelID(), - }).Return(errors.New("unable to update the post")).Times(1) + }, (*sql.Tx)(nil)).Return(errors.New("unable to update the post")).Times(1) }, }, { @@ -430,7 +431,7 @@ func TestHandleCreatedActivity(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: testutils.GetMessageID(), MSTeamsChannel: testutils.GetMSTeamsChannelID(), - }).Return(nil).Times(1) + }, (*sql.Tx)(nil)).Return(nil).Times(1) }, }, { @@ -477,7 +478,7 @@ func TestHandleCreatedActivity(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: testutils.GetMessageID(), MSTeamsChannel: testutils.GetChannelID(), - }).Return(nil).Times(1) + }, (*sql.Tx)(nil)).Return(nil).Times(1) }, }, } { diff --git a/server/message_hooks.go b/server/message_hooks.go index 2140d94e9..785c4ecc6 100644 --- a/server/message_hooks.go +++ b/server/message_hooks.go @@ -221,22 +221,45 @@ func (p *Plugin) SetChatReaction(teamsMessageID, srcUser, channelID, emojiName s } var teamsMessage *msteams.Message + tx, err := p.store.BeginTx() + if err != nil { + return err + } + + var txErr error + defer func() { + if txErr != nil { + if err := p.store.RollbackTx(tx); err != nil { + p.API.LogWarn("Unable to rollback database transaction", "error", err.Error()) + } + return + } + + if err := p.store.CommitTx(tx); err != nil { + p.API.LogWarn("Unable to commit database transaction", "error", err.Error()) + } + }() + + if txErr = p.store.LockPostByMSTeamsPostID(tx, teamsMessageID); txErr != nil { + return txErr + } + if updateRequired { - teamsMessage, err = client.SetChatReaction(chatID, teamsMessageID, srcUserID, emoji.Parse(":"+emojiName+":")) - if err != nil { - p.API.LogError("Error creating post reaction", "error", err.Error()) - return err + teamsMessage, txErr = client.SetChatReaction(chatID, teamsMessageID, srcUserID, emoji.Parse(":"+emojiName+":")) + if txErr != nil { + p.API.LogError("Error creating post reaction", "error", txErr.Error()) + return txErr } } else { - teamsMessage, err = client.GetChatMessage(chatID, teamsMessageID) - if err != nil { - p.API.LogWarn("Error getting the msteams post metadata", "error", err.Error()) - return err + teamsMessage, txErr = client.GetChatMessage(chatID, teamsMessageID) + if txErr != nil { + p.API.LogWarn("Error getting the msteams post metadata", "error", txErr.Error()) + return txErr } } - if err = p.store.SetPostLastUpdateAtByMSTeamsID(teamsMessageID, teamsMessage.LastUpdateAt); err != nil { - p.API.LogWarn("Error updating the msteams/mattermost post link metadata", "error", err.Error()) + if txErr = p.store.SetPostLastUpdateAtByMSTeamsID(teamsMessageID, teamsMessage.LastUpdateAt, tx); txErr != nil { + p.API.LogWarn("Error updating the msteams/mattermost post link metadata", "error", txErr.Error()) } return nil @@ -268,23 +291,46 @@ func (p *Plugin) SetReaction(teamID, channelID, userID string, post *model.Post, } var teamsMessage *msteams.Message + tx, err := p.store.BeginTx() + if err != nil { + return err + } + + var txErr error + defer func() { + if txErr != nil { + if err := p.store.RollbackTx(tx); err != nil { + p.API.LogWarn("Unable to rollback database transaction", "error", err.Error()) + } + return + } + + if err := p.store.CommitTx(tx); err != nil { + p.API.LogWarn("Unable to commit database transaction", "error", err.Error()) + } + }() + + if txErr = p.store.LockPostByMMPostID(tx, postInfo.MattermostID); txErr != nil { + return txErr + } + if updateRequired { teamsUserID, _ := p.store.MattermostToTeamsUserID(userID) - teamsMessage, err = client.SetReaction(teamID, channelID, parentID, postInfo.MSTeamsID, teamsUserID, emoji.Parse(":"+emojiName+":")) - if err != nil { - p.API.LogError("Error setting reaction", "error", err.Error()) - return err + teamsMessage, txErr = client.SetReaction(teamID, channelID, parentID, postInfo.MSTeamsID, teamsUserID, emoji.Parse(":"+emojiName+":")) + if txErr != nil { + p.API.LogError("Error setting reaction", "error", txErr.Error()) + return txErr } } else { - teamsMessage, err = getUpdatedMessage(teamID, channelID, parentID, postInfo.MSTeamsID, client) - if err != nil { - p.API.LogWarn("Error getting the msteams post metadata", "error", err.Error()) - return err + teamsMessage, txErr = getUpdatedMessage(teamID, channelID, parentID, postInfo.MSTeamsID, client) + if txErr != nil { + p.API.LogWarn("Error getting the msteams post metadata", "error", txErr.Error()) + return txErr } } - if err = p.store.SetPostLastUpdateAtByMattermostID(postInfo.MattermostID, teamsMessage.LastUpdateAt); err != nil { - p.API.LogWarn("Error updating the msteams/mattermost post link metadata", "error", err.Error()) + if txErr = p.store.SetPostLastUpdateAtByMattermostID(postInfo.MattermostID, teamsMessage.LastUpdateAt, tx); txErr != nil { + p.API.LogWarn("Error updating the msteams/mattermost post link metadata", "error", txErr.Error()) } return nil @@ -310,14 +356,37 @@ func (p *Plugin) UnsetChatReaction(teamsMessageID, srcUser, channelID string, em return err } - teamsMessage, err := client.UnsetChatReaction(chatID, teamsMessageID, srcUserID, emoji.Parse(":"+emojiName+":")) + tx, err := p.store.BeginTx() if err != nil { - p.API.LogError("Error in removing the chat reaction", "emojiName", emojiName, "error", err.Error()) return err } - if err = p.store.SetPostLastUpdateAtByMSTeamsID(teamsMessageID, teamsMessage.LastUpdateAt); err != nil { - p.API.LogWarn("Error updating the msteams/mattermost post link metadata", "error", err.Error()) + var txErr error + defer func() { + if txErr != nil { + if err := p.store.RollbackTx(tx); err != nil { + p.API.LogWarn("Unable to rollback database transaction", "error", err.Error()) + } + return + } + + if err := p.store.CommitTx(tx); err != nil { + p.API.LogWarn("Unable to commit database transaction", "error", err.Error()) + } + }() + + if txErr = p.store.LockPostByMSTeamsPostID(tx, teamsMessageID); txErr != nil { + return txErr + } + + teamsMessage, txErr := client.UnsetChatReaction(chatID, teamsMessageID, srcUserID, emoji.Parse(":"+emojiName+":")) + if txErr != nil { + p.API.LogError("Error in removing the chat reaction", "emojiName", emojiName, "error", txErr.Error()) + return txErr + } + + if txErr = p.store.SetPostLastUpdateAtByMSTeamsID(teamsMessageID, teamsMessage.LastUpdateAt, tx); txErr != nil { + p.API.LogWarn("Error updating the msteams/mattermost post link metadata", "error", txErr.Error()) } return nil @@ -349,14 +418,37 @@ func (p *Plugin) UnsetReaction(teamID, channelID, userID string, post *model.Pos } teamsUserID, _ := p.store.MattermostToTeamsUserID(userID) - teamsMessage, err := client.UnsetReaction(teamID, channelID, parentID, postInfo.MSTeamsID, teamsUserID, emoji.Parse(":"+emojiName+":")) + tx, err := p.store.BeginTx() if err != nil { - p.API.LogError("Error in removing the reaction", "emojiName", emojiName, "error", err.Error()) return err } - if err = p.store.SetPostLastUpdateAtByMattermostID(postInfo.MattermostID, teamsMessage.LastUpdateAt); err != nil { - p.API.LogWarn("Error updating the msteams/mattermost post link metadata", "error", err.Error()) + var txErr error + defer func() { + if txErr != nil { + if err := p.store.RollbackTx(tx); err != nil { + p.API.LogWarn("Unable to rollback database transaction", "error", err.Error()) + } + return + } + + if err := p.store.CommitTx(tx); err != nil { + p.API.LogWarn("Unable to commit database transaction", "error", err.Error()) + } + }() + + if txErr = p.store.LockPostByMMPostID(tx, postInfo.MattermostID); txErr != nil { + return txErr + } + + teamsMessage, txErr := client.UnsetReaction(teamID, channelID, parentID, postInfo.MSTeamsID, teamsUserID, emoji.Parse(":"+emojiName+":")) + if txErr != nil { + p.API.LogError("Error in removing the reaction", "emojiName", emojiName, "error", txErr.Error()) + return txErr + } + + if txErr = p.store.SetPostLastUpdateAtByMattermostID(postInfo.MattermostID, teamsMessage.LastUpdateAt, tx); txErr != nil { + p.API.LogWarn("Error updating the msteams/mattermost post link metadata", "error", txErr.Error()) } return nil @@ -447,9 +539,8 @@ func (p *Plugin) SendChat(srcUser string, usersIDs []string, post *model.Post) ( return "", err } - if post.Id != "" && newMessage != nil { - err := p.store.LinkPosts(storemodels.PostInfo{MattermostID: post.Id, MSTeamsChannel: chat.ID, MSTeamsID: newMessage.ID, MSTeamsLastUpdateAt: newMessage.LastUpdateAt}) - if err != nil { + if post.Id != "" { + if err := p.store.LinkPosts(storemodels.PostInfo{MattermostID: post.Id, MSTeamsChannel: chat.ID, MSTeamsID: newMessage.ID, MSTeamsLastUpdateAt: newMessage.LastUpdateAt}, nil); err != nil { p.API.LogWarn("Error updating the msteams/mattermost post link metadata", "error", err) } } @@ -531,9 +622,8 @@ func (p *Plugin) Send(teamID, channelID string, user *model.User, post *model.Po return "", err } - if post.Id != "" && newMessage != nil { - err := p.store.LinkPosts(storemodels.PostInfo{MattermostID: post.Id, MSTeamsChannel: channelID, MSTeamsID: newMessage.ID, MSTeamsLastUpdateAt: newMessage.LastUpdateAt}) - if err != nil { + if post.Id != "" { + if err := p.store.LinkPosts(storemodels.PostInfo{MattermostID: post.Id, MSTeamsChannel: channelID, MSTeamsID: newMessage.ID, MSTeamsLastUpdateAt: newMessage.LastUpdateAt}, nil); err != nil { p.API.LogWarn("Error updating the msteams/mattermost post link metadata", "error", err) } } @@ -638,31 +728,54 @@ func (p *Plugin) Update(teamID, channelID string, user *model.User, newPost, old } var updatedMessage *msteams.Message + tx, err := p.store.BeginTx() + if err != nil { + return err + } + + var txErr error + defer func() { + if txErr != nil { + if err := p.store.RollbackTx(tx); err != nil { + p.API.LogWarn("Unable to rollback database transaction", "error", err.Error()) + } + return + } + + if err := p.store.CommitTx(tx); err != nil { + p.API.LogWarn("Unable to commit database transaction", "error", err.Error()) + } + }() + + if txErr = p.store.LockPostByMMPostID(tx, newPost.Id); txErr != nil { + return txErr + } + if updateRequired { // TODO: Add the logic of processing the attachments and uploading new files to Teams // once Mattermost comes up with the feature of editing attachments md := markdown.New(markdown.XHTMLOutput(true), markdown.Typographer(false), markdown.LangPrefix("CodeMirror language-")) content := md.RenderToString([]byte(emoji.Parse(text))) content, mentions := p.getMentionsData(content, teamID, channelID, "", client) - updatedMessage, err = client.UpdateMessage(teamID, channelID, parentID, postInfo.MSTeamsID, content, mentions) - if err != nil { - p.API.LogWarn("Error updating the post on MS Teams", "error", err) + updatedMessage, txErr = client.UpdateMessage(teamID, channelID, parentID, postInfo.MSTeamsID, content, mentions) + if txErr != nil { + p.API.LogWarn("Error updating the post on MS Teams", "error", txErr) // If the error is regarding payment required for metered APIs, ignore it and continue because // the post is updated regardless - if !strings.Contains(err.Error(), "code: PaymentRequired") { - return err + if !strings.Contains(txErr.Error(), "code: PaymentRequired") { + return txErr } } } else { - updatedMessage, err = getUpdatedMessage(teamID, channelID, parentID, postInfo.MSTeamsID, client) - if err != nil { - p.API.LogWarn("Error in getting the message from MS Teams", "error", err) - return err + updatedMessage, txErr = getUpdatedMessage(teamID, channelID, parentID, postInfo.MSTeamsID, client) + if txErr != nil { + p.API.LogWarn("Error in getting the message from MS Teams", "error", txErr) + return txErr } } - if err = p.store.LinkPosts(storemodels.PostInfo{MattermostID: newPost.Id, MSTeamsChannel: channelID, MSTeamsID: postInfo.MSTeamsID, MSTeamsLastUpdateAt: updatedMessage.LastUpdateAt}); err != nil { - p.API.LogWarn("Error updating the msteams/mattermost post link metadata", "error", err) + if txErr = p.store.LinkPosts(storemodels.PostInfo{MattermostID: newPost.Id, MSTeamsChannel: channelID, MSTeamsID: postInfo.MSTeamsID, MSTeamsLastUpdateAt: updatedMessage.LastUpdateAt}, tx); txErr != nil { + p.API.LogWarn("Error updating the msteams/mattermost post link metadata", "error", txErr) } return nil @@ -690,29 +803,52 @@ func (p *Plugin) UpdateChat(chatID string, user *model.User, newPost, oldPost *m } var updatedMessage *msteams.Message + tx, err := p.store.BeginTx() + if err != nil { + return err + } + + var txErr error + defer func() { + if txErr != nil { + if err := p.store.RollbackTx(tx); err != nil { + p.API.LogWarn("Unable to rollback database transaction", "error", err.Error()) + } + return + } + + if err := p.store.CommitTx(tx); err != nil { + p.API.LogWarn("Unable to commit database transaction", "error", err.Error()) + } + }() + + if txErr = p.store.LockPostByMMPostID(tx, newPost.Id); txErr != nil { + return txErr + } + if updateRequired { md := markdown.New(markdown.XHTMLOutput(true), markdown.Typographer(false), markdown.LangPrefix("CodeMirror language-")) content := md.RenderToString([]byte(emoji.Parse(text))) content, mentions := p.getMentionsData(content, "", "", chatID, client) - updatedMessage, err = client.UpdateChatMessage(chatID, postInfo.MSTeamsID, content, mentions) - if err != nil { - p.API.LogWarn("Error updating the post on MS Teams", "error", err) + updatedMessage, txErr = client.UpdateChatMessage(chatID, postInfo.MSTeamsID, content, mentions) + if txErr != nil { + p.API.LogWarn("Error updating the post on MS Teams", "error", txErr) // If the error is regarding payment required for metered APIs, ignore it and continue because // the post is updated regardless - if !strings.Contains(err.Error(), "code: PaymentRequired") { - return err + if !strings.Contains(txErr.Error(), "code: PaymentRequired") { + return txErr } } } else { - updatedMessage, err = client.GetChatMessage(chatID, postInfo.MSTeamsID) - if err != nil { - p.API.LogWarn("Error getting the updated message from MS Teams", "error", err) - return err + updatedMessage, txErr = client.GetChatMessage(chatID, postInfo.MSTeamsID) + if txErr != nil { + p.API.LogWarn("Error getting the updated message from MS Teams", "error", txErr) + return txErr } } - if err := p.store.LinkPosts(storemodels.PostInfo{MattermostID: newPost.Id, MSTeamsChannel: chatID, MSTeamsID: postInfo.MSTeamsID, MSTeamsLastUpdateAt: updatedMessage.LastUpdateAt}); err != nil { - p.API.LogWarn("Error updating the msteams/mattermost post link metadata", "error", err) + if txErr = p.store.LinkPosts(storemodels.PostInfo{MattermostID: newPost.Id, MSTeamsChannel: chatID, MSTeamsID: postInfo.MSTeamsID, MSTeamsLastUpdateAt: updatedMessage.LastUpdateAt}, tx); txErr != nil { + p.API.LogWarn("Error updating the msteams/mattermost post link metadata", "error", txErr) } return nil diff --git a/server/message_hooks_test.go b/server/message_hooks_test.go index 2879f54f3..c8d8bbfb4 100644 --- a/server/message_hooks_test.go +++ b/server/message_hooks_test.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "database/sql" "errors" "math" "testing" @@ -94,6 +95,9 @@ func TestReactionHasBeenAdded(t *testing.T) { store.On("GetLinkByChannelID", testutils.GetChannelID()).Return(&storemodels.ChannelLink{MattermostTeamID: "mm-team-id", MattermostChannelID: "mm-channel-id", MSTeamsTeam: "ms-teams-team-id", MSTeamsChannel: "ms-teams-channel-id"}, nil).Times(1) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("SetReaction", "ms-teams-team-id", "ms-teams-channel-id", "", "ms-teams-id", testutils.GetID(), mock.AnythingOfType("string")).Return(nil, errors.New("unable to set the reaction")).Times(1) @@ -112,7 +116,10 @@ func TestReactionHasBeenAdded(t *testing.T) { store.On("GetLinkByChannelID", testutils.GetChannelID()).Return(&storemodels.ChannelLink{MattermostTeamID: "mm-team-id", MattermostChannelID: "mm-channel-id", MSTeamsTeam: "ms-teams-team-id", MSTeamsChannel: "ms-teams-channel-id"}, nil).Times(1) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() - store.On("SetPostLastUpdateAtByMattermostID", testutils.GetID(), testutils.GetMockTime()).Return(errors.New("unable to set post lastUpdateAt value")).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) + store.On("SetPostLastUpdateAtByMattermostID", testutils.GetID(), testutils.GetMockTime(), &sql.Tx{}).Return(errors.New("unable to set post lastUpdateAt value")).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("SetReaction", "ms-teams-team-id", "ms-teams-channel-id", "", "ms-teams-id", testutils.GetID(), mock.AnythingOfType("string")).Return(mockMessage, nil).Times(1) @@ -131,7 +138,10 @@ func TestReactionHasBeenAdded(t *testing.T) { store.On("GetLinkByChannelID", testutils.GetChannelID()).Return(&storemodels.ChannelLink{MattermostTeamID: "mm-team-id", MattermostChannelID: "mm-channel-id", MSTeamsTeam: "ms-teams-team-id", MSTeamsChannel: "ms-teams-channel-id"}, nil).Times(1) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() - store.On("SetPostLastUpdateAtByMattermostID", testutils.GetID(), testutils.GetMockTime()).Return(nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) + store.On("SetPostLastUpdateAtByMattermostID", testutils.GetID(), testutils.GetMockTime(), &sql.Tx{}).Return(nil).Times(1) + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("SetReaction", "ms-teams-team-id", "ms-teams-channel-id", "", "ms-teams-id", testutils.GetID(), mock.AnythingOfType("string")).Return(mockMessage, nil).Times(1) @@ -237,6 +247,9 @@ func TestReactionHasBeenRemoved(t *testing.T) { }, nil).Times(1) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("UnsetReaction", "mockTeamsTeamID", "mockTeamsChannelID", "", "", testutils.GetID(), mock.AnythingOfType("string")).Return(nil, errors.New("unable to unset the reaction")).Times(1) @@ -252,7 +265,6 @@ func TestReactionHasBeenRemoved(t *testing.T) { api.On("LogWarn", "Error updating the msteams/mattermost post link metadata", "error", "unable to set post lastUpdateAt value") }, SetupStore: func(store *storemocks.Store) { - demoTime, _ := time.Parse("Jan 2, 2006 at 3:04pm (MST)", "Jan 2, 2023 at 4:00pm (MST)") store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{ MattermostID: testutils.GetID(), }, nil).Times(2) @@ -264,7 +276,10 @@ func TestReactionHasBeenRemoved(t *testing.T) { }, nil).Times(1) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() - store.On("SetPostLastUpdateAtByMattermostID", testutils.GetID(), demoTime).Return(errors.New("unable to set post lastUpdateAt value")).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) + store.On("SetPostLastUpdateAtByMattermostID", testutils.GetID(), testutils.GetMockTime(), &sql.Tx{}).Return(errors.New("unable to set post lastUpdateAt value")).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("UnsetReaction", "mockTeamsTeamID", "mockTeamsChannelID", "", "", testutils.GetID(), mock.AnythingOfType("string")).Return(mockMessage, nil).Times(1) @@ -291,7 +306,10 @@ func TestReactionHasBeenRemoved(t *testing.T) { }, nil).Times(1) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() - store.On("SetPostLastUpdateAtByMattermostID", testutils.GetID(), testutils.GetMockTime()).Return(nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) + store.On("SetPostLastUpdateAtByMattermostID", testutils.GetID(), testutils.GetMockTime(), &sql.Tx{}).Return(nil).Times(1) + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("UnsetReaction", "mockTeamsTeamID", "mockTeamsChannelID", "", "", testutils.GetID(), mock.AnythingOfType("string")).Return(mockMessage, nil).Times(1) @@ -352,11 +370,14 @@ func TestMessageHasBeenUpdated(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: "mockMsgID", }, nil).Times(2) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) store.On("LinkPosts", storemodels.PostInfo{ MattermostID: testutils.GetID(), MSTeamsID: "mockMsgID", MSTeamsChannel: testutils.GetChatID(), - }).Return(nil).Times(1) + }, &sql.Tx{}).Return(nil).Times(1) + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) @@ -445,6 +466,7 @@ func TestMessageHasBeenUpdated(t *testing.T) { api.On("GetConfig").Return(&model.Config{ServiceSettings: model.ServiceSettings{SiteURL: model.NewString("/")}}, nil).Times(2) }, SetupStore: func(store *storemocks.Store) { + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) store.On("GetLinkByChannelID", testutils.GetChannelID()).Return(&storemodels.ChannelLink{ MattermostTeamID: "mockMattermostTeam", MattermostChannelID: "mockMattermostChannel", @@ -455,12 +477,14 @@ func TestMessageHasBeenUpdated(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: "mockMessageID", }, nil).Times(2) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) store.On("LinkPosts", storemodels.PostInfo{ MattermostID: testutils.GetID(), MSTeamsID: "mockMessageID", MSTeamsChannel: "mockTeamsChannelID", - }).Return(nil).Times(1) - store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) + }, &sql.Tx{}).Return(nil).Times(1) + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("UpdateMessage", "mockTeamsTeamID", "mockTeamsChannelID", "", "mockMessageID", "", []models.ChatMessageMentionable{}).Return(mockChannelMessage, nil).Times(1) @@ -476,6 +500,7 @@ func TestMessageHasBeenUpdated(t *testing.T) { api.On("GetConfig").Return(&model.Config{ServiceSettings: model.ServiceSettings{SiteURL: model.NewString("/")}}, nil).Times(2) }, SetupStore: func(store *storemocks.Store) { + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) store.On("GetLinkByChannelID", testutils.GetChannelID()).Return(&storemodels.ChannelLink{ MattermostTeamID: "mockMattermostTeamID", MattermostChannelID: "mockMattermostChannelID", @@ -485,12 +510,14 @@ func TestMessageHasBeenUpdated(t *testing.T) { store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{ MattermostID: testutils.GetID(), }, nil).Times(2) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) store.On("LinkPosts", storemodels.PostInfo{ MattermostID: testutils.GetID(), MSTeamsID: "mockTeamsTeamID", MSTeamsChannel: "mockTeamsChannelID", - }).Return(nil).Times(1) - store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) + }, &sql.Tx{}).Return(nil).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("UpdateMessage", "mockTeamsTeamID", "mockTeamsChannelID", "", "", "", []models.ChatMessageMentionable{}).Return(nil, errors.New("unable to update the post")).Times(1) @@ -566,6 +593,24 @@ func TestSetChatReaction(t *testing.T) { SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) {}, ExpectedMessage: "unable to get the channel", }, + { + Name: "SetChatReaction: Unable to begin the database transaction", + SetupAPI: func(api *plugintest.API) { + api.On("GetChannel", testutils.GetChannelID()).Return(testutils.GetChannel(model.ChannelTypeDirect), nil).Times(1) + api.On("GetChannelMembers", testutils.GetChannelID(), 0, math.MaxInt32).Return(testutils.GetChannelMembers(2), nil).Times(1) + api.On("GetConfig").Return(&model.Config{ServiceSettings: model.ServiceSettings{SiteURL: model.NewString("/")}}, nil).Times(2) + }, + SetupStore: func(store *storemocks.Store) { + store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Times(3) + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) + store.On("BeginTx").Return(nil, errors.New("unable to begin database transaction")).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { + uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) + }, + ExpectedMessage: "unable to begin database transaction", + UpdateRequired: true, + }, { Name: "SetChatReaction: Unable to set the chat reaction", SetupAPI: func(api *plugintest.API) { @@ -577,6 +622,32 @@ func TestSetChatReaction(t *testing.T) { SetupStore: func(store *storemocks.Store) { store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Times(3) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMSTeamsPostID", &sql.Tx{}, "mockTeamsMessageID").Return(nil).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { + uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) + uclient.On("SetChatReaction", testutils.GetChatID(), "mockTeamsMessageID", testutils.GetID(), ":mockEmojiName:").Return(nil, errors.New("unable to set the chat reaction")).Times(1) + }, + ExpectedMessage: "unable to set the chat reaction", + UpdateRequired: true, + }, + { + Name: "SetChatReaction: Unable to set the chat reaction and unable to rollback transaction", + SetupAPI: func(api *plugintest.API) { + api.On("LogError", "Error creating post reaction", "error", "unable to set the chat reaction") + api.On("LogWarn", "Unable to rollback database transaction", "error", "unable to rollback database transaction") + api.On("GetChannel", testutils.GetChannelID()).Return(testutils.GetChannel(model.ChannelTypeDirect), nil).Times(1) + api.On("GetChannelMembers", testutils.GetChannelID(), 0, math.MaxInt32).Return(testutils.GetChannelMembers(2), nil).Times(1) + api.On("GetConfig").Return(&model.Config{ServiceSettings: model.ServiceSettings{SiteURL: model.NewString("/")}}, nil).Times(2) + }, + SetupStore: func(store *storemocks.Store) { + store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Times(3) + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMSTeamsPostID", &sql.Tx{}, "mockTeamsMessageID").Return(nil).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(errors.New("unable to rollback database transaction")).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) @@ -595,7 +666,10 @@ func TestSetChatReaction(t *testing.T) { SetupStore: func(store *storemocks.Store) { store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Times(3) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) - store.On("SetPostLastUpdateAtByMSTeamsID", "mockTeamsMessageID", testutils.GetMockTime()).Return(nil).Once() + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMSTeamsPostID", &sql.Tx{}, "mockTeamsMessageID").Return(nil).Times(1) + store.On("SetPostLastUpdateAtByMSTeamsID", "mockTeamsMessageID", testutils.GetMockTime(), &sql.Tx{}).Return(nil).Once() + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) @@ -613,11 +687,37 @@ func TestSetChatReaction(t *testing.T) { SetupStore: func(store *storemocks.Store) { store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Times(3) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) - store.On("SetPostLastUpdateAtByMSTeamsID", "mockTeamsMessageID", testutils.GetMockTime()).Return(errors.New("unable to set post lastUpdateAt value")).Once() + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMSTeamsPostID", &sql.Tx{}, "mockTeamsMessageID").Return(nil).Times(1) + store.On("SetPostLastUpdateAtByMSTeamsID", "mockTeamsMessageID", testutils.GetMockTime(), &sql.Tx{}).Return(errors.New("unable to set post lastUpdateAt value")).Once() + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { + uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) + uclient.On("SetChatReaction", testutils.GetChatID(), "mockTeamsMessageID", testutils.GetID(), ":mockEmojiName:").Return(mockChatMessage, nil).Times(1) + }, + UpdateRequired: true, + }, + { + Name: "SetChatReaction: Unable to commit database transaction", + SetupAPI: func(api *plugintest.API) { + api.On("GetChannel", testutils.GetChannelID()).Return(testutils.GetChannel(model.ChannelTypeDirect), nil).Times(1) + api.On("GetChannelMembers", testutils.GetChannelID(), 0, math.MaxInt32).Return(testutils.GetChannelMembers(2), nil).Times(1) + api.On("GetConfig").Return(&model.Config{ServiceSettings: model.ServiceSettings{SiteURL: model.NewString("/")}}, nil).Times(2) + api.On("LogWarn", "Unable to commit database transaction", "error", "unable to commit database transaction").Return(nil).Times(1) + }, + SetupStore: func(store *storemocks.Store) { + store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Times(3) + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMSTeamsPostID", &sql.Tx{}, "mockTeamsMessageID").Return(nil).Times(1) + store.On("SetPostLastUpdateAtByMSTeamsID", "mockTeamsMessageID", testutils.GetMockTime(), &sql.Tx{}).Return(nil).Once() + store.On("CommitTx", &sql.Tx{}).Return(errors.New("unable to commit database transaction")).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) uclient.On("SetChatReaction", testutils.GetChatID(), "mockTeamsMessageID", testutils.GetID(), ":mockEmojiName:").Return(mockChatMessage, nil).Times(1) + uclient.On("GetChatMessage", testutils.GetChatID(), "mockTeamsMessageID").Return(&msteams.Message{LastUpdateAt: testutils.GetMockTime()}, nil).Once() }, UpdateRequired: true, }, @@ -631,7 +731,10 @@ func TestSetChatReaction(t *testing.T) { SetupStore: func(store *storemocks.Store) { store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Times(3) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) - store.On("SetPostLastUpdateAtByMSTeamsID", "mockTeamsMessageID", testutils.GetMockTime()).Return(nil).Once() + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMSTeamsPostID", &sql.Tx{}, "mockTeamsMessageID").Return(nil).Times(1) + store.On("SetPostLastUpdateAtByMSTeamsID", "mockTeamsMessageID", testutils.GetMockTime(), &sql.Tx{}).Return(nil).Once() + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) @@ -699,6 +802,21 @@ func TestSetReaction(t *testing.T) { SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) {}, ExpectedMessage: "not connected user", }, + { + Name: "SetReaction: Unable to begin database transaction", + SetupAPI: func(api *plugintest.API) { + api.On("LogError", "Error setting reaction", "error", "unable to begin database transaction").Return(nil).Times(1) + api.On("GetConfig").Return(&model.Config{ServiceSettings: model.ServiceSettings{SiteURL: model.NewString("/")}}, nil).Times(2) + }, + SetupStore: func(store *storemocks.Store) { + store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{}, nil).Times(1) + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) + store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() + store.On("BeginTx").Return(nil, errors.New("unable to begin database transaction")).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) {}, + ExpectedMessage: "unable to begin database transaction", + }, { Name: "SetReaction: Unable to set the reaction", SetupAPI: func(api *plugintest.API) { @@ -709,12 +827,54 @@ func TestSetReaction(t *testing.T) { store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{}, nil).Times(1) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, "").Return(nil).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("SetReaction", "mockTeamsTeamID", "mockTeamsChannelID", "", "", testutils.GetID(), ":mockName:").Return(nil, errors.New("unable to set the reaction")).Times(1) }, ExpectedMessage: "unable to set the reaction", }, + { + Name: "SetReaction: Unable to set the reaction and unable to rollback database transaction", + SetupAPI: func(api *plugintest.API) { + api.On("LogError", "Error setting reaction", "error", "unable to set the reaction").Return(nil).Times(1) + api.On("LogWarn", "Unable to rollback database transaction", "error", "unable to rollback database transaction") + api.On("GetConfig").Return(&model.Config{ServiceSettings: model.ServiceSettings{SiteURL: model.NewString("/")}}, nil).Times(2) + }, + SetupStore: func(store *storemocks.Store) { + store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{}, nil).Times(1) + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) + store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, "").Return(nil).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(errors.New("unable to rollback database transaction")).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { + uclient.On("SetReaction", "mockTeamsTeamID", "mockTeamsChannelID", "", "", testutils.GetID(), ":mockName:").Return(nil, errors.New("unable to set the reaction")).Times(1) + }, + ExpectedMessage: "unable to set the reaction", + }, + { + Name: "SetReaction: Unable to commit database transaction", + SetupAPI: func(api *plugintest.API) { + api.On("GetConfig").Return(&model.Config{ServiceSettings: model.ServiceSettings{SiteURL: model.NewString("/")}}, nil).Times(2) + api.On("LogWarn", "Unable to commit database transaction", "error", "unable to commit database transaction").Return(nil).Times(1) + }, + SetupStore: func(store *storemocks.Store) { + store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{}, nil).Times(1) + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) + store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, "").Return(nil).Times(1) + store.On("SetPostLastUpdateAtByMattermostID", "", testutils.GetMockTime(), &sql.Tx{}).Return(nil).Once() + store.On("CommitTx", &sql.Tx{}).Return(errors.New("unable to commit database transaction")).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { + uclient.On("SetReaction", "mockTeamsTeamID", "mockTeamsChannelID", "", "", testutils.GetID(), ":mockName:").Return(mockChannelMessage, nil).Times(1) + }, + }, { Name: "SetReaction: Valid", SetupAPI: func(api *plugintest.API) { @@ -724,7 +884,10 @@ func TestSetReaction(t *testing.T) { store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{}, nil).Times(1) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() - store.On("SetPostLastUpdateAtByMattermostID", "", testutils.GetMockTime()).Return(nil).Once() + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, "").Return(nil).Times(1) + store.On("SetPostLastUpdateAtByMattermostID", "", testutils.GetMockTime(), &sql.Tx{}).Return(nil).Once() + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("SetReaction", "mockTeamsTeamID", "mockTeamsChannelID", "", "", testutils.GetID(), ":mockName:").Return(mockChannelMessage, nil).Times(1) @@ -807,6 +970,24 @@ func TestUnsetChatReaction(t *testing.T) { SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) {}, ExpectedMessage: "unable to get the channel", }, + { + Name: "UnsetChatReaction: Unable to begin database transaction", + SetupAPI: func(api *plugintest.API) { + api.On("LogError", "Error in removing the chat reaction", "emojiName", "mockEmojiName", "error", ": , unable to unset the chat reaction") + api.On("GetChannel", testutils.GetChannelID()).Return(testutils.GetChannel(model.ChannelTypeDirect), nil).Times(1) + api.On("GetChannelMembers", testutils.GetChannelID(), 0, math.MaxInt32).Return(testutils.GetChannelMembers(2), nil).Times(1) + api.On("GetConfig").Return(&model.Config{ServiceSettings: model.ServiceSettings{SiteURL: model.NewString("/")}}, nil).Times(2) + }, + SetupStore: func(store *storemocks.Store) { + store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Times(3) + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) + store.On("BeginTx").Return(nil, errors.New("unable to begin the transaction")).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { + uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) + }, + ExpectedMessage: "unable to begin the transaction", + }, { Name: "UnsetChatReaction: Unable to unset the chat reaction", SetupAPI: func(api *plugintest.API) { @@ -818,6 +999,31 @@ func TestUnsetChatReaction(t *testing.T) { SetupStore: func(store *storemocks.Store) { store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Times(3) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMSTeamsPostID", &sql.Tx{}, "mockTeamsMessageID").Return(nil).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { + uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) + uclient.On("UnsetChatReaction", testutils.GetChatID(), "mockTeamsMessageID", testutils.GetID(), ":mockEmojiName:").Return(nil, testutils.GetInternalServerAppError("unable to unset the chat reaction")).Times(1) + }, + ExpectedMessage: "unable to unset the chat reaction", + }, + { + Name: "UnsetChatReaction: Unable to unset the chat reaction and rollback database transaction", + SetupAPI: func(api *plugintest.API) { + api.On("LogError", "Error in removing the chat reaction", "emojiName", "mockEmojiName", "error", ": , unable to unset the chat reaction") + api.On("LogWarn", "Unable to rollback database transaction", "error", "unable to rollback database transaction") + api.On("GetChannel", testutils.GetChannelID()).Return(testutils.GetChannel(model.ChannelTypeDirect), nil).Times(1) + api.On("GetChannelMembers", testutils.GetChannelID(), 0, math.MaxInt32).Return(testutils.GetChannelMembers(2), nil).Times(1) + api.On("GetConfig").Return(&model.Config{ServiceSettings: model.ServiceSettings{SiteURL: model.NewString("/")}}, nil).Times(2) + }, + SetupStore: func(store *storemocks.Store) { + store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Times(3) + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMSTeamsPostID", &sql.Tx{}, "mockTeamsMessageID").Return(nil).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(errors.New("unable to rollback database transaction")).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) @@ -836,7 +1042,31 @@ func TestUnsetChatReaction(t *testing.T) { SetupStore: func(store *storemocks.Store) { store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Times(3) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) - store.On("SetPostLastUpdateAtByMSTeamsID", "mockTeamsMessageID", testutils.GetMockTime()).Return(errors.New("unable to set post lastUpdateAt value")).Once() + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMSTeamsPostID", &sql.Tx{}, "mockTeamsMessageID").Return(nil).Times(1) + store.On("SetPostLastUpdateAtByMSTeamsID", "mockTeamsMessageID", testutils.GetMockTime(), &sql.Tx{}).Return(errors.New("unable to set post lastUpdateAt value")).Once() + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { + uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) + uclient.On("UnsetChatReaction", testutils.GetChatID(), "mockTeamsMessageID", testutils.GetID(), ":mockEmojiName:").Return(mockChatMessage, nil).Times(1) + }, + }, + { + Name: "UnsetChatReaction: Unable to commit database transaction", + SetupAPI: func(api *plugintest.API) { + api.On("GetChannel", testutils.GetChannelID()).Return(testutils.GetChannel(model.ChannelTypeDirect), nil).Times(1) + api.On("GetChannelMembers", testutils.GetChannelID(), 0, math.MaxInt32).Return(testutils.GetChannelMembers(2), nil).Times(1) + api.On("GetConfig").Return(&model.Config{ServiceSettings: model.ServiceSettings{SiteURL: model.NewString("/")}}, nil).Times(2) + api.On("LogWarn", "Unable to commit database transaction", "error", "unable to commit database transaction").Return(nil).Times(1) + }, + SetupStore: func(store *storemocks.Store) { + store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Times(3) + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMSTeamsPostID", &sql.Tx{}, "mockTeamsMessageID").Return(nil).Times(1) + store.On("SetPostLastUpdateAtByMSTeamsID", "mockTeamsMessageID", testutils.GetMockTime(), &sql.Tx{}).Return(nil).Once() + store.On("CommitTx", &sql.Tx{}).Return(errors.New("unable to commit database transaction")).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) @@ -853,7 +1083,10 @@ func TestUnsetChatReaction(t *testing.T) { SetupStore: func(store *storemocks.Store) { store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Times(3) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(2) - store.On("SetPostLastUpdateAtByMSTeamsID", "mockTeamsMessageID", testutils.GetMockTime()).Return(nil).Once() + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMSTeamsPostID", &sql.Tx{}, "mockTeamsMessageID").Return(nil).Times(1) + store.On("SetPostLastUpdateAtByMSTeamsID", "mockTeamsMessageID", testutils.GetMockTime(), &sql.Tx{}).Return(nil).Once() + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) @@ -919,6 +1152,20 @@ func TestUnsetReaction(t *testing.T) { SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) {}, ExpectedMessage: "not connected user", }, + { + Name: "UnsetReaction: Unable to begin database transaction", + SetupAPI: func(api *plugintest.API) { + api.On("LogError", "Error in removing the reaction", "emojiName", "mockName", "error", ": , unable to unset the reaction") + }, + SetupStore: func(store *storemocks.Store) { + store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{}, nil).Times(1) + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) + store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() + store.On("BeginTx").Return(&sql.Tx{}, errors.New("unable to begin database transaction")).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) {}, + ExpectedMessage: "unable to begin database transaction", + }, { Name: "UnsetReaction: Unable to unset the reaction", SetupAPI: func(api *plugintest.API) { @@ -928,12 +1175,53 @@ func TestUnsetReaction(t *testing.T) { store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{}, nil).Times(1) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, "").Return(nil).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { + uclient.On("UnsetReaction", "mockTeamsTeamID", "mockTeamsChannelID", "", "", testutils.GetID(), ":mockName:").Return(nil, testutils.GetInternalServerAppError("unable to unset the reaction")).Times(1) + }, + ExpectedMessage: "unable to unset the reaction", + }, + { + Name: "UnsetReaction: Unable to rollback database transaction", + SetupAPI: func(api *plugintest.API) { + api.On("LogError", "Error in removing the reaction", "emojiName", "mockName", "error", ": , unable to unset the reaction").Return(nil).Times(1) + api.On("LogWarn", "Unable to rollback database transaction", "error", "unable to rollback database transaction").Return(nil).Times(1) + }, + SetupStore: func(store *storemocks.Store) { + store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{}, nil).Times(1) + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) + store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, "").Return(nil).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(errors.New("unable to rollback database transaction")).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("UnsetReaction", "mockTeamsTeamID", "mockTeamsChannelID", "", "", testutils.GetID(), ":mockName:").Return(nil, testutils.GetInternalServerAppError("unable to unset the reaction")).Times(1) }, ExpectedMessage: "unable to unset the reaction", }, + { + Name: "UnsetReaction: Unable to commit database transaction", + SetupAPI: func(api *plugintest.API) { + api.On("LogWarn", "Unable to commit database transaction", "error", "unable to commit database transaction") + }, + SetupStore: func(store *storemocks.Store) { + store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{}, nil).Times(1) + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) + store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, "").Return(nil).Times(1) + store.On("SetPostLastUpdateAtByMattermostID", "", testutils.GetMockTime(), &sql.Tx{}).Return(nil).Once() + store.On("CommitTx", &sql.Tx{}).Return(errors.New("unable to commit database transaction")).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { + uclient.On("UnsetReaction", "mockTeamsTeamID", "mockTeamsChannelID", "", "", testutils.GetID(), ":mockName:").Return(mockChannelMessage, nil).Times(1) + uclient.On("GetMessage", "mockTeamsTeamID", "mockTeamsChannelID", "").Return(&msteams.Message{LastUpdateAt: testutils.GetMockTime()}, nil).Times(1) + }, + }, { Name: "UnsetReaction: Valid", SetupAPI: func(a *plugintest.API) {}, @@ -941,7 +1229,10 @@ func TestUnsetReaction(t *testing.T) { store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{}, nil).Times(1) store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) store.On("MattermostToTeamsUserID", testutils.GetID()).Return(testutils.GetID(), nil).Once() - store.On("SetPostLastUpdateAtByMattermostID", "", testutils.GetMockTime()).Return(nil).Once() + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, "").Return(nil).Times(1) + store.On("SetPostLastUpdateAtByMattermostID", "", testutils.GetMockTime(), &sql.Tx{}).Return(nil).Once() + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("UnsetReaction", "mockTeamsTeamID", "mockTeamsChannelID", "", "", testutils.GetID(), ":mockName:").Return(mockChannelMessage, nil).Times(1) @@ -1070,7 +1361,7 @@ func TestSendChat(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsChannel: testutils.GetChatID(), MSTeamsID: "mockMessageID", - }).Return(testutils.GetInternalServerAppError("unable to store the post")).Times(1) + }, (*sql.Tx)(nil)).Return(testutils.GetInternalServerAppError("unable to store the post")).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) @@ -1103,7 +1394,7 @@ func TestSendChat(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsChannel: testutils.GetChatID(), MSTeamsID: "mockMessageID", - }).Return(nil).Times(1) + }, (*sql.Tx)(nil)).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Once() @@ -1134,7 +1425,7 @@ func TestSendChat(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsChannel: testutils.GetChatID(), MSTeamsID: "mockMessageID", - }).Return(nil).Times(1) + }, (*sql.Tx)(nil)).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) @@ -1160,7 +1451,7 @@ func TestSendChat(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsChannel: testutils.GetChatID(), MSTeamsID: "mockMessageID", - }).Return(nil).Times(1) + }, (*sql.Tx)(nil)).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) @@ -1188,7 +1479,7 @@ func TestSendChat(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsChannel: testutils.GetChatID(), MSTeamsID: "mockMessageID", - }).Return(nil).Times(1) + }, (*sql.Tx)(nil)).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Once() @@ -1225,7 +1516,7 @@ func TestSendChat(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsChannel: testutils.GetChatID(), MSTeamsID: "mockMessageID", - }).Return(nil).Times(1) + }, (*sql.Tx)(nil)).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("CreateOrGetChatForUsers", mock.AnythingOfType("[]string")).Return(mockChat, nil).Times(1) @@ -1294,7 +1585,7 @@ func TestSend(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: "mockMessageID", MSTeamsChannel: testutils.GetChannelID(), - }).Return(nil).Times(1) + }, (*sql.Tx)(nil)).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("SendMessageWithAttachments", testutils.GetID(), testutils.GetChannelID(), "", "

mockMessage??????????

\n", ([]*msteams.Attachment)(nil), []models.ChatMessageMentionable{}).Return(&msteams.Message{ @@ -1316,7 +1607,7 @@ func TestSend(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: "mockMessageID", MSTeamsChannel: testutils.GetChannelID(), - }).Return(nil).Times(1) + }, (*sql.Tx)(nil)).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("SendMessageWithAttachments", testutils.GetID(), testutils.GetChannelID(), "", "

mockMessage??????????

\n", ([]*msteams.Attachment)(nil), []models.ChatMessageMentionable{}).Return(&msteams.Message{ @@ -1360,7 +1651,7 @@ func TestSend(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: "mockMessageID", MSTeamsChannel: testutils.GetChannelID(), - }).Return(errors.New("unable to store posts")).Times(1) + }, (*sql.Tx)(nil)).Return(errors.New("unable to store posts")).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("UploadFile", testutils.GetID(), testutils.GetChannelID(), "mockFile.Name"+"_"+testutils.GetID()+".txt", 1, "mockMimeType", bytes.NewReader([]byte("mockData")), (*msteams.Chat)(nil)).Return(&msteams.Attachment{ @@ -1388,7 +1679,7 @@ func TestSend(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: "mockMessageID", MSTeamsChannel: testutils.GetChannelID(), - }).Return(nil).Times(1) + }, (*sql.Tx)(nil)).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("UploadFile", testutils.GetID(), testutils.GetChannelID(), "mockFile.Name"+"_"+testutils.GetID()+".txt", 1, "mockMimeType", bytes.NewReader([]byte("mockData")), (*msteams.Chat)(nil)).Return(&msteams.Attachment{ @@ -1656,6 +1947,21 @@ func TestUpdate(t *testing.T) { SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) {}, ExpectedError: "post not found", }, + { + Name: "Update: Unable to begin database transaction", + SetupAPI: func(api *plugintest.API) {}, + SetupStore: func(store *storemocks.Store) { + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) + store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{ + MattermostID: testutils.GetID(), + MSTeamsID: "mockMSTeamsID", + }, nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, errors.New("unable to begin database transaction")).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) {}, + ExpectedError: "unable to begin database transaction", + UpdateRequired: true, + }, { Name: "Update: Unable to update the message", SetupAPI: func(api *plugintest.API) { @@ -1667,6 +1973,31 @@ func TestUpdate(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: "mockMSTeamsID", }, nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { + uclient.On("UpdateMessage", "mockTeamsTeamID", testutils.GetChannelID(), "", "mockMSTeamsID", "

mockMessage??????????

\n", []models.ChatMessageMentionable{}).Return(nil, errors.New("unable to update the message")).Times(1) + }, + ExpectedError: "unable to update the message", + UpdateRequired: true, + }, + { + Name: "Update: Unable to update the message and rollback database transaction", + SetupAPI: func(api *plugintest.API) { + api.On("LogWarn", "Error updating the post on MS Teams", "error", errors.New("unable to update the message")).Return(nil).Times(1) + api.On("LogWarn", "Unable to rollback database transaction", "error", "unable to rollback database transaction").Return(nil).Times(1) + }, + SetupStore: func(store *storemocks.Store) { + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) + store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{ + MattermostID: testutils.GetID(), + MSTeamsID: "mockMSTeamsID", + }, nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(errors.New("unable to rollback database transaction")).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("UpdateMessage", "mockTeamsTeamID", testutils.GetChannelID(), "", "mockMSTeamsID", "

mockMessage??????????

\n", []models.ChatMessageMentionable{}).Return(nil, errors.New("unable to update the message")).Times(1) @@ -1685,11 +2016,14 @@ func TestUpdate(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: "mockMSTeamsID", }, nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) store.On("LinkPosts", storemodels.PostInfo{ MattermostID: testutils.GetID(), MSTeamsChannel: testutils.GetChannelID(), MSTeamsID: "mockMSTeamsID", - }).Return(nil).Times(1) + }, &sql.Tx{}).Return(nil).Times(1) + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("GetMessage", "mockTeamsTeamID", testutils.GetChannelID(), "mockMSTeamsID").Return(mockChannelMessage, nil).Times(1) @@ -1706,11 +2040,14 @@ func TestUpdate(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: "mockMSTeamsID", }, nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) store.On("LinkPosts", storemodels.PostInfo{ MattermostID: testutils.GetID(), MSTeamsChannel: testutils.GetChannelID(), MSTeamsID: "mockMSTeamsID", - }).Return(errors.New("unable to store the link posts")).Times(1) + }, &sql.Tx{}).Return(errors.New("unable to store the link posts")).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("UpdateMessage", "mockTeamsTeamID", testutils.GetChannelID(), "", "mockMSTeamsID", "

mockMessage??????????

\n", []models.ChatMessageMentionable{}).Return(mockChannelMessage, nil).Times(1) @@ -1718,9 +2055,9 @@ func TestUpdate(t *testing.T) { UpdateRequired: true, }, { - Name: "Update: Valid", + Name: "Update: Unable to commit database transaction", SetupAPI: func(api *plugintest.API) { - api.On("LogWarn", "Error updating the msteams/mattermost post link metadata", "error", mock.Anything) + api.On("LogWarn", "Unable to commit database transaction", "error", "unable to commit database transaction").Return(nil).Times(1) }, SetupStore: func(store *storemocks.Store) { store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) @@ -1728,11 +2065,37 @@ func TestUpdate(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: "mockMSTeamsID", }, nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) store.On("LinkPosts", storemodels.PostInfo{ MattermostID: testutils.GetID(), MSTeamsChannel: testutils.GetChannelID(), MSTeamsID: "mockMSTeamsID", - }).Return(nil).Times(1) + }, &sql.Tx{}).Return(nil).Times(1) + store.On("CommitTx", &sql.Tx{}).Return(errors.New("unable to commit database transaction")).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { + uclient.On("UpdateMessage", "mockTeamsTeamID", testutils.GetChannelID(), "", "mockMSTeamsID", "

mockMessage??????????

\n", []models.ChatMessageMentionable{}).Return(mockChannelMessage, nil).Times(1) + }, + UpdateRequired: true, + }, + { + Name: "Update: Valid", + SetupAPI: func(api *plugintest.API) {}, + SetupStore: func(store *storemocks.Store) { + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) + store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{ + MattermostID: testutils.GetID(), + MSTeamsID: "mockMSTeamsID", + }, nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) + store.On("LinkPosts", storemodels.PostInfo{ + MattermostID: testutils.GetID(), + MSTeamsChannel: testutils.GetChannelID(), + MSTeamsID: "mockMSTeamsID", + }, &sql.Tx{}).Return(nil).Times(1) + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("UpdateMessage", "mockTeamsTeamID", testutils.GetChannelID(), "", "mockMSTeamsID", "

mockMessage??????????

\n", []models.ChatMessageMentionable{}).Return(mockChannelMessage, nil).Times(1) @@ -1808,6 +2171,23 @@ func TestUpdateChat(t *testing.T) { SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) {}, ExpectedError: "not connected user", }, + { + Name: "UpdateChat: Unable to begin database transaction", + SetupAPI: func(api *plugintest.API) {}, + SetupStore: func(store *storemocks.Store) { + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) + store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{ + MattermostID: testutils.GetID(), + MSTeamsID: "mockTeamsTeamID", + }, nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, errors.New("unable to begin database transaction")).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { + uclient.On("UpdateChatMessage", "mockChatID", "mockTeamsTeamID", "

mockMessage??????????

\n", []models.ChatMessageMentionable{}).Return(nil, errors.New("unable to update the message")).Times(1) + }, + ExpectedError: "unable to begin database transaction", + UpdateRequired: true, + }, { Name: "UpdateChat: Unable to update the message", SetupAPI: func(api *plugintest.API) { @@ -1819,6 +2199,31 @@ func TestUpdateChat(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: "mockTeamsTeamID", }, nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { + uclient.On("UpdateChatMessage", "mockChatID", "mockTeamsTeamID", "

mockMessage??????????

\n", []models.ChatMessageMentionable{}).Return(nil, errors.New("unable to update the message")).Times(1) + }, + ExpectedError: "unable to update the message", + UpdateRequired: true, + }, + { + Name: "UpdateChat: Unable to update the message and rollback database transaction", + SetupAPI: func(api *plugintest.API) { + api.On("LogWarn", "Error updating the post on MS Teams", "error", errors.New("unable to update the message")).Return(nil).Times(1) + api.On("LogWarn", "Unable to rollback database transaction", "error", "unable to rollback database transaction").Return(nil).Times(1) + }, + SetupStore: func(store *storemocks.Store) { + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) + store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{ + MattermostID: testutils.GetID(), + MSTeamsID: "mockTeamsTeamID", + }, nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(errors.New("unable to rollback database transaction")).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("UpdateChatMessage", "mockChatID", "mockTeamsTeamID", "

mockMessage??????????

\n", []models.ChatMessageMentionable{}).Return(nil, errors.New("unable to update the message")).Times(1) @@ -1835,11 +2240,14 @@ func TestUpdateChat(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: "mockTeamsTeamID", }, nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) store.On("LinkPosts", storemodels.PostInfo{ MattermostID: testutils.GetID(), MSTeamsChannel: "mockChatID", MSTeamsID: "mockTeamsTeamID", - }).Return(nil).Times(1) + }, &sql.Tx{}).Return(nil).Times(1) + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("GetChatMessage", "mockChatID", "mockTeamsTeamID").Return(mockChatMessage, nil).Times(1) @@ -1856,11 +2264,40 @@ func TestUpdateChat(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: "mockTeamsTeamID", }, nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) + store.On("LinkPosts", storemodels.PostInfo{ + MattermostID: testutils.GetID(), + MSTeamsChannel: "mockChatID", + MSTeamsID: "mockTeamsTeamID", + }, &sql.Tx{}).Return(errors.New("unable to store the link posts")).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) + }, + SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { + uclient.On("UpdateChatMessage", "mockChatID", "mockTeamsTeamID", "

mockMessage??????????

\n", []models.ChatMessageMentionable{}).Return(mockChatMessage, nil).Times(1) + }, + UpdateRequired: true, + }, + { + Name: "UpdateChat: Unable to commit database transaction", + SetupAPI: func(api *plugintest.API) { + api.On("LogWarn", "Error updating the msteams/mattermost post link metadata", "error", mock.Anything).Return(nil).Times(1) + api.On("LogWarn", "Unable to commit database transaction", "error", "unable to commit database transaction").Return(nil).Times(1) + }, + SetupStore: func(store *storemocks.Store) { + store.On("GetTokenForMattermostUser", testutils.GetID()).Return(&oauth2.Token{}, nil).Times(1) + store.On("GetPostInfoByMattermostID", testutils.GetID()).Return(&storemodels.PostInfo{ + MattermostID: testutils.GetID(), + MSTeamsID: "mockTeamsTeamID", + }, nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) store.On("LinkPosts", storemodels.PostInfo{ MattermostID: testutils.GetID(), MSTeamsChannel: "mockChatID", MSTeamsID: "mockTeamsTeamID", - }).Return(errors.New("unable to store the link posts")).Times(1) + }, &sql.Tx{}).Return(nil).Times(1) + store.On("CommitTx", &sql.Tx{}).Return(errors.New("unable to commit database transaction")).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("UpdateChatMessage", "mockChatID", "mockTeamsTeamID", "

mockMessage??????????

\n", []models.ChatMessageMentionable{}).Return(mockChatMessage, nil).Times(1) @@ -1878,11 +2315,14 @@ func TestUpdateChat(t *testing.T) { MattermostID: testutils.GetID(), MSTeamsID: "mockTeamsTeamID", }, nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("LockPostByMMPostID", &sql.Tx{}, testutils.GetID()).Return(nil).Times(1) store.On("LinkPosts", storemodels.PostInfo{ MattermostID: testutils.GetID(), MSTeamsChannel: "mockChatID", MSTeamsID: "mockTeamsTeamID", - }).Return(nil).Times(1) + }, &sql.Tx{}).Return(nil).Times(1) + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, SetupClient: func(client *clientmocks.Client, uclient *clientmocks.Client) { uclient.On("UpdateChatMessage", "mockChatID", "mockTeamsTeamID", "

mockMessage??????????

\n", []models.ChatMessageMentionable{}).Return(mockChatMessage, nil).Times(1) diff --git a/server/monitor/subscriptions.go b/server/monitor/subscriptions.go index b5a3aef3b..158319629 100644 --- a/server/monitor/subscriptions.go +++ b/server/monitor/subscriptions.go @@ -178,8 +178,28 @@ func (m *Monitor) recreateChannelSubscription(subscriptionID, teamID, channelID, } } - if err := m.store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: newSubscription.ID, TeamID: teamID, ChannelID: channelID, Secret: secret, ExpiresOn: newSubscription.ExpiresOn}); err != nil { - m.api.LogError("Unable to store new subscription in DB", "subscriptionID", newSubscription.ID, "error", err.Error()) + tx, err := m.store.BeginTx() + if err != nil { + m.api.LogWarn("Unable to begin database transaction", "error", err.Error()) + return + } + + var txErr error + defer func() { + if txErr != nil { + if err := m.store.RollbackTx(tx); err != nil { + m.api.LogWarn("Unable to rollback database transaction", "error", err.Error()) + } + return + } + + if err := m.store.CommitTx(tx); err != nil { + m.api.LogWarn("Unable to commit database transaction", "error", err.Error()) + } + }() + + if txErr = m.store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: newSubscription.ID, TeamID: teamID, ChannelID: channelID, Secret: secret, ExpiresOn: newSubscription.ExpiresOn}, tx); txErr != nil { + m.api.LogError("Unable to store new subscription in DB", "subscriptionID", newSubscription.ID, "error", txErr.Error()) return } } diff --git a/server/monitor/subscriptions_test.go b/server/monitor/subscriptions_test.go index 68f70e90b..54a48ed55 100644 --- a/server/monitor/subscriptions_test.go +++ b/server/monitor/subscriptions_test.go @@ -1,6 +1,7 @@ package monitor import ( + "database/sql" "errors" "testing" "time" @@ -199,7 +200,9 @@ func TestMonitorCheckChannelSubscriptions(t *testing.T) { setupStore: func(store *mocksStore.Store) { store.On("ListChannelLinks").Return([]storemodels.ChannelLink{channelLink}, nil).Times(1) store.On("ListChannelSubscriptions").Return([]*storemodels.ChannelSubscription{}, nil).Times(1) - store.On("SaveChannelSubscription", storemodels.ChannelSubscription{SubscriptionID: "new-id", TeamID: "team-id", ChannelID: "channel-id", Secret: "webhook-secret", ExpiresOn: newExpiresOn}).Return(nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("SaveChannelSubscription", storemodels.ChannelSubscription{SubscriptionID: "new-id", TeamID: "team-id", ChannelID: "channel-id", Secret: "webhook-secret", ExpiresOn: newExpiresOn}, &sql.Tx{}).Return(nil).Times(1) + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, }, { @@ -213,7 +216,9 @@ func TestMonitorCheckChannelSubscriptions(t *testing.T) { setupStore: func(store *mocksStore.Store) { store.On("ListChannelLinks").Return([]storemodels.ChannelLink{channelLink}, nil).Times(1) store.On("ListChannelSubscriptions").Return([]*storemodels.ChannelSubscription{}, nil).Times(1) - store.On("SaveChannelSubscription", storemodels.ChannelSubscription{SubscriptionID: "new-id", TeamID: "team-id", ChannelID: "channel-id", Secret: "webhook-secret", ExpiresOn: newExpiresOn}).Return(nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("SaveChannelSubscription", storemodels.ChannelSubscription{SubscriptionID: "new-id", TeamID: "team-id", ChannelID: "channel-id", Secret: "webhook-secret", ExpiresOn: newExpiresOn}, &sql.Tx{}).Return(nil).Times(1) + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, }, { @@ -234,7 +239,9 @@ func TestMonitorCheckChannelSubscriptions(t *testing.T) { store.On("ListChannelLinks").Return([]storemodels.ChannelLink{channelLink}, nil).Times(1) store.On("ListChannelSubscriptions").Return([]*storemodels.ChannelSubscription{{SubscriptionID: "test", TeamID: "team-id", ChannelID: "channel-id", Secret: "webhook-secret", ExpiresOn: time.Now().Add(3 * time.Minute)}}, nil).Times(1) store.On("DeleteSubscription", "test").Return(nil).Times(1) - store.On("SaveChannelSubscription", storemodels.ChannelSubscription{SubscriptionID: "new-id", TeamID: "team-id", ChannelID: "channel-id", Secret: "webhook-secret", ExpiresOn: newExpiresOn}).Return(nil).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("SaveChannelSubscription", storemodels.ChannelSubscription{SubscriptionID: "new-id", TeamID: "team-id", ChannelID: "channel-id", Secret: "webhook-secret", ExpiresOn: newExpiresOn}, &sql.Tx{}).Return(nil).Times(1) + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, }, { @@ -456,7 +463,7 @@ func TestMonitorRecreateGlobalSubscription(t *testing.T) { } } -func TestMonitorRecreateChannelSubscription(t *testing.T) { +func TestRecreateChannelSubscription(t *testing.T) { newExpiresOn := time.Now().Add(100 * time.Minute) for _, testCase := range []struct { description string @@ -502,6 +509,26 @@ func TestMonitorRecreateChannelSubscription(t *testing.T) { }, setupStore: func(store *mocksStore.Store) {}, }, + { + description: "Unable to begin database transaction", + subscriptionID: "test-id", + teamID: "team-id", + channelID: "channel-id", + secret: "webhook-secret", + expectsError: true, + setupClient: func(client *mocksClient.Client) { + client.On("DeleteSubscription", "test-id").Return(nil).Times(1) + client.On("SubscribeToChannel", "team-id", "channel-id", "base-url", "webhook-secret").Return(&msteams.Subscription{ID: "new-id", ExpiresOn: newExpiresOn}, nil).Times(1) + }, + setupAPI: func(mockAPI *plugintest.API) { + mockAPI.On("LogDebug", "Unable to delete old channel subscription from DB", "subscriptionID", "test-id", "error", "error in deleting subscription from store").Return() + mockAPI.On("LogWarn", "Unable to begin database transaction", "error", "unable to begin database transaction").Return().Times(1) + }, + setupStore: func(store *mocksStore.Store) { + store.On("DeleteSubscription", "test-id").Return(errors.New("error in deleting subscription from store")) + store.On("BeginTx").Return(&sql.Tx{}, errors.New("unable to begin database transaction")).Times(1) + }, + }, { description: "Failed to save the channel subscription in the database", subscriptionID: "test-id", @@ -519,7 +546,53 @@ func TestMonitorRecreateChannelSubscription(t *testing.T) { }, setupStore: func(store *mocksStore.Store) { store.On("DeleteSubscription", "test-id").Return(errors.New("error in deleting subscription from store")) - store.On("SaveChannelSubscription", storemodels.ChannelSubscription{SubscriptionID: "new-id", TeamID: "team-id", ChannelID: "channel-id", Secret: "webhook-secret", ExpiresOn: newExpiresOn}).Return(errors.New("failed to save the channel subscription in the database")).Times(1) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("SaveChannelSubscription", storemodels.ChannelSubscription{SubscriptionID: "new-id", TeamID: "team-id", ChannelID: "channel-id", Secret: "webhook-secret", ExpiresOn: newExpiresOn}, &sql.Tx{}).Return(errors.New("failed to save the channel subscription in the database")).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(nil).Times(1) + }, + }, + { + description: "Failed to save the channel subscription in the database and rollback database transaction", + subscriptionID: "test-id", + teamID: "team-id", + channelID: "channel-id", + secret: "webhook-secret", + expectsError: true, + setupClient: func(client *mocksClient.Client) { + client.On("DeleteSubscription", "test-id").Return(nil).Times(1) + client.On("SubscribeToChannel", "team-id", "channel-id", "base-url", "webhook-secret").Return(&msteams.Subscription{ID: "new-id", ExpiresOn: newExpiresOn}, nil).Times(1) + }, + setupAPI: func(mockAPI *plugintest.API) { + mockAPI.On("LogDebug", "Unable to delete old channel subscription from DB", "subscriptionID", "test-id", "error", "error in deleting subscription from store").Return() + mockAPI.On("LogError", "Unable to store new subscription in DB", "subscriptionID", "new-id", "error", "failed to save the channel subscription in the database").Return().Times(1) + mockAPI.On("LogWarn", "Unable to rollback database transaction", "error", "unable to rollback database transaction").Return(nil).Times(1) + }, + setupStore: func(store *mocksStore.Store) { + store.On("DeleteSubscription", "test-id").Return(errors.New("error in deleting subscription from store")) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("SaveChannelSubscription", storemodels.ChannelSubscription{SubscriptionID: "new-id", TeamID: "team-id", ChannelID: "channel-id", Secret: "webhook-secret", ExpiresOn: newExpiresOn}, &sql.Tx{}).Return(errors.New("failed to save the channel subscription in the database")).Times(1) + store.On("RollbackTx", &sql.Tx{}).Return(errors.New("unable to rollback database transaction")).Times(1) + }, + }, + { + description: "Unable to commit database transaction", + subscriptionID: "test-id", + teamID: "team-id", + channelID: "channel-id", + secret: "webhook-secret", + expectsError: false, + setupClient: func(client *mocksClient.Client) { + client.On("DeleteSubscription", "test-id").Return(nil).Times(1) + client.On("SubscribeToChannel", "team-id", "channel-id", "base-url", "webhook-secret").Return(&msteams.Subscription{ID: "new-id", ExpiresOn: newExpiresOn}, nil).Times(1) + }, + setupAPI: func(mockAPI *plugintest.API) { + mockAPI.On("LogWarn", "Unable to commit database transaction", "error", "unable to commit database transaction").Return(nil).Times(1) + }, + setupStore: func(store *mocksStore.Store) { + store.On("DeleteSubscription", "test-id").Return(nil) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("SaveChannelSubscription", storemodels.ChannelSubscription{SubscriptionID: "new-id", TeamID: "team-id", ChannelID: "channel-id", Secret: "webhook-secret", ExpiresOn: newExpiresOn}, &sql.Tx{}).Return(nil) + store.On("CommitTx", &sql.Tx{}).Return(errors.New("unable to commit database transaction")).Times(1) }, }, { @@ -536,7 +609,9 @@ func TestMonitorRecreateChannelSubscription(t *testing.T) { setupAPI: func(mockAPI *plugintest.API) {}, setupStore: func(store *mocksStore.Store) { store.On("DeleteSubscription", "test-id").Return(nil) - store.On("SaveChannelSubscription", storemodels.ChannelSubscription{SubscriptionID: "new-id", TeamID: "team-id", ChannelID: "channel-id", Secret: "webhook-secret", ExpiresOn: newExpiresOn}).Return(nil) + store.On("BeginTx").Return(&sql.Tx{}, nil).Times(1) + store.On("SaveChannelSubscription", storemodels.ChannelSubscription{SubscriptionID: "new-id", TeamID: "team-id", ChannelID: "channel-id", Secret: "webhook-secret", ExpiresOn: newExpiresOn}, &sql.Tx{}).Return(nil) + store.On("CommitTx", &sql.Tx{}).Return(nil).Times(1) }, }, } { diff --git a/server/msteams/client.go b/server/msteams/client.go index dc2061b83..40e7b1c61 100644 --- a/server/msteams/client.go +++ b/server/msteams/client.go @@ -737,6 +737,10 @@ func (tc *ClientImpl) UpdateMessage(teamID, channelID, parentID, msgID, message } } + if updateMessageRequest == nil { + return nil, errors.New("received nil updateMessageRequest from MS Graph") + } + var getMessageRequest *abstractions.RequestInformation if parentID != "" { getMessageRequest, err = tc.client.Teams().ByTeamId(teamID).Channels().ByChannelId(channelID).Messages().ByChatMessageId(parentID).Replies().ByChatMessageId1(msgID).ToGetRequestInformation(tc.ctx, nil) @@ -750,6 +754,10 @@ func (tc *ClientImpl) UpdateMessage(teamID, channelID, parentID, msgID, message } } + if getMessageRequest == nil { + return nil, errors.New("received nil getMessageRequest from MS Graph") + } + batchRequest := msgraphcore.NewBatchRequest(tc.client.GetAdapter()) updateMessageRequestItem, err := batchRequest.AddBatchRequestStep(*updateMessageRequest) if err != nil { @@ -797,11 +805,19 @@ func (tc *ClientImpl) UpdateChatMessage(chatID, msgID, message string, mentions return nil, NormalizeGraphAPIError(err) } + if updateMessageRequest == nil { + return nil, errors.New("received nil updateMessageRequest from MS Graph") + } + getMessageRequest, err := tc.client.Chats().ByChatId(chatID).Messages().ByChatMessageId(msgID).ToGetRequestInformation(tc.ctx, nil) if err != nil { return nil, NormalizeGraphAPIError(err) } + if getMessageRequest == nil { + return nil, errors.New("received nil getMessageRequest from MS Graph") + } + batchRequest := msgraphcore.NewBatchRequest(tc.client.GetAdapter()) updateMessageRequestItem, err := batchRequest.AddBatchRequestStep(*updateMessageRequest) if err != nil { @@ -1565,11 +1581,19 @@ func (tc *ClientImpl) SetChatReaction(chatID, messageID, userID, emoji string) ( return nil, NormalizeGraphAPIError(err) } + if setReactionRequest == nil { + return nil, errors.New("received nil setReactionRequest from MS Graph") + } + getMessageRequest, err := tc.client.Chats().ByChatId(chatID).Messages().ByChatMessageId(messageID).ToGetRequestInformation(tc.ctx, nil) if err != nil { return nil, NormalizeGraphAPIError(err) } + if getMessageRequest == nil { + return nil, errors.New("received nil getMessageRequest from MS Graph") + } + batchRequest := msgraphcore.NewBatchRequest(tc.client.GetAdapter()) setReactionRequestItem, err := batchRequest.AddBatchRequestStep(*setReactionRequest) if err != nil { @@ -1614,6 +1638,10 @@ func (tc *ClientImpl) SetReaction(teamID, channelID, parentID, messageID, userID } } + if setReactionRequest == nil { + return nil, errors.New("received nil setReactionRequest from MS Graph") + } + var getMessageRequest *abstractions.RequestInformation if parentID != "" { getMessageRequest, err = tc.client.Teams().ByTeamId(teamID).Channels().ByChannelId(channelID).Messages().ByChatMessageId(parentID).Replies().ByChatMessageId1(messageID).ToGetRequestInformation(tc.ctx, nil) @@ -1627,6 +1655,10 @@ func (tc *ClientImpl) SetReaction(teamID, channelID, parentID, messageID, userID } } + if getMessageRequest == nil { + return nil, errors.New("received nil getMessageRequest from MS Graph") + } + batchRequest := msgraphcore.NewBatchRequest(tc.client.GetAdapter()) setReactionRequestItem, err := batchRequest.AddBatchRequestStep(*setReactionRequest) if err != nil { @@ -1658,11 +1690,19 @@ func (tc *ClientImpl) UnsetChatReaction(chatID, messageID, userID, emoji string) return nil, err } + if unsetReactionRequest == nil { + return nil, errors.New("received nil unsetReactionRequest from MS Graph") + } + getMessageRequest, err := tc.client.Chats().ByChatId(chatID).Messages().ByChatMessageId(messageID).ToGetRequestInformation(tc.ctx, nil) if err != nil { return nil, NormalizeGraphAPIError(err) } + if getMessageRequest == nil { + return nil, errors.New("received nil getMessageRequest from MS Graph") + } + batchRequest := msgraphcore.NewBatchRequest(tc.client.GetAdapter()) unsetReactionRequestItem, err := batchRequest.AddBatchRequestStep(*unsetReactionRequest) if err != nil { @@ -1707,6 +1747,10 @@ func (tc *ClientImpl) UnsetReaction(teamID, channelID, parentID, messageID, user } } + if unsetReactionRequest == nil { + return nil, errors.New("received nil unsetReactionRequest from MS Graph") + } + var getMessageRequest *abstractions.RequestInformation if parentID != "" { getMessageRequest, err = tc.client.Teams().ByTeamId(teamID).Channels().ByChannelId(channelID).Messages().ByChatMessageId(parentID).Replies().ByChatMessageId1(messageID).ToGetRequestInformation(tc.ctx, nil) @@ -1720,6 +1764,10 @@ func (tc *ClientImpl) UnsetReaction(teamID, channelID, parentID, messageID, user } } + if getMessageRequest == nil { + return nil, errors.New("received nil getMessageRequest from MS Graph") + } + batchRequest := msgraphcore.NewBatchRequest(tc.client.GetAdapter()) unsetReactionRequestItem, err := batchRequest.AddBatchRequestStep(*unsetReactionRequest) if err != nil { @@ -1871,6 +1919,14 @@ func (tc *ClientImpl) SendBatchRequestAndGetMessage(batchRequest msgraphcore.Bat return nil, NormalizeGraphAPIError(err) } + if resp == nil { + return nil, errors.New("received nil response from MS Graph for the message") + } + + if resp.GetLastModifiedDateTime() == nil { + return nil, errors.New("received nil last modified date time from MS Graph for the message") + } + return &Message{LastUpdateAt: *resp.GetLastModifiedDateTime()}, nil } diff --git a/server/plugin_test.go b/server/plugin_test.go index 47cd21ca8..3e3f04135 100644 --- a/server/plugin_test.go +++ b/server/plugin_test.go @@ -1,6 +1,7 @@ package main import ( + "database/sql" "math" "net/http" "os" @@ -127,7 +128,7 @@ func TestMessageHasBeenPostedNewMessage(t *testing.T) { MSTeamsID: "new-message-id", MSTeamsChannel: "ms-channel-id", MSTeamsLastUpdateAt: now, - }).Return(nil).Times(1) + }, (*sql.Tx)(nil)).Return(nil).Times(1) clientMock := plugin.clientBuilderWithToken("", "", "", "", nil, nil) clientMock.(*mocks.Client).On("SendMessageWithAttachments", "ms-team-id", "ms-channel-id", "", "

message

\n", []*msteams.Attachment(nil), []models.ChatMessageMentionable{}).Return(&msteams.Message{ID: "new-message-id", LastUpdateAt: now}, nil) diff --git a/server/store/mocks/Store.go b/server/store/mocks/Store.go index 1285b1579..66551d98f 100644 --- a/server/store/mocks/Store.go +++ b/server/store/mocks/Store.go @@ -6,6 +6,8 @@ import ( mock "github.com/stretchr/testify/mock" oauth2 "golang.org/x/oauth2" + sql "database/sql" + store "github.com/mattermost/mattermost-plugin-msteams-sync/server/store" storemodels "github.com/mattermost/mattermost-plugin-msteams-sync/server/store/storemodels" @@ -18,6 +20,29 @@ type Store struct { mock.Mock } +// BeginTx provides a mock function with given fields: +func (_m *Store) BeginTx() (*sql.Tx, error) { + ret := _m.Called() + + var r0 *sql.Tx + if rf, ok := ret.Get(0).(func() *sql.Tx); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sql.Tx) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // CheckEnabledTeamByTeamID provides a mock function with given fields: teamID func (_m *Store) CheckEnabledTeamByTeamID(teamID string) bool { ret := _m.Called(teamID) @@ -32,6 +57,20 @@ func (_m *Store) CheckEnabledTeamByTeamID(teamID string) bool { return r0 } +// CommitTx provides a mock function with given fields: tx +func (_m *Store) CommitTx(tx *sql.Tx) error { + ret := _m.Called(tx) + + var r0 error + if rf, ok := ret.Get(0).(func(*sql.Tx) error); ok { + r0 = rf(tx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // CompareAndSetJobStatus provides a mock function with given fields: jobName, oldStatus, newStatus func (_m *Store) CompareAndSetJobStatus(jobName string, oldStatus bool, newStatus bool) (bool, error) { ret := _m.Called(jobName, oldStatus, newStatus) @@ -464,13 +503,13 @@ func (_m *Store) Init() error { return r0 } -// LinkPosts provides a mock function with given fields: postInfo -func (_m *Store) LinkPosts(postInfo storemodels.PostInfo) error { - ret := _m.Called(postInfo) +// LinkPosts provides a mock function with given fields: postInfo, tx +func (_m *Store) LinkPosts(postInfo storemodels.PostInfo, tx *sql.Tx) error { + ret := _m.Called(postInfo, tx) var r0 error - if rf, ok := ret.Get(0).(func(storemodels.PostInfo) error); ok { - r0 = rf(postInfo) + if rf, ok := ret.Get(0).(func(storemodels.PostInfo, *sql.Tx) error); ok { + r0 = rf(postInfo, tx) } else { r0 = ret.Error(0) } @@ -639,6 +678,34 @@ func (_m *Store) ListGlobalSubscriptionsToRefresh() ([]*storemodels.GlobalSubscr return r0, r1 } +// LockPostByMMPostID provides a mock function with given fields: tx, messageID +func (_m *Store) LockPostByMMPostID(tx *sql.Tx, messageID string) error { + ret := _m.Called(tx, messageID) + + var r0 error + if rf, ok := ret.Get(0).(func(*sql.Tx, string) error); ok { + r0 = rf(tx, messageID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// LockPostByMSTeamsPostID provides a mock function with given fields: tx, messageID +func (_m *Store) LockPostByMSTeamsPostID(tx *sql.Tx, messageID string) error { + ret := _m.Called(tx, messageID) + + var r0 error + if rf, ok := ret.Get(0).(func(*sql.Tx, string) error); ok { + r0 = rf(tx, messageID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // MattermostToTeamsUserID provides a mock function with given fields: userID func (_m *Store) MattermostToTeamsUserID(userID string) (string, error) { ret := _m.Called(userID) @@ -674,13 +741,27 @@ func (_m *Store) RecoverPost(postID string) error { return r0 } -// SaveChannelSubscription provides a mock function with given fields: _a0 -func (_m *Store) SaveChannelSubscription(_a0 storemodels.ChannelSubscription) error { - ret := _m.Called(_a0) +// RollbackTx provides a mock function with given fields: tx +func (_m *Store) RollbackTx(tx *sql.Tx) error { + ret := _m.Called(tx) var r0 error - if rf, ok := ret.Get(0).(func(storemodels.ChannelSubscription) error); ok { - r0 = rf(_a0) + if rf, ok := ret.Get(0).(func(*sql.Tx) error); ok { + r0 = rf(tx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SaveChannelSubscription provides a mock function with given fields: _a0, _a1 +func (_m *Store) SaveChannelSubscription(_a0 storemodels.ChannelSubscription, _a1 *sql.Tx) error { + ret := _m.Called(_a0, _a1) + + var r0 error + if rf, ok := ret.Get(0).(func(storemodels.ChannelSubscription, *sql.Tx) error); ok { + r0 = rf(_a0, _a1) } else { r0 = ret.Error(0) } @@ -744,13 +825,13 @@ func (_m *Store) SetJobStatus(jobName string, status bool) error { return r0 } -// SetPostLastUpdateAtByMSTeamsID provides a mock function with given fields: postID, lastUpdateAt -func (_m *Store) SetPostLastUpdateAtByMSTeamsID(postID string, lastUpdateAt time.Time) error { - ret := _m.Called(postID, lastUpdateAt) +// SetPostLastUpdateAtByMSTeamsID provides a mock function with given fields: postID, lastUpdateAt, tx +func (_m *Store) SetPostLastUpdateAtByMSTeamsID(postID string, lastUpdateAt time.Time, tx *sql.Tx) error { + ret := _m.Called(postID, lastUpdateAt, tx) var r0 error - if rf, ok := ret.Get(0).(func(string, time.Time) error); ok { - r0 = rf(postID, lastUpdateAt) + if rf, ok := ret.Get(0).(func(string, time.Time, *sql.Tx) error); ok { + r0 = rf(postID, lastUpdateAt, tx) } else { r0 = ret.Error(0) } @@ -758,13 +839,13 @@ func (_m *Store) SetPostLastUpdateAtByMSTeamsID(postID string, lastUpdateAt time return r0 } -// SetPostLastUpdateAtByMattermostID provides a mock function with given fields: postID, lastUpdateAt -func (_m *Store) SetPostLastUpdateAtByMattermostID(postID string, lastUpdateAt time.Time) error { - ret := _m.Called(postID, lastUpdateAt) +// SetPostLastUpdateAtByMattermostID provides a mock function with given fields: postID, lastUpdateAt, tx +func (_m *Store) SetPostLastUpdateAtByMattermostID(postID string, lastUpdateAt time.Time, tx *sql.Tx) error { + ret := _m.Called(postID, lastUpdateAt, tx) var r0 error - if rf, ok := ret.Get(0).(func(string, time.Time) error); ok { - r0 = rf(postID, lastUpdateAt) + if rf, ok := ret.Get(0).(func(string, time.Time, *sql.Tx) error); ok { + r0 = rf(postID, lastUpdateAt, tx) } else { r0 = ret.Error(0) } diff --git a/server/store/store.go b/server/store/store.go index 77775897c..0a535e3e8 100644 --- a/server/store/store.go +++ b/server/store/store.go @@ -49,9 +49,9 @@ type Store interface { StoreChannelLink(link *storemodels.ChannelLink) error GetPostInfoByMSTeamsID(chatID string, postID string) (*storemodels.PostInfo, error) GetPostInfoByMattermostID(postID string) (*storemodels.PostInfo, error) - LinkPosts(postInfo storemodels.PostInfo) error - SetPostLastUpdateAtByMattermostID(postID string, lastUpdateAt time.Time) error - SetPostLastUpdateAtByMSTeamsID(postID string, lastUpdateAt time.Time) error + LinkPosts(postInfo storemodels.PostInfo, tx *sql.Tx) error + SetPostLastUpdateAtByMattermostID(postID string, lastUpdateAt time.Time, tx *sql.Tx) error + SetPostLastUpdateAtByMSTeamsID(postID string, lastUpdateAt time.Time, tx *sql.Tx) error GetTokenForMattermostUser(userID string) (*oauth2.Token, error) GetTokenForMSTeamsUser(userID string) (*oauth2.Token, error) SetUserInfo(userID string, msTeamsUserID string, token *oauth2.Token) error @@ -66,7 +66,7 @@ type Store interface { ListChannelSubscriptionsToRefresh() ([]*storemodels.ChannelSubscription, error) SaveGlobalSubscription(storemodels.GlobalSubscription) error SaveChatSubscription(storemodels.ChatSubscription) error - SaveChannelSubscription(storemodels.ChannelSubscription) error + SaveChannelSubscription(storemodels.ChannelSubscription, *sql.Tx) error UpdateSubscriptionExpiresOn(subscriptionID string, expiresOn time.Time) error DeleteSubscription(subscriptionID string) error GetChannelSubscription(subscriptionID string) (*storemodels.ChannelSubscription, error) @@ -84,6 +84,11 @@ type Store interface { CompareAndSetJobStatus(jobName string, oldStatus, newStatus bool) (bool, error) GetStats() (*Stats, error) GetConnectedUsers(page, perPage int) ([]*storemodels.ConnectedUser, error) + LockPostByMSTeamsPostID(tx *sql.Tx, messageID string) error + LockPostByMMPostID(tx *sql.Tx, messageID string) error + BeginTx() (*sql.Tx, error) + RollbackTx(tx *sql.Tx) error + CommitTx(tx *sql.Tx) error } type SQLStore struct { @@ -403,38 +408,77 @@ func (s *SQLStore) MattermostToTeamsUserID(userID string) (string, error) { } func (s *SQLStore) GetPostInfoByMSTeamsID(chatID string, postID string) (*storemodels.PostInfo, error) { - query := s.getQueryBuilder().Select("mmPostID, msTeamsLastUpdateAt").From("msteamssync_posts").Where(sq.Eq{"msTeamsPostID": postID, "msTeamsChannelID": chatID}) + tx, err := s.BeginTx() + if err != nil { + return nil, err + } + + var txErr error + defer func() { + if txErr != nil { + if err := s.RollbackTx(tx); err != nil { + s.api.LogWarn("Unable to rollback database transaction", "error", err.Error()) + } + return + } + + if err := s.CommitTx(tx); err != nil { + s.api.LogWarn("Unable to commit database transaction", "error", err.Error()) + } + }() + + query := s.getQueryBuilder().Select("mmPostID, msTeamsLastUpdateAt").From("msteamssync_posts").Where(sq.Eq{"msTeamsPostID": postID, "msTeamsChannelID": chatID}).Suffix("FOR UPDATE").RunWith(tx) row := query.QueryRow() var lastUpdateAt int64 postInfo := storemodels.PostInfo{ MSTeamsID: postID, MSTeamsChannel: chatID, } - err := row.Scan(&postInfo.MattermostID, &lastUpdateAt) - if err != nil { - return nil, err + + if txErr = row.Scan(&postInfo.MattermostID, &lastUpdateAt); txErr != nil { + return nil, txErr } postInfo.MSTeamsLastUpdateAt = time.UnixMicro(lastUpdateAt) return &postInfo, nil } func (s *SQLStore) GetPostInfoByMattermostID(postID string) (*storemodels.PostInfo, error) { - query := s.getQueryBuilder().Select("msTeamsPostID, msTeamsChannelID, msTeamsLastUpdateAt").From("msteamssync_posts").Where(sq.Eq{"mmPostID": postID}) + tx, err := s.BeginTx() + if err != nil { + return nil, err + } + + var txErr error + defer func() { + if txErr != nil { + if err := s.RollbackTx(tx); err != nil { + s.api.LogWarn("Unable to rollback database transaction", "error", err.Error()) + } + return + } + + if err := s.CommitTx(tx); err != nil { + s.api.LogWarn("Unable to commit database transaction", "error", err.Error()) + } + }() + + query := s.getQueryBuilder().Select("msTeamsPostID, msTeamsChannelID, msTeamsLastUpdateAt").From("msteamssync_posts").Where(sq.Eq{"mmPostID": postID}).Suffix("FOR UPDATE").RunWith(tx) row := query.QueryRow() var lastUpdateAt int64 postInfo := storemodels.PostInfo{ MattermostID: postID, } - err := row.Scan(&postInfo.MSTeamsID, &postInfo.MSTeamsChannel, &lastUpdateAt) - if err != nil { - return nil, err + + if txErr = row.Scan(&postInfo.MSTeamsID, &postInfo.MSTeamsChannel, &lastUpdateAt); txErr != nil { + return nil, txErr } + postInfo.MSTeamsLastUpdateAt = time.UnixMicro(lastUpdateAt) return &postInfo, nil } -func (s *SQLStore) SetPostLastUpdateAtByMattermostID(postID string, lastUpdateAt time.Time) error { - query := s.getQueryBuilder().Update("msteamssync_posts").Set("msTeamsLastUpdateAt", lastUpdateAt.UnixMicro()).Where(sq.Eq{"mmPostID": postID}) +func (s *SQLStore) SetPostLastUpdateAtByMattermostID(postID string, lastUpdateAt time.Time, tx *sql.Tx) error { + query := s.getQueryBuilder().Update("msteamssync_posts").Set("msTeamsLastUpdateAt", lastUpdateAt.UnixMicro()).Where(sq.Eq{"mmPostID": postID}).RunWith(tx) if _, err := query.Exec(); err != nil { return err } @@ -442,8 +486,8 @@ func (s *SQLStore) SetPostLastUpdateAtByMattermostID(postID string, lastUpdateAt return nil } -func (s *SQLStore) SetPostLastUpdateAtByMSTeamsID(msTeamsPostID string, lastUpdateAt time.Time) error { - query := s.getQueryBuilder().Update("msteamssync_posts").Set("msTeamsLastUpdateAt", lastUpdateAt.UnixMicro()).Where(sq.Eq{"msTeamsPostID": msTeamsPostID}) +func (s *SQLStore) SetPostLastUpdateAtByMSTeamsID(msTeamsPostID string, lastUpdateAt time.Time, tx *sql.Tx) error { + query := s.getQueryBuilder().Update("msteamssync_posts").Set("msTeamsLastUpdateAt", lastUpdateAt.UnixMicro()).Where(sq.Eq{"msTeamsPostID": msTeamsPostID}).RunWith(tx) if _, err := query.Exec(); err != nil { return err } @@ -451,24 +495,38 @@ func (s *SQLStore) SetPostLastUpdateAtByMSTeamsID(msTeamsPostID string, lastUpda return nil } -func (s *SQLStore) LinkPosts(postInfo storemodels.PostInfo) error { +func (s *SQLStore) LinkPosts(postInfo storemodels.PostInfo, tx *sql.Tx) error { if s.driverName == "postgres" { - if _, err := s.getQueryBuilder().Insert("msteamssync_posts").Columns("mmPostID, msTeamsPostID, msTeamsChannelID, msTeamsLastUpdateAt").Values( + query := s.getQueryBuilder().Insert("msteamssync_posts").Columns("mmPostID, msTeamsPostID, msTeamsChannelID, msTeamsLastUpdateAt").Values( postInfo.MattermostID, postInfo.MSTeamsID, postInfo.MSTeamsChannel, postInfo.MSTeamsLastUpdateAt.UnixMicro(), - ).Suffix("ON CONFLICT (mmPostID) DO UPDATE SET msTeamsPostID = EXCLUDED.msTeamsPostID, msTeamsChannelID = EXCLUDED.msTeamsChannelID, msTeamsLastUpdateAt = EXCLUDED.msTeamsLastUpdateAt").Exec(); err != nil { - return err + ).Suffix("ON CONFLICT (mmPostID) DO UPDATE SET msTeamsPostID = EXCLUDED.msTeamsPostID, msTeamsChannelID = EXCLUDED.msTeamsChannelID, msTeamsLastUpdateAt = EXCLUDED.msTeamsLastUpdateAt") + if tx != nil { + if _, err := query.RunWith(tx).Exec(); err != nil { + return err + } + } else { + if _, err := query.Exec(); err != nil { + return err + } } } else { - if _, err := s.getQueryBuilder().Replace("msteamssync_posts").Columns("mmPostID, msTeamsPostID, msTeamsChannelID, msTeamsLastUpdateAt").Values( + query := s.getQueryBuilder().Replace("msteamssync_posts").Columns("mmPostID, msTeamsPostID, msTeamsChannelID, msTeamsLastUpdateAt").Values( postInfo.MattermostID, postInfo.MSTeamsID, postInfo.MSTeamsChannel, postInfo.MSTeamsLastUpdateAt.UnixMicro(), - ).Exec(); err != nil { - return err + ) + if tx != nil { + if _, err := query.RunWith(tx).Exec(); err != nil { + return err + } + } else { + if _, err := query.Exec(); err != nil { + return err + } } } return nil @@ -706,12 +764,12 @@ func (s *SQLStore) SaveChatSubscription(subscription storemodels.ChatSubscriptio return nil } -func (s *SQLStore) SaveChannelSubscription(subscription storemodels.ChannelSubscription) error { - if _, err := s.getQueryBuilder().Delete("msteamssync_subscriptions").Where(sq.Eq{"msTeamsTeamID": subscription.TeamID, "msTeamsChannelID": subscription.ChannelID}).Exec(); err != nil { +func (s *SQLStore) SaveChannelSubscription(subscription storemodels.ChannelSubscription, tx *sql.Tx) error { + if _, err := s.getQueryBuilder().Delete("msteamssync_subscriptions").Where(sq.Eq{"msTeamsTeamID": subscription.TeamID, "msTeamsChannelID": subscription.ChannelID}).RunWith(tx).Exec(); err != nil { return err } - if _, err := s.getQueryBuilder().Insert("msteamssync_subscriptions").Columns("subscriptionID, msTeamsTeamID, msTeamsChannelID, type, secret, expiresOn").Values(subscription.SubscriptionID, subscription.TeamID, subscription.ChannelID, subscriptionTypeChannel, subscription.Secret, subscription.ExpiresOn.UnixMicro()).Exec(); err != nil { + if _, err := s.getQueryBuilder().Insert("msteamssync_subscriptions").Columns("subscriptionID, msTeamsTeamID, msTeamsChannelID, type, secret, expiresOn").Values(subscription.SubscriptionID, subscription.TeamID, subscription.ChannelID, subscriptionTypeChannel, subscription.Secret, subscription.ExpiresOn.UnixMicro()).RunWith(tx).Exec(); err != nil { return err } return nil @@ -734,11 +792,30 @@ func (s *SQLStore) DeleteSubscription(subscriptionID string) error { } func (s *SQLStore) GetChannelSubscription(subscriptionID string) (*storemodels.ChannelSubscription, error) { - row := s.getQueryBuilder().Select("subscriptionID, msTeamsChannelID, msTeamsTeamID, secret, expiresOn").From("msteamssync_subscriptions").Where(sq.Eq{"subscriptionID": subscriptionID, "type": subscriptionTypeChannel}).QueryRow() + tx, err := s.BeginTx() + if err != nil { + return nil, err + } + + var txErr error + defer func() { + if txErr != nil { + if err := s.RollbackTx(tx); err != nil { + s.api.LogWarn("Unable to rollback database transaction", "error", err.Error()) + } + return + } + + if err := s.CommitTx(tx); err != nil { + s.api.LogWarn("Unable to commit database transaction", "error", err.Error()) + } + }() + + row := s.getQueryBuilder().Select("subscriptionID, msTeamsChannelID, msTeamsTeamID, secret, expiresOn").From("msteamssync_subscriptions").Where(sq.Eq{"subscriptionID": subscriptionID, "type": subscriptionTypeChannel}).Suffix("FOR UPDATE").QueryRow() var subscription storemodels.ChannelSubscription var expiresOn int64 - if scanErr := row.Scan(&subscription.SubscriptionID, &subscription.ChannelID, &subscription.TeamID, &subscription.Secret, &expiresOn); scanErr != nil { - return nil, scanErr + if txErr = row.Scan(&subscription.SubscriptionID, &subscription.ChannelID, &subscription.TeamID, &subscription.Secret, &expiresOn); txErr != nil { + return nil, txErr } subscription.ExpiresOn = time.UnixMicro(expiresOn) return &subscription, nil @@ -984,6 +1061,36 @@ func (s *SQLStore) GetConnectedUsers(page, perPage int) ([]*storemodels.Connecte return connectedUsers, nil } +func (s *SQLStore) LockPostByMSTeamsPostID(tx *sql.Tx, messageID string) error { + query := s.getQueryBuilder().Select("*").From("msteamssync_posts").Where(sq.Eq{"msTeamsPostID": messageID}).Suffix("FOR UPDATE").RunWith(tx) + if _, err := query.Exec(); err != nil { + return err + } + + return nil +} + +func (s *SQLStore) LockPostByMMPostID(tx *sql.Tx, messageID string) error { + query := s.getQueryBuilder().Select("*").From("msteamssync_posts").Where(sq.Eq{"mmPostID": messageID}).Suffix("FOR UPDATE").RunWith(tx) + if _, err := query.Exec(); err != nil { + return err + } + + return nil +} + +func (s *SQLStore) BeginTx() (*sql.Tx, error) { + return s.db.Begin() +} + +func (s *SQLStore) RollbackTx(tx *sql.Tx) error { + return tx.Rollback() +} + +func (s *SQLStore) CommitTx(tx *sql.Tx) error { + return tx.Commit() +} + func hashKey(prefix, hashableKey string) string { if hashableKey == "" { return prefix diff --git a/server/store/store_test.go b/server/store/store_test.go index 78544f8e8..40bebebbe 100644 --- a/server/store/store_test.go +++ b/server/store/store_test.go @@ -511,7 +511,7 @@ func testLinkPostsAndGetPostInfoByMSTeamsID(t *testing.T, store *SQLStore, _ *pl MSTeamsLastUpdateAt: time.UnixMicro(int64(100)), } - storeErr := store.LinkPosts(mockPostInfo) + storeErr := store.LinkPosts(mockPostInfo, nil) assert.Nil(storeErr) resp, getErr := store.GetPostInfoByMSTeamsID("mockMSTeamsChannel-1", "mockMSTeamsID-1") @@ -537,7 +537,7 @@ func testLinkPostsAndGetPostInfoByMattermostID(t *testing.T, store *SQLStore, _ MSTeamsLastUpdateAt: time.UnixMicro(int64(100)), } - storeErr := store.LinkPosts(mockPostInfo) + storeErr := store.LinkPosts(mockPostInfo, nil) assert.Nil(storeErr) resp, getErr := store.GetPostInfoByMattermostID("mockMattermostID-2") @@ -688,7 +688,7 @@ func testListGlobalSubscriptionsToCheck(t *testing.T, store *SQLStore, _ *plugin }) t.Run("no-near-to-expire-subscriptions", func(t *testing.T) { - err := store.SaveGlobalSubscription(storemodels.GlobalSubscription{SubscriptionID: "test", Type: "allChats", Secret: "secret", ExpiresOn: time.Now().Add(100 * time.Minute)}) + err := store.SaveGlobalSubscription(testutils.GetGlobalSubscription("test", time.Now().Add(100*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test") }() @@ -698,7 +698,7 @@ func testListGlobalSubscriptionsToCheck(t *testing.T, store *SQLStore, _ *plugin }) t.Run("almost-expired", func(t *testing.T) { - err := store.SaveGlobalSubscription(storemodels.GlobalSubscription{SubscriptionID: "test1", Type: "allChats", Secret: "secret", ExpiresOn: time.Now().Add(2 * time.Minute)}) + err := store.SaveGlobalSubscription(testutils.GetGlobalSubscription("test1", time.Now().Add(2*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test1") }() @@ -709,7 +709,7 @@ func testListGlobalSubscriptionsToCheck(t *testing.T, store *SQLStore, _ *plugin }) t.Run("expired-subscription", func(t *testing.T) { - err := store.SaveGlobalSubscription(storemodels.GlobalSubscription{SubscriptionID: "test1", Type: "allChats", Secret: "secret", ExpiresOn: time.Now().Add(-100 * time.Minute)}) + err := store.SaveGlobalSubscription(testutils.GetGlobalSubscription("test1", time.Now().Add(-100*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test1") }() @@ -728,7 +728,7 @@ func testListChatSubscriptionsToCheck(t *testing.T, store *SQLStore, _ *pluginte }) t.Run("no-near-to-expire-subscriptions", func(t *testing.T) { - err := store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test", UserID: "user-id", Secret: "secret", ExpiresOn: time.Now().Add(100 * time.Minute)}) + err := store.SaveChatSubscription(testutils.GetChatSubscription("test", "user-id", time.Now().Add(100*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test") }() @@ -738,22 +738,22 @@ func testListChatSubscriptionsToCheck(t *testing.T, store *SQLStore, _ *pluginte }) t.Run("multiple-subscriptions-with-different-expiry-dates", func(t *testing.T) { - err := store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test1", UserID: "user-id-1", Secret: "secret", ExpiresOn: time.Now().Add(100 * time.Minute)}) + err := store.SaveChatSubscription(testutils.GetChatSubscription("test1", "user-id-1", time.Now().Add(100*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test1") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test2", UserID: "user-id-2", Secret: "secret", ExpiresOn: time.Now().Add(100 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test2", "user-id-2", time.Now().Add(100*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test2") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test3", UserID: "user-id-3", Secret: "secret", ExpiresOn: time.Now().Add(100 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test3", "user-id-3", time.Now().Add(100*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test3") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test4", UserID: "user-id-4", Secret: "secret", ExpiresOn: time.Now().Add(2 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test4", "user-id-4", time.Now().Add(2*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test4") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test5", UserID: "user-id-5", Secret: "secret", ExpiresOn: time.Now().Add(2 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test5", "user-id-5", time.Now().Add(2*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test5") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test6", UserID: "user-id-6", Secret: "secret", ExpiresOn: time.Now().Add(-100 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test6", "user-id-6", time.Now().Add(-100*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test6") }() @@ -781,8 +781,21 @@ func testListChannelSubscriptionsToRefresh(t *testing.T, store *SQLStore, _ *plu }) t.Run("no-near-to-expire-subscriptions", func(t *testing.T) { - err := store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test", TeamID: "team-id", ChannelID: "channel-id", Secret: "secret", ExpiresOn: time.Now().Add(100 * time.Minute)}) + subscription := testutils.GetChannelSubscription("test", "team-id", "channel-id", time.Now().Add(100*time.Minute)) + go func() { + tx, err := store.BeginTx() + require.NoError(t, err) + + err = store.SaveChannelSubscription(subscription, tx) + require.NoError(t, err) + + err = store.CommitTx(tx) + require.NoError(t, err) + }() + + _, err := store.GetChannelSubscription("test") require.NoError(t, err) + defer func() { _ = store.DeleteSubscription("test") }() subscriptions, err := store.ListChannelSubscriptionsToRefresh() @@ -791,25 +804,31 @@ func testListChannelSubscriptionsToRefresh(t *testing.T, store *SQLStore, _ *plu }) t.Run("multiple-subscriptions-with-different-expiry-dates", func(t *testing.T) { - err := store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test1", TeamID: "team-id", ChannelID: "channel-id-1", Secret: "secret", ExpiresOn: time.Now().Add(100 * time.Minute)}) + tx, err := store.BeginTx() + require.NoError(t, err) + + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test1", "team-id", "channel-id-1", time.Now().Add(100*time.Minute)), tx) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test1") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test2", TeamID: "team-id", ChannelID: "channel-id-2", Secret: "secret", ExpiresOn: time.Now().Add(100 * time.Minute)}) + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test2", "team-id", "channel-id-2", time.Now().Add(100*time.Minute)), tx) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test2") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test3", TeamID: "team-id", ChannelID: "channel-id-3", Secret: "secret", ExpiresOn: time.Now().Add(100 * time.Minute)}) + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test3", "team-id", "channel-id-3", time.Now().Add(100*time.Minute)), tx) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test3") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test4", TeamID: "team-id", ChannelID: "channel-id-4", Secret: "secret", ExpiresOn: time.Now().Add(2 * time.Minute)}) + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test4", "team-id", "channel-id-4", time.Now().Add(2*time.Minute)), tx) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test4") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test5", TeamID: "team-id", ChannelID: "channel-id-5", Secret: "secret", ExpiresOn: time.Now().Add(2 * time.Minute)}) + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test5", "team-id", "channel-id-5", time.Now().Add(2*time.Minute)), tx) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test5") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test6", TeamID: "team-id", ChannelID: "channel-id-6", Secret: "secret", ExpiresOn: time.Now().Add(-100 * time.Minute)}) + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test6", "team-id", "channel-id-6", time.Now().Add(-100*time.Minute)), tx) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test6") }() + err = store.CommitTx(tx) + require.NoError(t, err) + subscriptions, err := store.ListChannelSubscriptionsToRefresh() require.NoError(t, err) assert.Len(t, subscriptions, 3) @@ -827,10 +846,10 @@ func testListChannelSubscriptionsToRefresh(t *testing.T, store *SQLStore, _ *plu } func testSaveGlobalSubscription(t *testing.T, store *SQLStore, _ *plugintest.API) { - err := store.SaveGlobalSubscription(storemodels.GlobalSubscription{SubscriptionID: "test1", Type: "allChats", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err := store.SaveGlobalSubscription(testutils.GetGlobalSubscription("test1", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test1") }() - err = store.SaveGlobalSubscription(storemodels.GlobalSubscription{SubscriptionID: "test2", Type: "allChats", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveGlobalSubscription(testutils.GetGlobalSubscription("test2", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test2") }() @@ -841,17 +860,17 @@ func testSaveGlobalSubscription(t *testing.T, store *SQLStore, _ *plugintest.API } func testSaveChatSubscription(t *testing.T, store *SQLStore, _ *plugintest.API) { - err := store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test1", UserID: "user-1", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err := store.SaveChatSubscription(testutils.GetChatSubscription("test1", "user-1", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test1") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test2", UserID: "user-1", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test2", "user-1", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test2") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test3", UserID: "user-2", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test3", "user-2", time.Now().Add(100*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test3") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test4", UserID: "user-2", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test4", "user-2", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test4") }() @@ -863,20 +882,26 @@ func testSaveChatSubscription(t *testing.T, store *SQLStore, _ *plugintest.API) } func testSaveChannelSubscription(t *testing.T, store *SQLStore, _ *plugintest.API) { - err := store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test1", TeamID: "team-id", ChannelID: "channel-id-1", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + tx, err := store.BeginTx() + require.NoError(t, err) + + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test1", "team-id", "channel-id-1", time.Now().Add(1*time.Minute)), tx) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test1") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test2", TeamID: "team-id", ChannelID: "channel-id-1", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test2", "team-id", "channel-id-1", time.Now().Add(1*time.Minute)), tx) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test2") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test3", TeamID: "team-id", ChannelID: "channel-id-2", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test3", "team-id", "channel-id-2", time.Now().Add(1*time.Minute)), tx) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test3") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test4", TeamID: "team-id", ChannelID: "channel-id-2", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test4", "team-id", "channel-id-2", time.Now().Add(1*time.Minute)), tx) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test4") }() + err = store.CommitTx(tx) + require.NoError(t, err) + subscriptions, err := store.ListChannelSubscriptionsToRefresh() require.NoError(t, err) assert.Len(t, subscriptions, 2) @@ -885,10 +910,16 @@ func testSaveChannelSubscription(t *testing.T, store *SQLStore, _ *plugintest.AP } func testUpdateSubscriptionExpiresOn(t *testing.T, store *SQLStore, _ *plugintest.API) { - err := store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test1", TeamID: "team-id", ChannelID: "channel-id-1", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + tx, err := store.BeginTx() + require.NoError(t, err) + + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test1", "team-id", "channel-id-1", time.Now().Add(1*time.Minute)), tx) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test1") }() + err = store.CommitTx(tx) + require.NoError(t, err) + subscriptions, err := store.ListChannelSubscriptionsToRefresh() require.NoError(t, err) require.Len(t, subscriptions, 1) @@ -909,24 +940,30 @@ func testUpdateSubscriptionExpiresOn(t *testing.T, store *SQLStore, _ *plugintes } func testGetGlobalSubscription(t *testing.T, store *SQLStore, _ *plugintest.API) { - err := store.SaveGlobalSubscription(storemodels.GlobalSubscription{SubscriptionID: "test1", Type: "allChats", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err := store.SaveGlobalSubscription(testutils.GetGlobalSubscription("test1", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test1") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test2", UserID: "user-1", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test2", "user-1", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test2") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test3", UserID: "user-2", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test3", "user-2", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test3") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test4", TeamID: "team-id", ChannelID: "channel-id-1", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + tx, err := store.BeginTx() + require.NoError(t, err) + + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test4", "team-id", "channel-id-1", time.Now().Add(1*time.Minute)), tx) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test4") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test5", TeamID: "team-id", ChannelID: "channel-id-2", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test5", "team-id", "channel-id-2", time.Now().Add(1*time.Minute)), tx) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test5") }() + err = store.CommitTx(tx) + require.NoError(t, err) + t.Run("not-existing-subscription", func(t *testing.T) { _, err := store.GetGlobalSubscription("not-valid") require.Error(t, err) @@ -945,24 +982,30 @@ func testGetGlobalSubscription(t *testing.T, store *SQLStore, _ *plugintest.API) } func testGetChatSubscription(t *testing.T, store *SQLStore, _ *plugintest.API) { - err := store.SaveGlobalSubscription(storemodels.GlobalSubscription{SubscriptionID: "test1", Type: "allChats", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err := store.SaveGlobalSubscription(testutils.GetGlobalSubscription("test1", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test1") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test2", UserID: "user-1", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test2", "user-1", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test2") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test3", UserID: "user-2", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test3", "user-2", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test3") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test4", TeamID: "team-id", ChannelID: "channel-id-1", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + tx, err := store.BeginTx() + require.NoError(t, err) + + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test4", "team-id", "channel-id-1", time.Now().Add(1*time.Minute)), tx) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test4") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test5", TeamID: "team-id", ChannelID: "channel-id-2", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test5", "team-id", "channel-id-2", time.Now().Add(1*time.Minute)), tx) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test5") }() + err = store.CommitTx(tx) + require.NoError(t, err) + t.Run("not-existing-subscription", func(t *testing.T) { _, err := store.GetChatSubscription("not-valid") require.Error(t, err) @@ -981,22 +1024,31 @@ func testGetChatSubscription(t *testing.T, store *SQLStore, _ *plugintest.API) { } func testGetChannelSubscription(t *testing.T, store *SQLStore, _ *plugintest.API) { - err := store.SaveGlobalSubscription(storemodels.GlobalSubscription{SubscriptionID: "test1", Type: "allChats", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err := store.SaveGlobalSubscription(testutils.GetGlobalSubscription("test1", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test1") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test2", UserID: "user-1", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test2", "user-1", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test2") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test3", UserID: "user-2", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test3", "user-2", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test3") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test4", TeamID: "team-id", ChannelID: "channel-id-1", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + tx, err := store.BeginTx() + require.NoError(t, err) + + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test4", "team-id", "channel-id-1", time.Now().Add(1*time.Minute)), tx) require.NoError(t, err) + defer func() { _ = store.DeleteSubscription("test4") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test5", TeamID: "team-id", ChannelID: "channel-id-2", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test5", "team-id", "channel-id-2", time.Now().Add(1*time.Minute)), tx) require.NoError(t, err) + + err = store.CommitTx(tx) + require.NoError(t, err) + defer func() { _ = store.DeleteSubscription("test5") }() t.Run("not-existing-subscription", func(t *testing.T) { @@ -1017,22 +1069,31 @@ func testGetChannelSubscription(t *testing.T, store *SQLStore, _ *plugintest.API } func testGetSubscriptionType(t *testing.T, store *SQLStore, _ *plugintest.API) { - err := store.SaveGlobalSubscription(storemodels.GlobalSubscription{SubscriptionID: "test1", Type: "allChats", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err := store.SaveGlobalSubscription(testutils.GetGlobalSubscription("test1", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test1") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test2", UserID: "user-1", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test2", "user-1", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test2") }() - err = store.SaveChatSubscription(storemodels.ChatSubscription{SubscriptionID: "test3", UserID: "user-2", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + err = store.SaveChatSubscription(testutils.GetChatSubscription("test3", "user-2", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test3") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test4", TeamID: "team-id", ChannelID: "channel-id-1", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + tx, err := store.BeginTx() require.NoError(t, err) + + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test4", "team-id", "channel-id-1", time.Now().Add(1*time.Minute)), tx) + require.NoError(t, err) + defer func() { _ = store.DeleteSubscription("test4") }() - err = store.SaveChannelSubscription(storemodels.ChannelSubscription{SubscriptionID: "test5", TeamID: "team-id", ChannelID: "channel-id-2", Secret: "secret", ExpiresOn: time.Now().Add(1 * time.Minute)}) + + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test5", "team-id", "channel-id-2", time.Now().Add(1*time.Minute)), tx) + require.NoError(t, err) + + err = store.CommitTx(tx) require.NoError(t, err) + defer func() { _ = store.DeleteSubscription("test5") }() t.Run("not-valid-subscription", func(t *testing.T) { @@ -1057,14 +1118,15 @@ func testGetSubscriptionType(t *testing.T, store *SQLStore, _ *plugintest.API) { } func testListChannelSubscriptions(t *testing.T, store *SQLStore, _ *plugintest.API) { - err := store.SaveChannelSubscription(storemodels.ChannelSubscription{ - SubscriptionID: "test1", - TeamID: "team-id", - ChannelID: "channel-id", - Secret: "secret", - ExpiresOn: time.Now().Add(1 * time.Minute), - }) + tx, err := store.BeginTx() require.NoError(t, err) + + err = store.SaveChannelSubscription(testutils.GetChannelSubscription("test1", "team-id", "channel-id", time.Now().Add(1*time.Minute)), tx) + require.NoError(t, err) + + err = store.CommitTx(tx) + require.NoError(t, err) + defer func() { _ = store.DeleteSubscription("test1") }() subscriptions, err := store.ListChannelSubscriptions() @@ -1073,12 +1135,7 @@ func testListChannelSubscriptions(t *testing.T, store *SQLStore, _ *plugintest.A } func testListGlobalSubscriptions(t *testing.T, store *SQLStore, _ *plugintest.API) { - err := store.SaveGlobalSubscription(storemodels.GlobalSubscription{ - SubscriptionID: "test1", - Secret: "secret", - Type: "allChats", - ExpiresOn: time.Now().Add(1 * time.Minute), - }) + err := store.SaveGlobalSubscription(testutils.GetGlobalSubscription("test1", time.Now().Add(1*time.Minute))) require.NoError(t, err) defer func() { _ = store.DeleteSubscription("test1") }() diff --git a/server/testutils/data.go b/server/testutils/data.go index f6f00bb10..150200964 100644 --- a/server/testutils/data.go +++ b/server/testutils/data.go @@ -177,3 +177,39 @@ func GetMockTime() time.Time { mockTime, _ := time.Parse("Jan 2, 2006 at 3:04pm (MST)", "Jan 2, 2023 at 4:00pm (MST)") return mockTime } + +func GetEphemeralPost(userID, channelID, message string) *model.Post { + return &model.Post{ + UserId: userID, + ChannelId: channelID, + Message: message, + } +} + +func GetGlobalSubscription(subscriptionID string, expiresOn time.Time) storemodels.GlobalSubscription { + return storemodels.GlobalSubscription{ + SubscriptionID: subscriptionID, + Type: "allChats", + Secret: "secret", + ExpiresOn: expiresOn, + } +} + +func GetChannelSubscription(subscriptionID, teamID, channelID string, expiresOn time.Time) storemodels.ChannelSubscription { + return storemodels.ChannelSubscription{ + SubscriptionID: subscriptionID, + TeamID: teamID, + ChannelID: channelID, + Secret: "secret", + ExpiresOn: expiresOn, + } +} + +func GetChatSubscription(subscriptionID, userID string, expiresOn time.Time) storemodels.ChatSubscription { + return storemodels.ChatSubscription{ + SubscriptionID: subscriptionID, + UserID: userID, + Secret: "secret", + ExpiresOn: expiresOn, + } +}