diff --git a/changelog.d/3-bug-fixes/rabbitmq-acks b/changelog.d/3-bug-fixes/rabbitmq-acks new file mode 100644 index 00000000000..0e968509335 --- /dev/null +++ b/changelog.d/3-bug-fixes/rabbitmq-acks @@ -0,0 +1 @@ +Cannon does not attempt to restore a rabbitmq channel after it disconnects. This fixes a potential issue where a client would be able to ack a message on the wrong channel. diff --git a/integration/test/Test/Events.hs b/integration/test/Test/Events.hs index 795ba3b7779..c15b4bf8803 100644 --- a/integration/test/Test/Events.hs +++ b/integration/test/Test/Events.hs @@ -7,15 +7,19 @@ import API.Galley import API.Gundeck import qualified Control.Concurrent.Timeout as Timeout import Control.Monad.Codensity +import Control.Monad.RWS (asks) import Control.Monad.Trans.Class import Control.Retry import Data.ByteString.Conversion (toByteString') import qualified Data.Text as Text import Data.Timeout +import Network.AMQP.Extended +import Network.RabbitMqAdmin import qualified Network.WebSockets as WS import Notifications import SetupHelpers import Testlib.Prelude hiding (assertNoEvent) +import Testlib.ResourcePool (acquireResources) import UnliftIO hiding (handle) testConsumeEventsOneWebSocket :: (HasCallStack) => App () @@ -38,10 +42,10 @@ testConsumeEventsOneWebSocket = do e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" e %. "data.event.payload.0.client.id" `shouldMatch` clientId e %. "data.delivery_tag" - assertNoEvent ws + assertNoEvent_ ws sendAck ws deliveryTag False - assertNoEvent ws + assertNoEvent_ ws handle <- randomHandle putHandle alice handle >>= assertSuccess @@ -80,7 +84,7 @@ testConsumeEventsForDifferentUsers = do e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" e %. "data.event.payload.0.client.id" `shouldMatch` clientId e %. "data.delivery_tag" - assertNoEvent ws + assertNoEvent_ ws sendAck ws deliveryTag False testConsumeEventsWhileHavingLegacyClients :: (HasCallStack) => App () @@ -137,7 +141,7 @@ testConsumeEventsAcks = do sendAck ws deliveryTag False runCodensity (createEventsWebSocket alice clientId) $ \ws -> do - assertNoEvent ws + assertNoEvent_ ws testConsumeEventsMultipleAcks :: (HasCallStack) => App () testConsumeEventsMultipleAcks = do @@ -161,7 +165,7 @@ testConsumeEventsMultipleAcks = do sendAck ws deliveryTag True runCodensity (createEventsWebSocket alice clientId) $ \ws -> do - assertNoEvent ws + assertNoEvent_ ws testConsumeEventsAckNewEventWithoutAckingOldOne :: (HasCallStack) => App () testConsumeEventsAckNewEventWithoutAckingOldOne = do @@ -195,7 +199,7 @@ testConsumeEventsAckNewEventWithoutAckingOldOne = do sendAck ws deliveryTagClientAdd False runCodensity (createEventsWebSocket alice clientId) $ \ws -> do - assertNoEvent ws + assertNoEvent_ ws testEventsDeadLettered :: (HasCallStack) => App () testEventsDeadLettered = do @@ -229,7 +233,7 @@ testEventsDeadLettered = do ackEvent ws e -- We've consumed the whole queue. - assertNoEvent ws + assertNoEvent_ ws testTransientEventsDoNotTriggerDeadLetters :: (HasCallStack) => App () testTransientEventsDoNotTriggerDeadLetters = do @@ -257,7 +261,7 @@ testTransientEventsDoNotTriggerDeadLetters = do sendTypingStatus alice selfConvId "started" >>= assertSuccess runCodensity (createEventsWebSocket alice clientId) $ \ws -> do - assertNoEvent ws + assertNoEvent_ ws testTransientEvents :: (HasCallStack) => App () testTransientEvents = do @@ -296,7 +300,7 @@ testTransientEvents = do e %. "data.event.payload.0.user.handle" `shouldMatch` handle ackEvent ws e - assertNoEvent ws + assertNoEvent_ ws testChannelLimit :: (HasCallStack) => App () testChannelLimit = withModifiedBackend @@ -318,16 +322,46 @@ testChannelLimit = withModifiedBackend lowerCodensity $ do for_ clients $ \c -> do ws <- createEventsWebSocket alice c - e <- Codensity $ \k -> assertEvent ws k - lift $ do + lift $ assertEvent ws $ \e -> do e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" e %. "data.event.payload.0.client.id" `shouldMatch` c - e %. "data.delivery_tag" -- the first client fails to connect because the server runs out of channels do ws <- createEventsWebSocket alice client0 - lift $ assertNoEvent ws + lift $ assertNoEvent_ ws + +testChannelKilled :: (HasCallStack) => App () +testChannelKilled = lowerCodensity $ do + pool <- lift $ asks (.resourcePool) + [backend] <- acquireResources 1 pool + domain <- startDynamicBackend backend mempty + alice <- lift $ randomUser domain def + [c1, c2] <- + lift + $ replicateM 2 + $ addClient alice def {acapabilities = Just ["consumable-notifications"]} + >>= getJSON 201 + >>= (%. "id") + >>= asString + + ws <- createEventsWebSocket alice c1 + lift $ do + assertEvent ws $ \e -> do + e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" + e %. "data.event.payload.0.client.id" `shouldMatch` c1 + ackEvent ws e + + assertEvent ws $ \e -> do + e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" + e %. "data.event.payload.0.client.id" `shouldMatch` c2 + + recoverAll + (constantDelay 500_000 <> limitRetries 10) + (const (killConnection backend)) + + noEvent <- assertNoEvent ws + noEvent `shouldMatch` WebSocketDied ---------------------------------------------------------------------- -- helpers @@ -428,15 +462,24 @@ assertEvent ws expectations = do addFailureContext ("event:\n" <> pretty) $ expectations e -assertNoEvent :: (HasCallStack) => EventWebSocket -> App () +data NoEvent = NoEvent | WebSocketDied + +instance ToJSON NoEvent where + toJSON NoEvent = toJSON "no-event" + toJSON WebSocketDied = toJSON "web-socket-died" + +assertNoEvent :: (HasCallStack) => EventWebSocket -> App NoEvent assertNoEvent ws = do timeout 1_000_000 (readChan ws.events) >>= \case - Nothing -> pure () - Just (Left _) -> pure () + Nothing -> pure NoEvent + Just (Left _) -> pure WebSocketDied Just (Right e) -> do eventJSON <- prettyJSON e assertFailure $ "Did not expect event: \n" <> eventJSON +assertNoEvent_ :: (HasCallStack) => EventWebSocket -> App () +assertNoEvent_ = void . assertNoEvent + consumeAllEvents :: EventWebSocket -> App () consumeAllEvents ws = do timeout 1_000_000 (readChan ws.events) >>= \case @@ -448,3 +491,25 @@ consumeAllEvents ws = do Just (Right e) -> do ackEvent ws e consumeAllEvents ws + +killConnection :: (HasCallStack) => BackendResource -> App () +killConnection backend = do + rc <- asks (.rabbitMQConfig) + let opts = + RabbitMqAdminOpts + { host = rc.host, + port = 0, + adminPort = fromIntegral rc.adminPort, + vHost = Text.pack backend.berVHost, + tls = Just $ RabbitMqTlsOpts Nothing True + } + servantClient <- liftIO $ mkRabbitMqAdminClientEnv opts + name <- do + connections <- liftIO $ listConnectionsByVHost servantClient opts.vHost + connection <- + assertOne + [ c | c <- connections, c.userProvidedName == Just (Text.pack "pool 0") + ] + pure connection.name + + void $ liftIO $ deleteConnection servantClient name diff --git a/libs/extended/src/Network/RabbitMqAdmin.hs b/libs/extended/src/Network/RabbitMqAdmin.hs index 77d65afc676..acc6bf8c920 100644 --- a/libs/extended/src/Network/RabbitMqAdmin.hs +++ b/libs/extended/src/Network/RabbitMqAdmin.hs @@ -1,7 +1,7 @@ -- | Perhaps this module should be a separate package and published to hackage. module Network.RabbitMqAdmin where -import Data.Aeson +import Data.Aeson as Aeson import Imports import Servant import Servant.Client @@ -33,6 +33,19 @@ data AdminAPI route = AdminAPI :> "queues" :> Capture "vhost" VHost :> Capture "queue" QueueName + :> DeleteNoContent, + listConnectionsByVHost :: + route + :- "api" + :> "vhosts" + :> Capture "vhost" Text + :> "connections" + :> Get '[JSON] [Connection], + deleteConnection :: + route + :- "api" + :> "connections" + :> Capture "name" Text :> DeleteNoContent } deriving (Generic) @@ -45,6 +58,9 @@ data AuthenticatedAPI route = AuthenticatedAPI } deriving (Generic) +jsonOptions :: Aeson.Options +jsonOptions = defaultOptions {fieldLabelModifier = camelTo2 '_'} + data Queue = Queue {name :: Text, vhost :: Text} deriving (Show, Eq, Generic) @@ -52,6 +68,18 @@ instance FromJSON Queue instance ToJSON Queue +data Connection = Connection + { userProvidedName :: Maybe Text, + name :: Text + } + deriving (Eq, Show, Generic) + +instance FromJSON Connection where + parseJSON = genericParseJSON jsonOptions + +instance ToJSON Connection where + toJSON = genericToJSON jsonOptions + adminClient :: BasicAuthData -> AdminAPI (AsClientT ClientM) adminClient ba = fromServant $ clientWithAuth.api ba where diff --git a/services/background-worker/test/Test/Wire/BackendNotificationPusherSpec.hs b/services/background-worker/test/Test/Wire/BackendNotificationPusherSpec.hs index 322ccddd148..7e63fb10f44 100644 --- a/services/background-worker/test/Test/Wire/BackendNotificationPusherSpec.hs +++ b/services/background-worker/test/Test/Wire/BackendNotificationPusherSpec.hs @@ -348,7 +348,9 @@ mockApi :: MockRabbitMqAdmin -> AdminAPI (AsServerT Servant.Handler) mockApi mockAdmin = AdminAPI { listQueuesByVHost = mockListQueuesByVHost mockAdmin, - deleteQueue = mockListDeleteQueue mockAdmin + deleteQueue = mockListDeleteQueue mockAdmin, + listConnectionsByVHost = mockListConnectionsByVHost mockAdmin, + deleteConnection = mockDeleteConnection mockAdmin } mockListQueuesByVHost :: MockRabbitMqAdmin -> Text -> Maybe Text -> Maybe Bool -> Servant.Handler [Queue] @@ -362,6 +364,12 @@ mockListDeleteQueue :: MockRabbitMqAdmin -> Text -> Text -> Servant.Handler NoCo mockListDeleteQueue _ _ _ = do pure NoContent +mockListConnectionsByVHost :: MockRabbitMqAdmin -> Text -> Servant.Handler [Connection] +mockListConnectionsByVHost _ _ = pure [] + +mockDeleteConnection :: MockRabbitMqAdmin -> Text -> Servant.Handler NoContent +mockDeleteConnection _ _ = pure NoContent + mockRabbitMqAdminApp :: MockRabbitMqAdmin -> Application mockRabbitMqAdminApp mockAdmin = genericServe (mockApi mockAdmin) diff --git a/services/cannon/src/Cannon/RabbitMq.hs b/services/cannon/src/Cannon/RabbitMq.hs index d5a228bd410..ea1ef92feb6 100644 --- a/services/cannon/src/Cannon/RabbitMq.hs +++ b/services/cannon/src/Cannon/RabbitMq.hs @@ -25,6 +25,7 @@ import Control.Retry import Data.ByteString.Conversion import Data.List.Extra import Data.Map qualified as Map +import Data.Text qualified as T import Data.Timeout import Imports hiding (threadDelay) import Network.AMQP qualified as Q @@ -59,7 +60,8 @@ data RabbitMqPool key = RabbitMqPool data RabbitMqPoolOptions = RabbitMqPoolOptions { maxConnections :: Int, maxChannels :: Int, - endpoint :: AmqpEndpoint + endpoint :: AmqpEndpoint, + retryEnabled :: Bool } createRabbitMqPool :: (Ord key) => RabbitMqPoolOptions -> Logger -> Codensity IO (RabbitMqPool key) @@ -176,6 +178,11 @@ createConnection pool = mask_ $ do openConnection :: RabbitMqPool key -> IO Q.Connection openConnection pool = do + -- This might not be the correct connection ID that will eventually be + -- assigned to this connection, since there are potential races with other + -- connections being opened at the same time. However, this is only used to + -- name the connection, and we only rely on names for tests, so it is fine. + connId <- readTVarIO pool.nextId (username, password) <- readCredsFromEnv recovering rabbitMqRetryPolicy @@ -199,7 +206,9 @@ openConnection pool = do ], Q.coVHost = pool.opts.endpoint.vHost, Q.coAuth = [Q.plain username password], - Q.coTLSSettings = fmap Q.TLSCustom mTlsSettings + Q.coTLSSettings = fmap Q.TLSCustom mTlsSettings, + -- the name is used by tests to identify pool connections + Q.coName = Just ("pool " <> T.pack (show connId)) } ) @@ -233,11 +242,11 @@ createChannel pool queue key = do (_, Just (Q.ConnectionClosedException {})) -> do Log.info pool.logger $ Log.msg (Log.val "RabbitMQ connection was closed unexpectedly") - pure True + pure pool.opts.retryEnabled _ -> do unless (fromException e == Just AsyncCancelled) $ logException pool.logger "RabbitMQ channel closed" e - pure True + pure pool.opts.retryEnabled putMVar closedVar retry let manageChannel = do @@ -258,7 +267,9 @@ createChannel pool queue key = do putMVar inner chan void $ liftIO $ Q.consumeMsgs chan queue Q.Ack $ \(message, envelope) -> do putMVar msgVar (Just (message, envelope)) - takeMVar closedVar + retry <- takeMVar closedVar + void $ takeMVar inner + pure retry when retry manageChannel diff --git a/services/cannon/src/Cannon/Types.hs b/services/cannon/src/Cannon/Types.hs index 81773ed32a1..146bd5519bd 100644 --- a/services/cannon/src/Cannon/Types.hs +++ b/services/cannon/src/Cannon/Types.hs @@ -114,7 +114,8 @@ mkEnv external o cs l d conns p g t endpoint = do RabbitMqPoolOptions { endpoint = endpoint, maxConnections = o ^. rabbitMqMaxConnections, - maxChannels = o ^. rabbitMqMaxChannels + maxChannels = o ^. rabbitMqMaxChannels, + retryEnabled = False } pool <- createRabbitMqPool poolOpts l let wsEnv =