@@ -112,10 +112,11 @@ def __init__(self, hs: "HomeServer"):
112112 # with FederationHandlerRegistry.
113113 hs .get_directory_handler ()
114114
115- self ._federation_ratelimiter = hs .get_federation_ratelimiter ()
116-
117115 self ._server_linearizer = Linearizer ("fed_server" )
118- self ._transaction_linearizer = Linearizer ("fed_txn_handler" )
116+
117+ # origins that we are currently processing a transaction from.
118+ # a dict from origin to txn id.
119+ self ._active_transactions = {} # type: Dict[str, str]
119120
120121 # We cache results for transaction with the same ID
121122 self ._transaction_resp_cache = ResponseCache (
@@ -169,6 +170,33 @@ async def on_incoming_transaction(
169170
170171 logger .debug ("[%s] Got transaction" , transaction_id )
171172
173+ # Reject malformed transactions early: reject if too many PDUs/EDUs
174+ if len (transaction .pdus ) > 50 or ( # type: ignore
175+ hasattr (transaction , "edus" ) and len (transaction .edus ) > 100 # type: ignore
176+ ):
177+ logger .info ("Transaction PDU or EDU count too large. Returning 400" )
178+ return 400 , {}
179+
180+ # we only process one transaction from each origin at a time. We need to do
181+ # this check here, rather than in _on_incoming_transaction_inner so that we
182+ # don't cache the rejection in _transaction_resp_cache (so that if the txn
183+ # arrives again later, we can process it).
184+ current_transaction = self ._active_transactions .get (origin )
185+ if current_transaction and current_transaction != transaction_id :
186+ logger .warning (
187+ "Received another txn %s from %s while still processing %s" ,
188+ transaction_id ,
189+ origin ,
190+ current_transaction ,
191+ )
192+ return 429 , {
193+ "errcode" : Codes .UNKNOWN ,
194+ "error" : "Too many concurrent transactions" ,
195+ }
196+
197+ # CRITICAL SECTION: we must now not await until we populate _active_transactions
198+ # in _on_incoming_transaction_inner.
199+
172200 # We wrap in a ResponseCache so that we de-duplicate retried
173201 # transactions.
174202 return await self ._transaction_resp_cache .wrap (
@@ -182,26 +210,18 @@ async def on_incoming_transaction(
182210 async def _on_incoming_transaction_inner (
183211 self , origin : str , transaction : Transaction , request_time : int
184212 ) -> Tuple [int , Dict [str , Any ]]:
185- # Use a linearizer to ensure that transactions from a remote are
186- # processed in order.
187- with await self ._transaction_linearizer .queue (origin ):
188- # We rate limit here *after* we've queued up the incoming requests,
189- # so that we don't fill up the ratelimiter with blocked requests.
190- #
191- # This is important as the ratelimiter allows N concurrent requests
192- # at a time, and only starts ratelimiting if there are more requests
193- # than that being processed at a time. If we queued up requests in
194- # the linearizer/response cache *after* the ratelimiting then those
195- # queued up requests would count as part of the allowed limit of N
196- # concurrent requests.
197- with self ._federation_ratelimiter .ratelimit (origin ) as d :
198- await d
199-
200- result = await self ._handle_incoming_transaction (
201- origin , transaction , request_time
202- )
213+ # CRITICAL SECTION: the first thing we must do (before awaiting) is
214+ # add an entry to _active_transactions.
215+ assert origin not in self ._active_transactions
216+ self ._active_transactions [origin ] = transaction .transaction_id # type: ignore
203217
204- return result
218+ try :
219+ result = await self ._handle_incoming_transaction (
220+ origin , transaction , request_time
221+ )
222+ return result
223+ finally :
224+ del self ._active_transactions [origin ]
205225
206226 async def _handle_incoming_transaction (
207227 self , origin : str , transaction : Transaction , request_time : int
@@ -227,19 +247,6 @@ async def _handle_incoming_transaction(
227247
228248 logger .debug ("[%s] Transaction is new" , transaction .transaction_id ) # type: ignore
229249
230- # Reject if PDU count > 50 or EDU count > 100
231- if len (transaction .pdus ) > 50 or ( # type: ignore
232- hasattr (transaction , "edus" ) and len (transaction .edus ) > 100 # type: ignore
233- ):
234-
235- logger .info ("Transaction PDU or EDU count too large. Returning 400" )
236-
237- response = {}
238- await self .transaction_actions .set_response (
239- origin , transaction , 400 , response
240- )
241- return 400 , response
242-
243250 # We process PDUs and EDUs in parallel. This is important as we don't
244251 # want to block things like to device messages from reaching clients
245252 # behind the potentially expensive handling of PDUs.
0 commit comments