@@ -950,54 +950,35 @@ async def on_make_knock_request(
950
950
951
951
return event
952
952
953
- async def get_state_for_pdu (self , room_id : str , event_id : str ) -> List [EventBase ]:
954
- """Returns the state at the event. i.e. not including said event."""
955
-
956
- event = await self .store .get_event (event_id , check_room_id = room_id )
957
-
958
- state_groups = await self .state_store .get_state_groups (room_id , [event_id ])
959
-
960
- if state_groups :
961
- _ , state = list (state_groups .items ()).pop ()
962
- results = {(e .type , e .state_key ): e for e in state }
963
-
964
- if event .is_state ():
965
- # Get previous state
966
- if "replaces_state" in event .unsigned :
967
- prev_id = event .unsigned ["replaces_state" ]
968
- if prev_id != event .event_id :
969
- prev_event = await self .store .get_event (prev_id )
970
- results [(event .type , event .state_key )] = prev_event
971
- else :
972
- del results [(event .type , event .state_key )]
973
-
974
- res = list (results .values ())
975
- return res
976
- else :
977
- return []
978
-
979
953
async def get_state_ids_for_pdu (self , room_id : str , event_id : str ) -> List [str ]:
980
954
"""Returns the state at the event. i.e. not including said event."""
981
955
event = await self .store .get_event (event_id , check_room_id = room_id )
956
+ if event .internal_metadata .outlier :
957
+ raise NotFoundError ("State not known at event %s" % (event_id ,))
982
958
983
959
state_groups = await self .state_store .get_state_groups_ids (room_id , [event_id ])
984
960
985
- if state_groups :
986
- _ , state = list (state_groups .items ()).pop ()
987
- results = state
961
+ # get_state_groups_ids should return exactly one result
962
+ assert len (state_groups ) == 1
988
963
989
- if event .is_state ():
990
- # Get previous state
991
- if "replaces_state" in event .unsigned :
992
- prev_id = event .unsigned ["replaces_state" ]
993
- if prev_id != event .event_id :
994
- results [(event .type , event .state_key )] = prev_id
995
- else :
996
- results .pop ((event .type , event .state_key ), None )
964
+ state_map = next (iter (state_groups .values ()))
997
965
998
- return list (results .values ())
999
- else :
1000
- return []
966
+ state_key = event .get_state_key ()
967
+ if state_key is not None :
968
+ # the event was not rejected (get_event raises a NotFoundError for rejected
969
+ # events) so the state at the event should include the event itself.
970
+ assert (
971
+ state_map .get ((event .type , state_key )) == event .event_id
972
+ ), "State at event did not include event itself"
973
+
974
+ # ... but we need the state *before* that event
975
+ if "replaces_state" in event .unsigned :
976
+ prev_id = event .unsigned ["replaces_state" ]
977
+ state_map [(event .type , state_key )] = prev_id
978
+ else :
979
+ del state_map [(event .type , state_key )]
980
+
981
+ return list (state_map .values ())
1001
982
1002
983
async def on_backfill_request (
1003
984
self , origin : str , room_id : str , pdu_list : List [str ], limit : int
0 commit comments