1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Dict , Iterable , List , Tuple
15+ from typing import Collection , Dict , List , Tuple
1616
1717from unpaddedbase64 import encode_base64
1818
19- from synapse .storage ._base import SQLBaseStore
20- from synapse .storage .types import Cursor
19+ from synapse .crypto .event_signing import compute_event_reference_hash
20+ from synapse .storage .databases .main .events_worker import (
21+ EventRedactBehaviour ,
22+ EventsWorkerStore ,
23+ )
2124from synapse .util .caches .descriptors import cached , cachedList
2225
2326
24- class SignatureWorkerStore (SQLBaseStore ):
27+ class SignatureWorkerStore (EventsWorkerStore ):
2528 @cached ()
2629 def get_event_reference_hash (self , event_id ):
2730 # This is a dummy function to allow get_event_reference_hashes
@@ -32,7 +35,7 @@ def get_event_reference_hash(self, event_id):
3235 cached_method_name = "get_event_reference_hash" , list_name = "event_ids" , num_args = 1
3336 )
3437 async def get_event_reference_hashes (
35- self , event_ids : Iterable [str ]
38+ self , event_ids : Collection [str ]
3639 ) -> Dict [str , Dict [str , bytes ]]:
3740 """Get all hashes for given events.
3841
@@ -42,17 +45,22 @@ async def get_event_reference_hashes(
4245 Returns:
4346 A mapping of event ID to a mapping of algorithm to hash.
4447 """
48+ events = await self .get_events (
49+ event_ids ,
50+ redact_behaviour = EventRedactBehaviour .AS_IS ,
51+ allow_rejected = True ,
52+ )
4553
46- def f ( txn ):
47- return {
48- event_id : self . _get_event_reference_hashes_txn ( txn , event_id )
49- for event_id in event_ids
50- }
54+ hashes : Dict [ str , Dict [ str , bytes ]] = {}
55+ for event_id in event_ids :
56+ event = events [ event_id ]
57+ ref_alg , ref_hash_bytes = compute_event_reference_hash ( event )
58+ hashes [ event . event_id ] = { ref_alg : ref_hash_bytes }
5159
52- return await self . db_pool . runInteraction ( "get_event_reference_hashes" , f )
60+ return hashes
5361
5462 async def add_event_hashes (
55- self , event_ids : Iterable [str ]
63+ self , event_ids : Collection [str ]
5664 ) -> List [Tuple [str , Dict [str , str ]]]:
5765 """
5866
@@ -70,24 +78,6 @@ async def add_event_hashes(
7078
7179 return list (encoded_hashes .items ())
7280
73- def _get_event_reference_hashes_txn (
74- self , txn : Cursor , event_id : str
75- ) -> Dict [str , bytes ]:
76- """Get all the hashes for a given PDU.
77- Args:
78- txn:
79- event_id: Id for the Event.
80- Returns:
81- A mapping of algorithm -> hash.
82- """
83- query = (
84- "SELECT algorithm, hash"
85- " FROM event_reference_hashes"
86- " WHERE event_id = ?"
87- )
88- txn .execute (query , (event_id ,))
89- return {k : v for k , v in txn }
90-
9181
9282class SignatureStore (SignatureWorkerStore ):
9383 """Persistence for event signatures and hashes"""
0 commit comments