Skip to content

Commit

Permalink
Provide CryptoStore::mark_sessions_as_backed_up on stores and use it …
Browse files Browse the repository at this point in the history
…in BackupMachine::mark_request_as_sent

Signed-off-by: Andy Balaam <andy.balaam@matrix.org>
  • Loading branch information
andybalaam committed Dec 12, 2023
1 parent e652069 commit 81dad3c
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 27 deletions.
39 changes: 14 additions & 25 deletions crates/matrix-sdk-crypto/src/backups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use std::{

use ruma::{
api::client::backup::RoomKeyBackup, serde::Raw, DeviceId, DeviceKeyAlgorithm, OwnedDeviceId,
OwnedRoomId, OwnedTransactionId, TransactionId,
OwnedRoomId, OwnedTransactionId, RoomId, TransactionId,
};
use tokio::sync::RwLock;
use tracing::{debug, info, instrument, trace, warn};
Expand Down Expand Up @@ -64,20 +64,10 @@ pub struct BackupMachine {
struct PendingBackup {
request_id: OwnedTransactionId,
request: KeysBackupRequest,
/// Room ID : Sender Key : [Session IDs]
sessions: BTreeMap<OwnedRoomId, BTreeMap<String, BTreeSet<String>>>,
}

impl PendingBackup {
fn session_was_part_of_the_backup(&self, session: &InboundGroupSession) -> bool {
self.sessions
.get(session.room_id())
.and_then(|r| {
r.get(&session.sender_key().to_base64()).map(|s| s.contains(session.session_id()))
})
.unwrap_or(false)
}
}

impl From<PendingBackup> for OutgoingRequest {
fn from(b: PendingBackup) -> Self {
OutgoingRequest { request_id: b.request_id, request: Arc::new(b.request.into()) }
Expand Down Expand Up @@ -477,25 +467,20 @@ impl BackupMachine {
request_id: &TransactionId,
) -> Result<(), CryptoStoreError> {
let mut request = self.pending_backup.write().await;

if let Some(r) = &*request {
if r.request_id == request_id {
let sessions: Vec<_> = self
.store
.get_inbound_group_sessions()
.await?
.into_iter()
.filter(|s| r.session_was_part_of_the_backup(s))
let room_and_session_ids: Vec<(&RoomId, &str)> = r
.sessions
.iter()
.flat_map(|(room_id, sender_key_to_session_ids)| {
std::iter::repeat(room_id).zip(sender_key_to_session_ids.values().flatten())
})
.map(|(room_id, session_id)| (room_id.as_ref(), session_id.as_str()))
.collect();

for session in &sessions {
session.mark_as_backed_up();
}

trace!(request_id = ?r.request_id, keys = ?r.sessions, "Marking room keys as backed up");

let changes = Changes { inbound_group_sessions: sessions, ..Default::default() };
self.store.save_changes(changes).await?;
self.store.mark_sessions_as_backed_up(&room_and_session_ids).await?;

let counts = self.store.inbound_group_session_counts().await?;

Expand Down Expand Up @@ -563,6 +548,10 @@ impl BackupMachine {
}

/// Backup all the non-backed up room keys we know about
/// returns a tuple: (
/// map of Room ID : RoomKeyBackup,
/// map of Room ID : map of Sender Key : set of Session IDs
/// )
async fn backup_keys(
sessions: Vec<InboundGroupSession>,
backup_key: &MegolmV1BackupKey,
Expand Down
50 changes: 49 additions & 1 deletion crates/matrix-sdk-crypto/src/store/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ macro_rules! cryptostore_integration_tests {
room_id,
serde::{Base64, Raw},
to_device::DeviceIdOrAllDevices,
user_id, DeviceId, JsOption, OwnedDeviceId, OwnedUserId, TransactionId, UserId,
user_id, DeviceId, JsOption, OwnedDeviceId, OwnedUserId, RoomId, TransactionId, UserId
};
use serde_json::value::to_raw_value;
use $crate::{
Expand Down Expand Up @@ -292,6 +292,54 @@ macro_rules! cryptostore_integration_tests {
assert_eq!(to_back_up, vec![session])
}

#[async_test]
async fn mark_inbound_group_sessions_as_backed_up() {
// Given a store exists with multiple unbacked-up sessions
let (account, store) =
get_loaded_store("mark_inbound_group_sessions_as_backed_up").await;
let room_id = &room_id!("!test:localhost");
let mut sessions: Vec<InboundGroupSession> = Vec::with_capacity(10);
for i in 0..10 {
sessions.push(account.create_group_session_pair_with_defaults(room_id).await.1);
}
let changes = Changes { inbound_group_sessions: sessions.clone(), ..Default::default() };
store.save_changes(changes).await.expect("Can't save group session");
assert_eq!(store.inbound_group_sessions_for_backup(100).await.unwrap().len(), 10);

fn session_info(session: &InboundGroupSession) -> (&RoomId, &str) {
(&session.room_id(), &session.session_id())
}

// When I mark some as backed up
let x = store.mark_sessions_as_backed_up(&[
session_info(&sessions[1]),
session_info(&sessions[3]),
session_info(&sessions[5]),
session_info(&sessions[7]),
session_info(&sessions[9]),
]).await.expect("Failed to mark sessions as backed up");


// And ask which still need backing up
let to_back_up = store.inbound_group_sessions_for_backup(10).await.unwrap();
let needs_backing_up = |i: usize| to_back_up.iter().any(|s| s.session_id() == sessions[i].session_id());

// Then the sessions we said were backed up no longer need backing up
assert!(!needs_backing_up(1));
assert!(!needs_backing_up(2));
assert!(!needs_backing_up(3));
assert!(!needs_backing_up(4));
assert!(!needs_backing_up(5));

// And the sessions we didn't mention still need backing up
assert!(needs_backing_up(0));
assert!(needs_backing_up(2));
assert!(needs_backing_up(4));
assert!(needs_backing_up(6));
assert!(needs_backing_up(8));
assert_eq!(to_back_up.len(), 5);
}

#[async_test]
async fn reset_inbound_group_session_for_backup() {
let (account, store) =
Expand Down
14 changes: 14 additions & 0 deletions crates/matrix-sdk-crypto/src/store/memorystore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,20 @@ impl CryptoStore for MemoryStore {
.collect())
}

async fn mark_sessions_as_backed_up(
&self,
room_and_session_ids: &[(&RoomId, &str)],
) -> Result<()> {
for (room_id, session_id) in room_and_session_ids {
let session = self.inbound_group_sessions.get(room_id, session_id);
if let Some(session) = session {
session.mark_as_backed_up();
self.inbound_group_sessions.add(session);
}
}
Ok(())
}

async fn reset_backup_state(&self) -> Result<()> {
for session in self.get_inbound_group_sessions().await? {
session.reset_backup_state();
Expand Down
13 changes: 13 additions & 0 deletions crates/matrix-sdk-crypto/src/store/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ pub trait CryptoStore: AsyncTraitDeps {
limit: usize,
) -> Result<Vec<InboundGroupSession>, Self::Error>;

/// Mark the sessions with the supplied room and session IDs as backed up
async fn mark_sessions_as_backed_up(
&self,
room_and_session_ids: &[(&RoomId, &str)],
) -> Result<(), Self::Error>;

/// Reset the backup state of all the stored inbound group sessions.
async fn reset_backup_state(&self) -> Result<(), Self::Error>;

Expand Down Expand Up @@ -332,6 +338,13 @@ impl<T: CryptoStore> CryptoStore for EraseCryptoStoreError<T> {
self.0.inbound_group_sessions_for_backup(limit).await.map_err(Into::into)
}

async fn mark_sessions_as_backed_up(
&self,
room_and_session_ids: &[(&RoomId, &str)],
) -> Result<()> {
self.0.mark_sessions_as_backed_up(room_and_session_ids).await.map_err(Into::into)
}

async fn reset_backup_state(&self) -> Result<()> {
self.0.reset_backup_state().await.map_err(Into::into)
}
Expand Down
22 changes: 22 additions & 0 deletions crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,28 @@ impl_crypto_store! {
Ok(result)
}

async fn mark_sessions_as_backed_up(&self, room_and_session_ids: &[(&RoomId, &str)]) -> Result<()> {
let tx = self
.inner
.transaction_on_one_with_mode(
keys::INBOUND_GROUP_SESSIONS_V2,
IdbTransactionMode::Readwrite,
)?;

let object_store = tx.object_store(keys::INBOUND_GROUP_SESSIONS_V2)?;

for (room_id, session_id) in room_and_session_ids {
let key = self.serializer.encode_key(keys::INBOUND_GROUP_SESSIONS_V2, (room_id, session_id));
if let Some(idb_object_js) = object_store.get(&key)?.await? {
let mut idb_object: InboundGroupSessionIndexedDbObject = serde_wasm_bindgen::from_value(idb_object_js)?;
idb_object.needs_backup = false;
object_store.put_key_val(&key, &serde_wasm_bindgen::to_value(&idb_object)?)?;
}
}

Ok(tx.await.into_result()?)
}

async fn reset_backup_state(&self) -> Result<()> {
let tx = self
.inner
Expand Down
67 changes: 66 additions & 1 deletion crates/matrix-sdk-sqlite/src/crypto_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use std::{

use async_trait::async_trait;
use deadpool_sqlite::{Object as SqliteConn, Pool as SqlitePool, Runtime};
use itertools::Itertools;
use matrix_sdk_crypto::{
olm::{
InboundGroupSession, OutboundGroupSession, PickledInboundGroupSession,
Expand All @@ -40,7 +41,7 @@ use ruma::{
events::secret::request::SecretName, DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId,
RoomId, TransactionId, UserId,
};
use rusqlite::OptionalExtension;
use rusqlite::{limits::Limit, params_from_iter, OptionalExtension};
use serde::{de::DeserializeOwned, Serialize};
use tokio::{fs, sync::Mutex};
use tracing::{debug, instrument, warn};
Expand Down Expand Up @@ -499,6 +500,32 @@ trait SqliteObjectCryptoStoreExt: SqliteObjectExt {
.await?)
}

async fn mark_sessions_as_backed_up(&self, session_ids: Vec<Key>) -> Result<()> {
let max_chunk_size = usize::try_from(self.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER).await)
.expect("SQLITE_LIMIT_VARIABLE_NUMBER was not a usize!");

if session_ids.is_empty() {
// We are not expecting to be called with an empty list of sessions
warn!("No sessions to mark as backed up!");
return Ok(());
}

self.with_transaction(move |tx| {
for chunk in session_ids.chunks(max_chunk_size) {
// Safety: placeholders is not generated using any user input except the number of
// session IDs, so it is safe from injection.
let placeholders = generate_placeholders(chunk.len());
let query = format!(
"UPDATE inbound_group_session SET backed_up = TRUE where session_id IN ({})",
placeholders
);
tx.execute(&query, params_from_iter(chunk.into_iter()))?;
}
Ok(())
})
.await
}

async fn reset_inbound_group_session_backup_state(&self) -> Result<()> {
self.execute("UPDATE inbound_group_session SET backed_up = FALSE", ()).await?;
Ok(())
Expand Down Expand Up @@ -936,6 +963,19 @@ impl CryptoStore for SqliteCryptoStore {
.collect()
}

async fn mark_sessions_as_backed_up(&self, session_ids: &[(&RoomId, &str)]) -> Result<()> {
Ok(self
.acquire()
.await?
.mark_sessions_as_backed_up(
session_ids
.iter()
.map(|(_, s)| self.encode_key("inbound_group_session", s))
.collect(),
)
.await?)
}

async fn reset_backup_state(&self) -> Result<()> {
Ok(self.acquire().await?.reset_inbound_group_session_backup_state().await?)
}
Expand Down Expand Up @@ -1232,6 +1272,31 @@ impl CryptoStore for SqliteCryptoStore {
}
}

fn generate_placeholders(number: usize) -> String {
if number == 0 {
panic!("Can't generate zero placeholders");
}
(std::iter::repeat("?").take(number)).join(", ")
}

#[cfg(test)]
mod placeholder_tests {
use super::*;

#[test]
fn can_generate_placeholders() {
assert_eq!(generate_placeholders(1), "?");
assert_eq!(generate_placeholders(2), "?, ?");
assert_eq!(generate_placeholders(5), "?, ?, ?, ?, ?");
}

#[test]
#[should_panic(expected = "Can't generate zero placeholders")]
fn generating_zero_placeholders_panics() {
generate_placeholders(0);
}
}

#[cfg(test)]
mod tests {
use matrix_sdk_crypto::{cryptostore_integration_tests, cryptostore_integration_tests_time};
Expand Down
6 changes: 6 additions & 0 deletions crates/matrix-sdk-sqlite/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ pub(crate) trait SqliteObjectExt {
T: Send + 'static,
E: From<rusqlite::Error> + Send + 'static,
F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static;

async fn limit(&self, limit: rusqlite::limits::Limit) -> i32;
}

#[async_trait]
Expand Down Expand Up @@ -145,6 +147,10 @@ impl SqliteObjectExt for deadpool_sqlite::Object {
.await
.unwrap()
}

async fn limit(&self, limit: rusqlite::limits::Limit) -> i32 {
self.interact(move |conn| conn.limit(limit)).await.expect("Failed to fetch limit")
}
}

pub(crate) trait SqliteConnectionExt {
Expand Down

0 comments on commit 81dad3c

Please sign in to comment.