diff --git a/crates/matrix-sdk-crypto/CHANGELOG.md b/crates/matrix-sdk-crypto/CHANGELOG.md index 1bfa77fad8b..10757de14e0 100644 --- a/crates/matrix-sdk-crypto/CHANGELOG.md +++ b/crates/matrix-sdk-crypto/CHANGELOG.md @@ -1,5 +1,8 @@ # unreleased +- Add method to mark a list of inbound group sessions as backed up: + `CyrptoStore::mark_inbound_group_sessions_as_backed_up` + - `OlmMachine::toggle_room_key_forwarding` is replaced by two separate methods: * `OlmMachine::set_room_key_requests_enabled`, which controls whether diff --git a/crates/matrix-sdk-crypto/src/backups/mod.rs b/crates/matrix-sdk-crypto/src/backups/mod.rs index 14f9f8c8288..86943becbe7 100644 --- a/crates/matrix-sdk-crypto/src/backups/mod.rs +++ b/crates/matrix-sdk-crypto/src/backups/mod.rs @@ -30,7 +30,7 @@ use std::{ use ruma::{ api::client::backup::RoomKeyBackup, serde::Raw, DeviceId, DeviceKeyAlgorithm, OwnedDeviceId, - OwnedRoomId, OwnedTransactionId, TransactionId, + OwnedRoomId, OwnedTransactionId, RoomId, TransactionId, }; use tokio::sync::RwLock; use tracing::{debug, info, instrument, trace, warn}; @@ -64,20 +64,10 @@ pub struct BackupMachine { struct PendingBackup { request_id: OwnedTransactionId, request: KeysBackupRequest, + /// Room ID : Sender Key : [Session IDs] sessions: BTreeMap>>, } -impl PendingBackup { - fn session_was_part_of_the_backup(&self, session: &InboundGroupSession) -> bool { - self.sessions - .get(session.room_id()) - .and_then(|r| { - r.get(&session.sender_key().to_base64()).map(|s| s.contains(session.session_id())) - }) - .unwrap_or(false) - } -} - impl From for OutgoingRequest { fn from(b: PendingBackup) -> Self { OutgoingRequest { request_id: b.request_id, request: Arc::new(b.request.into()) } @@ -477,25 +467,20 @@ impl BackupMachine { request_id: &TransactionId, ) -> Result<(), CryptoStoreError> { let mut request = self.pending_backup.write().await; - if let Some(r) = &*request { if r.request_id == request_id { - let sessions: Vec<_> = self - .store - .get_inbound_group_sessions() - .await? - .into_iter() - .filter(|s| r.session_was_part_of_the_backup(s)) + let room_and_session_ids: Vec<(&RoomId, &str)> = r + .sessions + .iter() + .flat_map(|(room_id, sender_key_to_session_ids)| { + std::iter::repeat(room_id).zip(sender_key_to_session_ids.values().flatten()) + }) + .map(|(room_id, session_id)| (room_id.as_ref(), session_id.as_str())) .collect(); - for session in &sessions { - session.mark_as_backed_up(); - } - trace!(request_id = ?r.request_id, keys = ?r.sessions, "Marking room keys as backed up"); - let changes = Changes { inbound_group_sessions: sessions, ..Default::default() }; - self.store.save_changes(changes).await?; + self.store.mark_inbound_group_sessions_as_backed_up(&room_and_session_ids).await?; let counts = self.store.inbound_group_session_counts().await?; @@ -563,6 +548,10 @@ impl BackupMachine { } /// Backup all the non-backed up room keys we know about + /// returns a tuple: ( + /// map of Room ID : RoomKeyBackup, + /// map of Room ID : map of Sender Key : set of Session IDs + /// ) async fn backup_keys( sessions: Vec, backup_key: &MegolmV1BackupKey, diff --git a/crates/matrix-sdk-crypto/src/store/integration_tests.rs b/crates/matrix-sdk-crypto/src/store/integration_tests.rs index 0ddbc8077da..fe940ff27d8 100644 --- a/crates/matrix-sdk-crypto/src/store/integration_tests.rs +++ b/crates/matrix-sdk-crypto/src/store/integration_tests.rs @@ -14,7 +14,7 @@ macro_rules! cryptostore_integration_tests { room_id, serde::{Base64, Raw}, to_device::DeviceIdOrAllDevices, - user_id, DeviceId, JsOption, OwnedDeviceId, OwnedUserId, TransactionId, UserId, + user_id, DeviceId, JsOption, OwnedDeviceId, OwnedUserId, RoomId, TransactionId, UserId }; use serde_json::value::to_raw_value; use $crate::{ @@ -292,6 +292,54 @@ macro_rules! cryptostore_integration_tests { assert_eq!(to_back_up, vec![session]) } + #[async_test] + async fn mark_inbound_group_sessions_as_backed_up() { + // Given a store exists with multiple unbacked-up sessions + let (account, store) = + get_loaded_store("mark_inbound_group_sessions_as_backed_up").await; + let room_id = &room_id!("!test:localhost"); + let mut sessions: Vec = Vec::with_capacity(10); + for i in 0..10 { + sessions.push(account.create_group_session_pair_with_defaults(room_id).await.1); + } + let changes = Changes { inbound_group_sessions: sessions.clone(), ..Default::default() }; + store.save_changes(changes).await.expect("Can't save group session"); + assert_eq!(store.inbound_group_sessions_for_backup(100).await.unwrap().len(), 10); + + fn session_info(session: &InboundGroupSession) -> (&RoomId, &str) { + (&session.room_id(), &session.session_id()) + } + + // When I mark some as backed up + let x = store.mark_inbound_group_sessions_as_backed_up(&[ + session_info(&sessions[1]), + session_info(&sessions[3]), + session_info(&sessions[5]), + session_info(&sessions[7]), + session_info(&sessions[9]), + ]).await.expect("Failed to mark sessions as backed up"); + + + // And ask which still need backing up + let to_back_up = store.inbound_group_sessions_for_backup(10).await.unwrap(); + let needs_backing_up = |i: usize| to_back_up.iter().any(|s| s.session_id() == sessions[i].session_id()); + + // Then the sessions we said were backed up no longer need backing up + assert!(!needs_backing_up(1)); + assert!(!needs_backing_up(3)); + assert!(!needs_backing_up(5)); + assert!(!needs_backing_up(7)); + assert!(!needs_backing_up(9)); + + // And the sessions we didn't mention still need backing up + assert!(needs_backing_up(0)); + assert!(needs_backing_up(2)); + assert!(needs_backing_up(4)); + assert!(needs_backing_up(6)); + assert!(needs_backing_up(8)); + assert_eq!(to_back_up.len(), 5); + } + #[async_test] async fn reset_inbound_group_session_for_backup() { let (account, store) = diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index 163d93b6b5f..8d7cdbb01fe 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -265,6 +265,20 @@ impl CryptoStore for MemoryStore { .collect()) } + async fn mark_inbound_group_sessions_as_backed_up( + &self, + room_and_session_ids: &[(&RoomId, &str)], + ) -> Result<()> { + for (room_id, session_id) in room_and_session_ids { + let session = self.inbound_group_sessions.get(room_id, session_id); + if let Some(session) = session { + session.mark_as_backed_up(); + self.inbound_group_sessions.add(session); + } + } + Ok(()) + } + async fn reset_backup_state(&self) -> Result<()> { for session in self.get_inbound_group_sessions().await? { session.reset_backup_state(); diff --git a/crates/matrix-sdk-crypto/src/store/traits.rs b/crates/matrix-sdk-crypto/src/store/traits.rs index 08924cac606..fbf867afe3f 100644 --- a/crates/matrix-sdk-crypto/src/store/traits.rs +++ b/crates/matrix-sdk-crypto/src/store/traits.rs @@ -113,6 +113,13 @@ pub trait CryptoStore: AsyncTraitDeps { limit: usize, ) -> Result, Self::Error>; + /// Mark the inbound group sessions with the supplied room and session IDs + /// as backed up + async fn mark_inbound_group_sessions_as_backed_up( + &self, + room_and_session_ids: &[(&RoomId, &str)], + ) -> Result<(), Self::Error>; + /// Reset the backup state of all the stored inbound group sessions. async fn reset_backup_state(&self) -> Result<(), Self::Error>; @@ -332,6 +339,13 @@ impl CryptoStore for EraseCryptoStoreError { self.0.inbound_group_sessions_for_backup(limit).await.map_err(Into::into) } + async fn mark_inbound_group_sessions_as_backed_up( + &self, + room_and_session_ids: &[(&RoomId, &str)], + ) -> Result<()> { + self.0.mark_inbound_group_sessions_as_backed_up(room_and_session_ids).await.map_err(Into::into) + } + async fn reset_backup_state(&self) -> Result<()> { self.0.reset_backup_state().await.map_err(Into::into) } diff --git a/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs b/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs index 79d7f1f7526..5fca6dfcf46 100644 --- a/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs +++ b/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs @@ -836,6 +836,28 @@ impl_crypto_store! { Ok(result) } + async fn mark_inbound_group_sessions_as_backed_up(&self, room_and_session_ids: &[(&RoomId, &str)]) -> Result<()> { + let tx = self + .inner + .transaction_on_one_with_mode( + keys::INBOUND_GROUP_SESSIONS_V2, + IdbTransactionMode::Readwrite, + )?; + + let object_store = tx.object_store(keys::INBOUND_GROUP_SESSIONS_V2)?; + + for (room_id, session_id) in room_and_session_ids { + let key = self.serializer.encode_key(keys::INBOUND_GROUP_SESSIONS_V2, (room_id, session_id)); + if let Some(idb_object_js) = object_store.get(&key)?.await? { + let mut idb_object: InboundGroupSessionIndexedDbObject = serde_wasm_bindgen::from_value(idb_object_js)?; + idb_object.needs_backup = false; + object_store.put_key_val(&key, &serde_wasm_bindgen::to_value(&idb_object)?)?; + } + } + + Ok(tx.await.into_result()?) + } + async fn reset_backup_state(&self) -> Result<()> { let tx = self .inner diff --git a/crates/matrix-sdk-sqlite/src/crypto_store.rs b/crates/matrix-sdk-sqlite/src/crypto_store.rs index cb815873a7a..65a8171ed79 100644 --- a/crates/matrix-sdk-sqlite/src/crypto_store.rs +++ b/crates/matrix-sdk-sqlite/src/crypto_store.rs @@ -22,6 +22,7 @@ use std::{ use async_trait::async_trait; use deadpool_sqlite::{Object as SqliteConn, Pool as SqlitePool, Runtime}; +use itertools::Itertools; use matrix_sdk_crypto::{ olm::{ InboundGroupSession, OutboundGroupSession, PickledInboundGroupSession, @@ -40,7 +41,7 @@ use ruma::{ events::secret::request::SecretName, DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, RoomId, TransactionId, UserId, }; -use rusqlite::OptionalExtension; +use rusqlite::{limits::Limit, params_from_iter, OptionalExtension}; use serde::{de::DeserializeOwned, Serialize}; use tokio::{fs, sync::Mutex}; use tracing::{debug, instrument, warn}; @@ -499,6 +500,32 @@ trait SqliteObjectCryptoStoreExt: SqliteObjectExt { .await?) } + async fn mark_inbound_group_sessions_as_backed_up(&self, session_ids: Vec) -> Result<()> { + let max_chunk_size = usize::try_from(self.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER).await) + .expect("SQLITE_LIMIT_VARIABLE_NUMBER was not a usize!"); + + if session_ids.is_empty() { + // We are not expecting to be called with an empty list of sessions + warn!("No sessions to mark as backed up!"); + return Ok(()); + } + + self.with_transaction(move |tx| { + for chunk in session_ids.chunks(max_chunk_size) { + // Safety: placeholders is not generated using any user input except the number of + // session IDs, so it is safe from injection. + let placeholders = generate_placeholders(chunk.len()); + let query = format!( + "UPDATE inbound_group_session SET backed_up = TRUE where session_id IN ({})", + placeholders + ); + tx.execute(&query, params_from_iter(chunk.into_iter()))?; + } + Ok(()) + }) + .await + } + async fn reset_inbound_group_session_backup_state(&self) -> Result<()> { self.execute("UPDATE inbound_group_session SET backed_up = FALSE", ()).await?; Ok(()) @@ -936,6 +963,19 @@ impl CryptoStore for SqliteCryptoStore { .collect() } + async fn mark_inbound_group_sessions_as_backed_up(&self, session_ids: &[(&RoomId, &str)]) -> Result<()> { + Ok(self + .acquire() + .await? + .mark_inbound_group_sessions_as_backed_up( + session_ids + .iter() + .map(|(_, s)| self.encode_key("inbound_group_session", s)) + .collect(), + ) + .await?) + } + async fn reset_backup_state(&self) -> Result<()> { Ok(self.acquire().await?.reset_inbound_group_session_backup_state().await?) } @@ -1232,6 +1272,31 @@ impl CryptoStore for SqliteCryptoStore { } } +fn generate_placeholders(number: usize) -> String { + if number == 0 { + panic!("Can't generate zero placeholders"); + } + (std::iter::repeat("?").take(number)).join(", ") +} + +#[cfg(test)] +mod placeholder_tests { + use super::*; + + #[test] + fn can_generate_placeholders() { + assert_eq!(generate_placeholders(1), "?"); + assert_eq!(generate_placeholders(2), "?, ?"); + assert_eq!(generate_placeholders(5), "?, ?, ?, ?, ?"); + } + + #[test] + #[should_panic(expected = "Can't generate zero placeholders")] + fn generating_zero_placeholders_panics() { + generate_placeholders(0); + } +} + #[cfg(test)] mod tests { use matrix_sdk_crypto::{cryptostore_integration_tests, cryptostore_integration_tests_time}; diff --git a/crates/matrix-sdk-sqlite/src/utils.rs b/crates/matrix-sdk-sqlite/src/utils.rs index bfbd41d9d02..06d06c21107 100644 --- a/crates/matrix-sdk-sqlite/src/utils.rs +++ b/crates/matrix-sdk-sqlite/src/utils.rs @@ -85,6 +85,8 @@ pub(crate) trait SqliteObjectExt { T: Send + 'static, E: From + Send + 'static, F: FnOnce(&Transaction<'_>) -> Result + Send + 'static; + + async fn limit(&self, limit: rusqlite::limits::Limit) -> i32; } #[async_trait] @@ -145,6 +147,10 @@ impl SqliteObjectExt for deadpool_sqlite::Object { .await .unwrap() } + + async fn limit(&self, limit: rusqlite::limits::Limit) -> i32 { + self.interact(move |conn| conn.limit(limit)).await.expect("Failed to fetch limit") + } } pub(crate) trait SqliteConnectionExt {