1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515import logging
16+ from itertools import chain
1617from typing import (
1718 TYPE_CHECKING ,
1819 AbstractSet ,
@@ -1131,12 +1132,33 @@ async def _get_joined_hosts(
11311132 else :
11321133 # The cache doesn't match the state group or prev state group,
11331134 # so we calculate the result from first principles.
1135+ #
1136+ # We need to fetch all hosts joined to the room according to `state` by
1137+ # inspecting all join memberships in `state`. However, if the `state` is
1138+ # relatively recent then many of its events are likely to be held in
1139+ # the current state of the room, which is easily available and likely
1140+ # cached.
1141+ #
1142+ # We therefore compute the set of `state` events not in the
1143+ # current state and only fetch those.
1144+ current_memberships = (
1145+ await self ._get_approximate_current_memberships_in_room (room_id )
1146+ )
1147+ unknown_state_events = {}
1148+ joined_users_in_current_state = []
1149+
1150+ for (type , state_key ), event_id in state .items ():
1151+ if event_id not in current_memberships :
1152+ unknown_state_events [type , state_key ] = event_id
1153+ elif current_memberships [event_id ] == Membership .JOIN :
1154+ joined_users_in_current_state .append (state_key )
1155+
11341156 joined_user_ids = await self .get_joined_user_ids_from_state (
1135- room_id , state
1157+ room_id , unknown_state_events
11361158 )
11371159
11381160 cache .hosts_to_joined_users = {}
1139- for user_id in joined_user_ids :
1161+ for user_id in chain ( joined_user_ids , joined_users_in_current_state ) :
11401162 host = intern_string (get_domain_from_id (user_id ))
11411163 cache .hosts_to_joined_users .setdefault (host , set ()).add (user_id )
11421164
@@ -1147,6 +1169,27 @@ async def _get_joined_hosts(
11471169
11481170 return frozenset (cache .hosts_to_joined_users )
11491171
1172+ # TODO: this _might_ turn out to need caching, let's see
1173+ async def _get_approximate_current_memberships_in_room (
1174+ self , room_id : str
1175+ ) -> Mapping [str , Optional [str ]]:
1176+ """Build a map from event id to membership, for all events in the current state.
1177+
1178+ The event ids of non-memberships events (e.g. `m.room.power_levels`) are present
1179+ in the result, mapped to values of `None`.
1180+
1181+ The result is approximate for partially-joined rooms. It is fully accurate
1182+ for fully-joined rooms.
1183+ """
1184+
1185+ rows = await self .db_pool .simple_select_list (
1186+ "current_state_events" ,
1187+ keyvalues = {"room_id" : room_id },
1188+ retcols = ("event_id" , "membership" ),
1189+ desc = "has_completed_background_updates" ,
1190+ )
1191+ return {row ["event_id" ]: row ["membership" ] for row in rows }
1192+
11501193 @cached (max_entries = 10000 )
11511194 def _get_joined_hosts_cache (self , room_id : str ) -> "_JoinedHostsCache" :
11521195 return _JoinedHostsCache ()
0 commit comments