1313#
1414
1515import logging
16- from typing import TYPE_CHECKING , Dict , Optional , Tuple
16+ from typing import TYPE_CHECKING , Optional
1717
1818import attr
1919
20- from synapse .api .errors import SlidingSyncUnknownPosition
2120from synapse .logging .opentracing import trace
21+ from synapse .storage .databases .main import DataStore
2222from synapse .types import SlidingSyncStreamToken
2323from synapse .types .handlers .sliding_sync import (
2424 MutablePerConnectionState ,
@@ -61,22 +61,9 @@ class SlidingSyncConnectionStore:
6161 to mapping of room ID to `HaveSentRoom`.
6262 """
6363
64- # `(user_id, conn_id)` -> `connection_position` -> `PerConnectionState`
65- _connections : Dict [Tuple [str , str ], Dict [int , PerConnectionState ]] = attr .Factory (
66- dict
67- )
64+ store : "DataStore"
6865
69- async def is_valid_token (
70- self , sync_config : SlidingSyncConfig , connection_token : int
71- ) -> bool :
72- """Return whether the connection token is valid/recognized"""
73- if connection_token == 0 :
74- return True
75-
76- conn_key = self ._get_connection_key (sync_config )
77- return connection_token in self ._connections .get (conn_key , {})
78-
79- async def get_per_connection_state (
66+ async def get_and_clear_connection_positions (
8067 self ,
8168 sync_config : SlidingSyncConfig ,
8269 from_token : Optional [SlidingSyncStreamToken ],
@@ -86,23 +73,21 @@ async def get_per_connection_state(
8673 Raises:
8774 SlidingSyncUnknownPosition if the connection_token is unknown
8875 """
89- if from_token is None :
76+ # If this is our first request, there is no previous connection state to fetch out of the database
77+ if from_token is None or from_token .connection_position == 0 :
9078 return PerConnectionState ()
9179
92- connection_position = from_token .connection_position
93- if connection_position == 0 :
94- # Initial sync (request without a `from_token`) starts at `0` so
95- # there is no existing per-connection state
96- return PerConnectionState ()
97-
98- conn_key = self ._get_connection_key (sync_config )
99- sync_statuses = self ._connections .get (conn_key , {})
100- connection_state = sync_statuses .get (connection_position )
80+ conn_id = sync_config .conn_id or ""
10181
102- if connection_state is None :
103- raise SlidingSyncUnknownPosition ()
82+ device_id = sync_config . requester . device_id
83+ assert device_id is not None
10484
105- return connection_state
85+ return await self .store .get_and_clear_connection_positions (
86+ sync_config .user .to_string (),
87+ device_id ,
88+ conn_id ,
89+ from_token .connection_position ,
90+ )
10691
10792 @trace
10893 async def record_new_state (
@@ -116,85 +101,28 @@ async def record_new_state(
116101 If there are no changes to the state this may return the same token as
117102 the existing per-connection state.
118103 """
119- prev_connection_token = 0
120- if from_token is not None :
121- prev_connection_token = from_token .connection_position
122-
123104 if not new_connection_state .has_updates ():
124- return prev_connection_token
125-
126- conn_key = self ._get_connection_key (sync_config )
127- sync_statuses = self ._connections .setdefault (conn_key , {})
128-
129- # Generate a new token, removing any existing entries in that token
130- # (which can happen if requests get resent).
131- new_store_token = prev_connection_token + 1
132- sync_statuses .pop (new_store_token , None )
133-
134- # We copy the `MutablePerConnectionState` so that the inner `ChainMap`s
135- # don't grow forever.
136- sync_statuses [new_store_token ] = new_connection_state .copy ()
137-
138- return new_store_token
105+ if from_token is not None :
106+ return from_token .connection_position
107+ else :
108+ return 0
109+
110+ # A from token with a zero connection position means there was no
111+ # previously stored connection state, so we treat a zero the same as
112+ # there being no previous position.
113+ previous_connection_position = None
114+ if from_token is not None and from_token .connection_position != 0 :
115+ previous_connection_position = from_token .connection_position
139116
140- @trace
141- async def mark_token_seen (
142- self ,
143- sync_config : SlidingSyncConfig ,
144- from_token : Optional [SlidingSyncStreamToken ],
145- ) -> None :
146- """We have received a request with the given token, so we can clear out
147- any other tokens associated with the connection.
148-
149- If there is no from token then we have started afresh, and so we delete
150- all tokens associated with the device.
151- """
152- # Clear out any tokens for the connection that doesn't match the one
153- # from the request.
154-
155- conn_key = self ._get_connection_key (sync_config )
156- sync_statuses = self ._connections .pop (conn_key , {})
157- if from_token is None :
158- return
159-
160- sync_statuses = {
161- connection_token : room_statuses
162- for connection_token , room_statuses in sync_statuses .items ()
163- if connection_token == from_token .connection_position
164- }
165- if sync_statuses :
166- self ._connections [conn_key ] = sync_statuses
167-
168- @staticmethod
169- def _get_connection_key (sync_config : SlidingSyncConfig ) -> Tuple [str , str ]:
170- """Return a unique identifier for this connection.
171-
172- The first part is simply the user ID.
173-
174- The second part is generally a combination of device ID and conn_id.
175- However, both these two are optional (e.g. puppet access tokens don't
176- have device IDs), so this handles those edge cases.
177-
178- We use this over the raw `conn_id` to avoid clashes between different
179- clients that use the same `conn_id`. Imagine a user uses a web client
180- that uses `conn_id: main_sync_loop` and an Android client that also has
181- a `conn_id: main_sync_loop`.
182- """
183-
184- user_id = sync_config .user .to_string ()
185-
186- # Only one sliding sync connection is allowed per given conn_id (empty
187- # or not).
188117 conn_id = sync_config .conn_id or ""
189118
190- if sync_config .requester .device_id :
191- return (user_id , f"D/{ sync_config .requester .device_id } /{ conn_id } " )
192-
193- if sync_config .requester .access_token_id :
194- # If we don't have a device, then the access token ID should be a
195- # stable ID.
196- return (user_id , f"A/{ sync_config .requester .access_token_id } /{ conn_id } " )
119+ device_id = sync_config .requester .device_id
120+ assert device_id is not None
197121
198- # If we have neither then its likely an AS or some weird token. Either
199- # way we can just fail here.
200- raise Exception ("Cannot use sliding sync with access token type" )
122+ return await self .store .persist_per_connection_state (
123+ sync_config .user .to_string (),
124+ device_id ,
125+ conn_id ,
126+ previous_connection_position ,
127+ new_connection_state ,
128+ )
0 commit comments