@@ -138,26 +138,29 @@ def on_receive_pdu(self, origin, pdu, backfilled, state=None,
138
138
if state and auth_chain is not None :
139
139
# If we have any state or auth_chain given to us by the replication
140
140
# layer, then we should handle them (if we haven't before.)
141
+
142
+ event_infos = []
143
+
141
144
for e in itertools .chain (auth_chain , state ):
142
145
if e .event_id in seen_ids :
143
146
continue
144
-
145
147
e .internal_metadata .outlier = True
146
- try :
147
- auth_ids = [e_id for e_id , _ in e .auth_events ]
148
- auth = {
149
- (e .type , e .state_key ): e for e in auth_chain
150
- if e .event_id in auth_ids
151
- }
152
- yield self ._handle_new_event (
153
- origin , e , auth_events = auth
154
- )
155
- seen_ids .add (e .event_id )
156
- except :
157
- logger .exception (
158
- "Failed to handle state event %s" ,
159
- e .event_id ,
160
- )
148
+ auth_ids = [e_id for e_id , _ in e .auth_events ]
149
+ auth = {
150
+ (e .type , e .state_key ): e for e in auth_chain
151
+ if e .event_id in auth_ids
152
+ }
153
+ event_infos .append ({
154
+ "event" : e ,
155
+ "auth_events" : auth ,
156
+ })
157
+ seen_ids .add (e .event_id )
158
+
159
+ yield self ._handle_new_events (
160
+ origin ,
161
+ event_infos ,
162
+ outliers = True
163
+ )
161
164
162
165
try :
163
166
_ , event_stream_id , max_stream_id = yield self ._handle_new_event (
@@ -292,49 +295,44 @@ def backfill(self, dest, room_id, limit, extremities=[]):
292
295
).addErrback (unwrapFirstError )
293
296
auth_events .update ({a .event_id : a for a in results })
294
297
295
- yield defer .gatherResults (
296
- [
297
- self ._handle_new_event (
298
- dest , a ,
299
- auth_events = {
300
- (auth_events [a_id ].type , auth_events [a_id ].state_key ):
301
- auth_events [a_id ]
302
- for a_id , _ in a .auth_events
303
- },
304
- )
305
- for a in auth_events .values ()
306
- if a .event_id not in seen_events
307
- ],
308
- consumeErrors = True ,
309
- ).addErrback (unwrapFirstError )
310
-
311
- yield defer .gatherResults (
312
- [
313
- self ._handle_new_event (
314
- dest , event_map [e_id ],
315
- state = events_to_state [e_id ],
316
- backfilled = True ,
317
- auth_events = {
318
- (auth_events [a_id ].type , auth_events [a_id ].state_key ):
319
- auth_events [a_id ]
320
- for a_id , _ in event_map [e_id ].auth_events
321
- },
322
- )
323
- for e_id in events_to_state
324
- ],
325
- consumeErrors = True
326
- ).addErrback (unwrapFirstError )
298
+ ev_infos = []
299
+ for a in auth_events .values ():
300
+ if a .event_id in seen_events :
301
+ continue
302
+ ev_infos .append ({
303
+ "event" : a ,
304
+ "auth_events" : {
305
+ (auth_events [a_id ].type , auth_events [a_id ].state_key ):
306
+ auth_events [a_id ]
307
+ for a_id , _ in a .auth_events
308
+ }
309
+ })
310
+
311
+ for e_id in events_to_state :
312
+ ev_infos .append ({
313
+ "event" : event_map [e_id ],
314
+ "state" : events_to_state [e_id ],
315
+ "auth_events" : {
316
+ (auth_events [a_id ].type , auth_events [a_id ].state_key ):
317
+ auth_events [a_id ]
318
+ for a_id , _ in event_map [e_id ].auth_events
319
+ }
320
+ })
327
321
328
322
events .sort (key = lambda e : e .depth )
329
323
330
324
for event in events :
331
325
if event in events_to_state :
332
326
continue
333
327
334
- yield self ._handle_new_event (
335
- dest , event ,
336
- backfilled = True ,
337
- )
328
+ ev_infos .append ({
329
+ "event" : event ,
330
+ })
331
+
332
+ yield self ._handle_new_events (
333
+ dest , ev_infos ,
334
+ backfilled = True ,
335
+ )
338
336
339
337
defer .returnValue (events )
340
338
@@ -600,32 +598,22 @@ def do_invite_join(self, target_hosts, room_id, joinee, content, snapshot):
600
598
# FIXME
601
599
pass
602
600
603
- yield self ._handle_auth_events (
604
- origin , [e for e in auth_chain if e .event_id != event .event_id ]
605
- )
606
-
607
- @defer .inlineCallbacks
608
- def handle_state (e ):
601
+ ev_infos = []
602
+ for e in itertools .chain (state , auth_chain ):
609
603
if e .event_id == event .event_id :
610
- return
604
+ continue
611
605
612
606
e .internal_metadata .outlier = True
613
- try :
614
- auth_ids = [e_id for e_id , _ in e .auth_events ]
615
- auth = {
607
+ auth_ids = [e_id for e_id , _ in e .auth_events ]
608
+ ev_infos .append ({
609
+ "event" : e ,
610
+ "auth_events" : {
616
611
(e .type , e .state_key ): e for e in auth_chain
617
612
if e .event_id in auth_ids
618
613
}
619
- yield self ._handle_new_event (
620
- origin , e , auth_events = auth
621
- )
622
- except :
623
- logger .exception (
624
- "Failed to handle state event %s" ,
625
- e .event_id ,
626
- )
614
+ })
627
615
628
- yield defer . DeferredList ([ handle_state ( e ) for e in state ] )
616
+ yield self . _handle_new_events ( origin , ev_infos , outliers = True )
629
617
630
618
auth_ids = [e_id for e_id , _ in event .auth_events ]
631
619
auth_events = {
@@ -940,11 +928,54 @@ def _on_user_joined(self, user, room_id):
940
928
def _handle_new_event (self , origin , event , state = None , backfilled = False ,
941
929
current_state = None , auth_events = None ):
942
930
943
- logger .debug (
944
- "_handle_new_event: %s, sigs: %s" ,
945
- event .event_id , event .signatures ,
931
+ outlier = event .internal_metadata .is_outlier ()
932
+
933
+ context = yield self ._prep_event (
934
+ origin , event ,
935
+ state = state ,
936
+ backfilled = backfilled ,
937
+ current_state = current_state ,
938
+ auth_events = auth_events ,
946
939
)
947
940
941
+ event_stream_id , max_stream_id = yield self .store .persist_event (
942
+ event ,
943
+ context = context ,
944
+ backfilled = backfilled ,
945
+ is_new_state = (not outlier and not backfilled ),
946
+ current_state = current_state ,
947
+ )
948
+
949
+ defer .returnValue ((context , event_stream_id , max_stream_id ))
950
+
951
+ @defer .inlineCallbacks
952
+ def _handle_new_events (self , origin , event_infos , backfilled = False ,
953
+ outliers = False ):
954
+ contexts = yield defer .gatherResults (
955
+ [
956
+ self ._prep_event (
957
+ origin ,
958
+ ev_info ["event" ],
959
+ state = ev_info .get ("state" ),
960
+ backfilled = backfilled ,
961
+ auth_events = ev_info .get ("auth_events" ),
962
+ )
963
+ for ev_info in event_infos
964
+ ]
965
+ )
966
+
967
+ yield self .store .persist_events (
968
+ [
969
+ (ev_info ["event" ], context )
970
+ for ev_info , context in itertools .izip (event_infos , contexts )
971
+ ],
972
+ backfilled = backfilled ,
973
+ is_new_state = (not outliers and not backfilled ),
974
+ )
975
+
976
+ @defer .inlineCallbacks
977
+ def _prep_event (self , origin , event , state = None , backfilled = False ,
978
+ current_state = None , auth_events = None ):
948
979
outlier = event .internal_metadata .is_outlier ()
949
980
950
981
context = yield self .state_handler .compute_event_context (
@@ -954,13 +985,6 @@ def _handle_new_event(self, origin, event, state=None, backfilled=False,
954
985
if not auth_events :
955
986
auth_events = context .current_state
956
987
957
- logger .debug (
958
- "_handle_new_event: %s, auth_events: %s" ,
959
- event .event_id , auth_events ,
960
- )
961
-
962
- is_new_state = not outlier
963
-
964
988
# This is a hack to fix some old rooms where the initial join event
965
989
# didn't reference the create event in its auth events.
966
990
if event .type == EventTypes .Member and not event .auth_events :
@@ -984,26 +1008,7 @@ def _handle_new_event(self, origin, event, state=None, backfilled=False,
984
1008
985
1009
context .rejected = RejectedReason .AUTH_ERROR
986
1010
987
- # FIXME: Don't store as rejected with AUTH_ERROR if we haven't
988
- # seen all the auth events.
989
- yield self .store .persist_event (
990
- event ,
991
- context = context ,
992
- backfilled = backfilled ,
993
- is_new_state = False ,
994
- current_state = current_state ,
995
- )
996
- raise
997
-
998
- event_stream_id , max_stream_id = yield self .store .persist_event (
999
- event ,
1000
- context = context ,
1001
- backfilled = backfilled ,
1002
- is_new_state = (is_new_state and not backfilled ),
1003
- current_state = current_state ,
1004
- )
1005
-
1006
- defer .returnValue ((context , event_stream_id , max_stream_id ))
1011
+ defer .returnValue (context )
1007
1012
1008
1013
@defer .inlineCallbacks
1009
1014
def on_query_auth (self , origin , event_id , remote_auth_chain , rejects ,
@@ -1066,14 +1071,24 @@ def on_get_missing_events(self, origin, room_id, earliest_events,
1066
1071
@log_function
1067
1072
def do_auth (self , origin , event , context , auth_events ):
1068
1073
# Check if we have all the auth events.
1069
- have_events = yield self .store .have_events (
1070
- [e_id for e_id , _ in event .auth_events ]
1071
- )
1072
-
1074
+ current_state = set (e .event_id for e in auth_events .values ())
1073
1075
event_auth_events = set (e_id for e_id , _ in event .auth_events )
1076
+
1077
+ if event_auth_events - current_state :
1078
+ have_events = yield self .store .have_events (
1079
+ event_auth_events - current_state
1080
+ )
1081
+ else :
1082
+ have_events = {}
1083
+
1084
+ have_events .update ({
1085
+ e .event_id : ""
1086
+ for e in auth_events .values ()
1087
+ })
1088
+
1074
1089
seen_events = set (have_events .keys ())
1075
1090
1076
- missing_auth = event_auth_events - seen_events
1091
+ missing_auth = event_auth_events - seen_events - current_state
1077
1092
1078
1093
if missing_auth :
1079
1094
logger .info ("Missing auth: %s" , missing_auth )
0 commit comments