From 697e3b53f19da74a133f5bb1b7529bdc59f037dd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 12 Jul 2024 10:42:21 +0100 Subject: [PATCH] Add assertion to MultiWriterTokens --- synapse/types/__init__.py | 13 +++++++++++++ tests/test_types.py | 17 +++++------------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index a8af17e18c..eae6f0179b 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -458,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta): represented by a default `stream` attribute and a map of instance name to stream position of any writers that are ahead of the default stream position. + + The values in `instance_map` must be greater than the `stream` attribute. """ stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True) @@ -472,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta): kw_only=True, ) + def __attrs_post_init__(self): + # Enforce that all instances have a value greater than the min stream + # position. + for v in self.instance_map.values(): + if v < self.stream: + raise ValueError( + "'instance_map' includes a stream position before the main 'stream' attribute" + ) + @classmethod @abc.abstractmethod async def parse(cls, store: "DataStore", string: str) -> "Self": @@ -641,6 +652,8 @@ def __attrs_post_init__(self) -> None: "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'." ) + super().__attrs_post_init__() + @classmethod async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": try: diff --git a/tests/test_types.py b/tests/test_types.py index 498eea40a9..3af05eb7f0 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -161,16 +161,9 @@ def test_instance_map(self) -> None: parsed_token = self.get_success(RoomStreamToken.parse(store, string_token)) self.assertEqual(parsed_token, token) - @skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres") - def test_instance_map_behind(self) -> None: - """Test for stream token with instance map, where instance map entries - are from before stream token.""" - store = self.hs.get_datastores().main + def test_instance_map_assertion(self) -> None: + """Test that we assert values in the instance map are greater than the + min stream position""" - token = RoomStreamToken(stream=5, instance_map=immutabledict({"foo": 4})) - - string_token = self.get_success(token.to_string(store)) - self.assertEqual(string_token, "s5") - - parsed_token = self.get_success(RoomStreamToken.parse(store, string_token)) - self.assertEqual(parsed_token, RoomStreamToken(stream=5)) + with self.assertRaises(ValueError): + RoomStreamToken(stream=5, instance_map=immutabledict({"foo": 4}))