@@ -16,7 +16,8 @@ module Ouroboros.Consensus.MiniProtocol.ObjectDiffusion.Inbound.V2.State
1616
1717import Control.Concurrent.Class.MonadSTM.Strict
1818import Control.Concurrent.Class.MonadSTM.TSem
19- import Control.Exception (assert )
19+ import Control.Exception (assert , throw )
20+ import Control.Monad (when )
2021import Control.Tracer (Tracer , traceWith )
2122import Data.Foldable qualified as Foldable
2223import Data.Map.Strict (Map , findWithDefault )
@@ -81,6 +82,7 @@ onRequestIdsImpl
8182 let
8283 -- We compute the ids to ack and new state of the FIFO based on the number of ids to ack given by the decision logic
8384 (idsToAck, dpsOutstandingFifo') =
85+ assert (StrictSeq. length dpsOutstandingFifo >= fromIntegral numIdsToAck) $
8486 StrictSeq. splitAt
8587 (fromIntegral numIdsToAck)
8688 dpsOutstandingFifo
@@ -143,6 +145,10 @@ onRequestObjectsImpl
143145 dgsPeerStates' =
144146 Map. adjust
145147 ( \ ps@ DecisionPeerState {dpsObjectsAvailableIds, dpsObjectsInflightIds} ->
148+ assert
149+ ( objectIds `Set.isSubsetOf` dpsObjectsAvailableIds
150+ && Set. null (objectIds `Set.intersection` dpsObjectsInflightIds)
151+ ) $
146152 ps
147153 { dpsObjectsAvailableIds = dpsObjectsAvailableIds \\ objectIds
148154 , dpsObjectsInflightIds = dpsObjectsInflightIds `Set.union` objectIds
@@ -169,15 +175,32 @@ onReceiveIds ::
169175 -- | received `objectId`s
170176 m ()
171177onReceiveIds odTracer decisionTracer globalStateVar peerAddr numIdsInitiallyRequested receivedIds = do
178+ peerState <- atomically $ ((Map. ! peerAddr) . dgsPeerStates) <$> readTVar globalStateVar
179+ checkProtocolErrors peerState numIdsInitiallyRequested receivedIds
172180 globalState' <- atomically $ do
173181 stateTVar
174182 globalStateVar
175183 ( \ globalState ->
176184 let globalState' = onReceiveIdsImpl peerAddr numIdsInitiallyRequested receivedIds globalState
177- in (globalState', globalState')
185+ in (globalState', globalState')
178186 )
179187 traceWith odTracer (TraceObjectDiffusionInboundReceivedIds (length receivedIds))
180188 traceWith decisionTracer (TraceDecisionLogicGlobalStateUpdated " onReceiveIds" globalState')
189+ where
190+ checkProtocolErrors ::
191+ DecisionPeerState objectId object ->
192+ NumObjectIdsReq ->
193+ [objectId ] ->
194+ m ()
195+ checkProtocolErrors DecisionPeerState {dpsObjectsAvailableIds, dpsObjectsInflightIds} nReq ids = do
196+ when (length ids > fromIntegral nReq) $ throw ProtocolErrorObjectIdsNotRequested
197+ let idSet = Set. fromList ids
198+ when (length ids /= Set. size idSet) $ throw ProtocolErrorObjectIdsDuplicate
199+ when
200+ -- TODO also check for IDs in pool
201+ ( (not $ Set. null $ idSet `Set.intersection` dpsObjectsAvailableIds)
202+ || (not $ Set. null $ idSet `Set.intersection` dpsObjectsInflightIds)
203+ ) $ throw ProtocolErrorObjectIdAlreadyKnown
181204
182205onReceiveIdsImpl ::
183206 forall peerAddr object objectId .
@@ -253,13 +276,15 @@ onReceiveObjects ::
253276 ObjectPoolWriter objectId object m ->
254277 ObjectPoolSem m ->
255278 peerAddr ->
279+ -- | requested objects
280+ Set objectId ->
256281 -- | received objects
257282 [object ] ->
258283 m ()
259- onReceiveObjects odTracer tracer globalStateVar objectPoolWriter poolSem peerAddr objectsReceived = do
284+ onReceiveObjects odTracer tracer globalStateVar objectPoolWriter poolSem peerAddr objectsRequestedIds objectsReceived = do
260285 let getId = opwObjectId objectPoolWriter
261286 let objectsReceivedMap = Map. fromList $ (\ obj -> (getId obj, obj)) <$> objectsReceived
262-
287+ checkProtocolErrors objectsRequestedIds objectsReceivedMap
263288 globalState' <- atomically $ do
264289 stateTVar
265290 globalStateVar
@@ -281,6 +306,15 @@ onReceiveObjects odTracer tracer globalStateVar objectPoolWriter poolSem peerAdd
281306 poolSem
282307 peerAddr
283308 objectsReceivedMap
309+ where
310+ checkProtocolErrors ::
311+ Set objectId ->
312+ Map objectId object ->
313+ m ()
314+ checkProtocolErrors requested received' = do
315+ let received = Map. keysSet received'
316+ when (not $ Set. null $ requested \\ received) $ throw ProtocolErrorObjectMissing
317+ when (not $ Set. null $ received \\ requested) $ throw ProtocolErrorObjectNotRequested
284318
285319onReceiveObjectsImpl ::
286320 forall peerAddr object objectId .
@@ -314,7 +348,7 @@ onReceiveObjectsImpl
314348 dgsPeerStates
315349
316350 -- subtract requested from in-flight
317- dpsObjectsInflightIds' =
351+ dpsObjectsInflightIds' = assert (objectsReceivedIds `Set.isSubsetOf` dpsObjectsInflightIds) $
318352 dpsObjectsInflightIds \\ objectsReceivedIds
319353
320354 dpsObjectsOwtPool' = dpsObjectsOwtPool <> objectsReceived
0 commit comments