Skip to content

Commit

Permalink
Merge pull request #3338 from matrix-org/kegan/drop-store
Browse files Browse the repository at this point in the history
bugfix: ensure the SessionStore is cleared when regenerating the OlmMachine
  • Loading branch information
andybalaam authored Apr 19, 2024
2 parents 4325812 + 381c02d commit a3e6a07
Show file tree
Hide file tree
Showing 8 changed files with 400 additions and 3 deletions.
322 changes: 321 additions & 1 deletion crates/matrix-sdk-base/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ impl BaseClient {
tracing::debug!("regenerating OlmMachine");
let session_meta = self.session_meta().ok_or(Error::OlmError(OlmError::MissingSession))?;

// Recreate it.
// Recreate the `OlmMachine` and wipe the in-memory cache in the store
// because we suspect it has stale data.
self.crypto_store.clear_caches().await;
let olm_machine = OlmMachine::with_store(
&session_meta.user_id,
&session_meta.device_id,
Expand Down Expand Up @@ -1695,6 +1697,36 @@ mod tests {
client.get_room(room_id).expect("Just-created room not found!")
}

#[cfg(feature = "e2e-encryption")]
#[async_test]
async fn test_regerating_olm_clears_store_caches() {
// See https://github.com/matrix-org/matrix-rust-sdk/issues/3110
// We must clear the store cache when we regenerate the OlmMachine
// to ensure we really get the new state.

use ruma::{owned_device_id, owned_user_id};

use crate::store::StoreConfig;

// Given a client using a fake store
let user_id = owned_user_id!("@u:m.o");
let device_id = owned_device_id!("DEVICE");
let fake_store = fake_crypto_store::FakeCryptoStore::default();
let store_config = StoreConfig::new().crypto_store(fake_store.clone());
let client = BaseClient::with_store_config(store_config);
client.set_session_meta(SessionMeta { user_id, device_id }).await.unwrap();
fake_store.clear_method_calls();

// When we regenerate the OlmMachine
client.regenerate_olm().await.expect("Failed to regenerate olm");

// Then we cleared the store cache
assert!(
fake_store.method_calls().contains(&"clear_caches".to_owned()),
"No clear_caches call!"
);
}

#[async_test]
async fn test_deserialization_failure() {
let user_id = user_id!("@alice:example.org");
Expand Down Expand Up @@ -1884,4 +1916,292 @@ mod tests {
assert_eq!(member.display_name().unwrap(), "Invited Alice");
assert_eq!(member.avatar_url().unwrap().to_string(), "mxc://localhost/fewjilfewjil42");
}

#[cfg(feature = "e2e-encryption")]
mod fake_crypto_store {
use std::{
collections::HashMap,
convert::Infallible,
sync::{Arc, Mutex},
};

use async_trait::async_trait;
use matrix_sdk_crypto::{
olm::{InboundGroupSession, OutboundGroupSession, PrivateCrossSigningIdentity},
store::{
BackupKeys, Changes, CryptoStore, PendingChanges, RoomKeyCounts, RoomSettings,
},
types::events::room_key_withheld::RoomKeyWithheldEvent,
Account, GossipRequest, GossippedSecret, ReadOnlyDevice, ReadOnlyUserIdentities,
SecretInfo, Session, TrackedUser,
};
use ruma::{
events::secret::request::SecretName, DeviceId, OwnedDeviceId, RoomId, TransactionId,
UserId,
};

#[derive(Clone, Debug, Default)]
pub(crate) struct FakeCryptoStore {
pub method_calls: Arc<Mutex<Vec<String>>>,
}

impl FakeCryptoStore {
pub fn method_calls(&self) -> Vec<String> {
self.method_calls.lock().unwrap().clone()
}

pub fn clear_method_calls(&self) {
self.method_calls.lock().unwrap().clear();
}

fn call(&self, method_name: &str) {
self.method_calls.lock().unwrap().push(method_name.to_owned());
}
}

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl CryptoStore for FakeCryptoStore {
type Error = Infallible;

async fn clear_caches(&self) {
self.call("clear_caches");
}

async fn load_account(&self) -> Result<Option<Account>, Self::Error> {
self.call("load_account");
Ok(None)
}

async fn load_identity(
&self,
) -> Result<Option<PrivateCrossSigningIdentity>, Self::Error> {
self.call("load_identity");
Ok(None)
}

async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
self.call("next_batch_token");
Ok(None)
}

async fn save_pending_changes(
&self,
_changes: PendingChanges,
) -> Result<(), Self::Error> {
self.call("save_pending_changes");
Ok(())
}

async fn save_changes(&self, _changes: Changes) -> Result<(), Self::Error> {
self.call("save_changes");
Ok(())
}

async fn get_sessions(
&self,
_sender_key: &str,
) -> Result<Option<Arc<tokio::sync::Mutex<Vec<Session>>>>, Self::Error> {
self.call("get_sessions");
Ok(None)
}

async fn get_inbound_group_session(
&self,
_room_id: &RoomId,
_session_id: &str,
) -> Result<Option<InboundGroupSession>, Self::Error> {
self.call("get_inbound_group_session");
Ok(None)
}

async fn get_withheld_info(
&self,
_room_id: &RoomId,
_session_id: &str,
) -> Result<Option<RoomKeyWithheldEvent>, Self::Error> {
self.call("get_withheld_info");
Ok(None)
}

async fn get_inbound_group_sessions(
&self,
) -> Result<Vec<InboundGroupSession>, Self::Error> {
self.call("get_inbound_group_sessions");
Ok(vec![])
}

async fn inbound_group_session_counts(
&self,
_backup_version: Option<&str>,
) -> Result<RoomKeyCounts, Self::Error> {
self.call("inbound_group_session_counts");
Ok(RoomKeyCounts { total: 0, backed_up: 0 })
}

async fn inbound_group_sessions_for_backup(
&self,
_backup_version: &str,
_limit: usize,
) -> Result<Vec<InboundGroupSession>, Self::Error> {
self.call("inbound_group_sessions_for_backup");
Ok(vec![])
}

async fn mark_inbound_group_sessions_as_backed_up(
&self,
_backup_version: &str,
_room_and_session_ids: &[(&RoomId, &str)],
) -> Result<(), Self::Error> {
self.call("mark_inbound_group_sessions_as_backed_up");
Ok(())
}

async fn reset_backup_state(&self) -> Result<(), Self::Error> {
self.call("reset_backup_state");
Ok(())
}

async fn load_backup_keys(&self) -> Result<BackupKeys, Self::Error> {
self.call("load_backup_keys");
Ok(BackupKeys::default())
}

async fn get_outbound_group_session(
&self,
_room_id: &RoomId,
) -> Result<Option<OutboundGroupSession>, Self::Error> {
self.call("get_outbound_group_session");
Ok(None)
}

async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>, Self::Error> {
self.call("load_tracked_users");
Ok(vec![])
}

async fn save_tracked_users(
&self,
_tracked_users: &[(&UserId, bool)],
) -> Result<(), Self::Error> {
self.call("save_tracked_users");
Ok(())
}

async fn get_device(
&self,
_user_id: &UserId,
_device_id: &DeviceId,
) -> Result<Option<ReadOnlyDevice>, Self::Error> {
self.call("get_device");
Ok(None)
}

async fn get_user_devices(
&self,
_user_id: &UserId,
) -> Result<HashMap<OwnedDeviceId, ReadOnlyDevice>, Self::Error> {
self.call("get_user_devices");
Ok(HashMap::default())
}

async fn get_user_identity(
&self,
_user_id: &UserId,
) -> Result<Option<ReadOnlyUserIdentities>, Self::Error> {
self.call("get_user_identity");
Ok(None)
}

async fn is_message_known(
&self,
_message_hash: &matrix_sdk_crypto::olm::OlmMessageHash,
) -> Result<bool, Self::Error> {
self.call("is_message_known");
Ok(false)
}

async fn get_outgoing_secret_requests(
&self,
_request_id: &TransactionId,
) -> Result<Option<GossipRequest>, Self::Error> {
self.call("get_outgoing_secret_requests");
Ok(None)
}

async fn get_secret_request_by_info(
&self,
_key_info: &SecretInfo,
) -> Result<Option<GossipRequest>, Self::Error> {
self.call("get_secret_request_by_info");
Ok(None)
}

async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>, Self::Error> {
self.call("get_unsent_secret_requests");
Ok(vec![])
}

async fn delete_outgoing_secret_requests(
&self,
_request_id: &TransactionId,
) -> Result<(), Self::Error> {
self.call("delete_outgoing_secret_requests");
Ok(())
}

async fn get_secrets_from_inbox(
&self,
_secret_name: &SecretName,
) -> Result<Vec<GossippedSecret>, Self::Error> {
self.call("get_secrets_from_inbox");
Ok(vec![])
}

async fn delete_secrets_from_inbox(
&self,
_secret_name: &SecretName,
) -> Result<(), Self::Error> {
self.call("delete_secrets_from_inbox");
Ok(())
}

async fn get_room_settings(
&self,
_room_id: &RoomId,
) -> Result<Option<RoomSettings>, Self::Error> {
self.call("get_room_settings");
Ok(None)
}

async fn get_custom_value(&self, _key: &str) -> Result<Option<Vec<u8>>, Self::Error> {
self.call("get_custom_value");
Ok(None)
}

async fn set_custom_value(
&self,
_key: &str,
_value: Vec<u8>,
) -> Result<(), Self::Error> {
self.call("set_custom_value");
Ok(())
}

async fn remove_custom_value(&self, _key: &str) -> Result<(), Self::Error> {
self.call("remove_custom_value");
Ok(())
}

async fn try_take_leased_lock(
&self,
_lease_duration_ms: u32,
_key: &str,
_holder: &str,
) -> Result<bool, Self::Error> {
self.call("try_take_leased_lock");
Ok(true)
}
}
}
}
3 changes: 3 additions & 0 deletions crates/matrix-sdk-crypto/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ Breaking changes:

Additions:

- Expose new method `CryptoStore::clear_caches`.
([#3338](https://github.com/matrix-org/matrix-rust-sdk/pull/3338))

- Expose new method `OlmMachine::device_creation_time`.
([#3275](https://github.com/matrix-org/matrix-rust-sdk/pull/3275))

Expand Down
7 changes: 7 additions & 0 deletions crates/matrix-sdk-crypto/src/store/caches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ impl SessionStore {
Self::default()
}

/// Clear all entries in the session store.
///
/// This is intended to be used when regenerating olm machines.
pub fn clear(&self) {
self.entries.write().unwrap().clear()
}

/// Add a session to the store.
///
/// Returns true if the session was added, false if the session was
Expand Down
2 changes: 1 addition & 1 deletion crates/matrix-sdk-crypto/src/store/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ macro_rules! cryptostore_integration_tests {
Account::with_device_id(alice_id(), alice_device_id())
}

async fn get_account_and_session() -> (Account, Session) {
pub(crate) async fn get_account_and_session() -> (Account, Session) {
let alice = Account::with_device_id(alice_id(), alice_device_id());
let mut bob = Account::with_device_id(bob_id(), bob_device_id());

Expand Down
12 changes: 12 additions & 0 deletions crates/matrix-sdk-crypto/src/store/memorystore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ type Result<T> = std::result::Result<T, Infallible>;
impl CryptoStore for MemoryStore {
type Error = Infallible;

async fn clear_caches(&self) {
// no-op: it makes no sense to delete fields here as we would forget our
// identity, etc Effectively we have no caches as the fields
// *are* the underlying store. Calling this method only makes
// sense if there is some other layer (e.g disk) persistence
// happening.
}

async fn load_account(&self) -> Result<Option<Account>> {
Ok(self.account.read().unwrap().as_ref().map(|acc| acc.deep_clone()))
}
Expand Down Expand Up @@ -718,6 +726,10 @@ mod integration_tests {
impl CryptoStore for PersistentMemoryStore {
type Error = <MemoryStore as CryptoStore>::Error;

async fn clear_caches(&self) {
self.0.clear_caches().await
}

async fn load_account(&self) -> Result<Option<Account>, Self::Error> {
self.0.load_account().await
}
Expand Down
Loading

0 comments on commit a3e6a07

Please sign in to comment.