diff --git a/crates/matrix-sdk-base/src/client.rs b/crates/matrix-sdk-base/src/client.rs index 621c6e65086..babd5d38edb 100644 --- a/crates/matrix-sdk-base/src/client.rs +++ b/crates/matrix-sdk-base/src/client.rs @@ -1697,6 +1697,36 @@ mod tests { client.get_room(room_id).expect("Just-created room not found!") } + #[cfg(feature = "e2e-encryption")] + #[async_test] + async fn test_regerating_olm_clears_store_caches() { + // See https://github.com/matrix-org/matrix-rust-sdk/issues/3110 + // We must clear the store cache when we regenerate the OlmMachine + // to ensure we really get the new state. + + use ruma::{owned_device_id, owned_user_id}; + + use crate::store::StoreConfig; + + // Given a client using a fake store + let user_id = owned_user_id!("@u:m.o"); + let device_id = owned_device_id!("DEVICE"); + let fake_store = fake_crypto_store::FakeCryptoStore::default(); + let store_config = StoreConfig::new().crypto_store(fake_store.clone()); + let client = BaseClient::with_store_config(store_config); + client.set_session_meta(SessionMeta { user_id, device_id }).await.unwrap(); + fake_store.clear_method_calls(); + + // When we regenerate the OlmMachine + client.regenerate_olm().await.expect("Failed to regenerate olm"); + + // Then we cleared the store cache + assert!( + fake_store.method_calls().contains(&"clear_caches".to_owned()), + "No clear_caches call!" + ); + } + #[async_test] async fn test_deserialization_failure() { let user_id = user_id!("@alice:example.org"); @@ -1886,4 +1916,292 @@ mod tests { assert_eq!(member.display_name().unwrap(), "Invited Alice"); assert_eq!(member.avatar_url().unwrap().to_string(), "mxc://localhost/fewjilfewjil42"); } + + #[cfg(feature = "e2e-encryption")] + mod fake_crypto_store { + use std::{ + collections::HashMap, + convert::Infallible, + sync::{Arc, Mutex}, + }; + + use async_trait::async_trait; + use matrix_sdk_crypto::{ + olm::{InboundGroupSession, OutboundGroupSession, PrivateCrossSigningIdentity}, + store::{ + BackupKeys, Changes, CryptoStore, PendingChanges, RoomKeyCounts, RoomSettings, + }, + types::events::room_key_withheld::RoomKeyWithheldEvent, + Account, GossipRequest, GossippedSecret, ReadOnlyDevice, ReadOnlyUserIdentities, + SecretInfo, Session, TrackedUser, + }; + use ruma::{ + events::secret::request::SecretName, DeviceId, OwnedDeviceId, RoomId, TransactionId, + UserId, + }; + + #[derive(Clone, Debug, Default)] + pub(crate) struct FakeCryptoStore { + pub method_calls: Arc>>, + } + + impl FakeCryptoStore { + pub fn method_calls(&self) -> Vec { + self.method_calls.lock().unwrap().clone() + } + + pub fn clear_method_calls(&self) { + self.method_calls.lock().unwrap().clear(); + } + + fn call(&self, method_name: &str) { + self.method_calls.lock().unwrap().push(method_name.to_owned()); + } + } + + #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] + #[cfg_attr(not(target_arch = "wasm32"), async_trait)] + impl CryptoStore for FakeCryptoStore { + type Error = Infallible; + + async fn clear_caches(&self) { + self.call("clear_caches"); + } + + async fn load_account(&self) -> Result, Self::Error> { + self.call("load_account"); + Ok(None) + } + + async fn load_identity( + &self, + ) -> Result, Self::Error> { + self.call("load_identity"); + Ok(None) + } + + async fn next_batch_token(&self) -> Result, Self::Error> { + self.call("next_batch_token"); + Ok(None) + } + + async fn save_pending_changes( + &self, + _changes: PendingChanges, + ) -> Result<(), Self::Error> { + self.call("save_pending_changes"); + Ok(()) + } + + async fn save_changes(&self, _changes: Changes) -> Result<(), Self::Error> { + self.call("save_changes"); + Ok(()) + } + + async fn get_sessions( + &self, + _sender_key: &str, + ) -> Result>>>, Self::Error> { + self.call("get_sessions"); + Ok(None) + } + + async fn get_inbound_group_session( + &self, + _room_id: &RoomId, + _session_id: &str, + ) -> Result, Self::Error> { + self.call("get_inbound_group_session"); + Ok(None) + } + + async fn get_withheld_info( + &self, + _room_id: &RoomId, + _session_id: &str, + ) -> Result, Self::Error> { + self.call("get_withheld_info"); + Ok(None) + } + + async fn get_inbound_group_sessions( + &self, + ) -> Result, Self::Error> { + self.call("get_inbound_group_sessions"); + Ok(vec![]) + } + + async fn inbound_group_session_counts( + &self, + _backup_version: Option<&str>, + ) -> Result { + self.call("inbound_group_session_counts"); + Ok(RoomKeyCounts { total: 0, backed_up: 0 }) + } + + async fn inbound_group_sessions_for_backup( + &self, + _backup_version: &str, + _limit: usize, + ) -> Result, Self::Error> { + self.call("inbound_group_sessions_for_backup"); + Ok(vec![]) + } + + async fn mark_inbound_group_sessions_as_backed_up( + &self, + _backup_version: &str, + _room_and_session_ids: &[(&RoomId, &str)], + ) -> Result<(), Self::Error> { + self.call("mark_inbound_group_sessions_as_backed_up"); + Ok(()) + } + + async fn reset_backup_state(&self) -> Result<(), Self::Error> { + self.call("reset_backup_state"); + Ok(()) + } + + async fn load_backup_keys(&self) -> Result { + self.call("load_backup_keys"); + Ok(BackupKeys::default()) + } + + async fn get_outbound_group_session( + &self, + _room_id: &RoomId, + ) -> Result, Self::Error> { + self.call("get_outbound_group_session"); + Ok(None) + } + + async fn load_tracked_users(&self) -> Result, Self::Error> { + self.call("load_tracked_users"); + Ok(vec![]) + } + + async fn save_tracked_users( + &self, + _tracked_users: &[(&UserId, bool)], + ) -> Result<(), Self::Error> { + self.call("save_tracked_users"); + Ok(()) + } + + async fn get_device( + &self, + _user_id: &UserId, + _device_id: &DeviceId, + ) -> Result, Self::Error> { + self.call("get_device"); + Ok(None) + } + + async fn get_user_devices( + &self, + _user_id: &UserId, + ) -> Result, Self::Error> { + self.call("get_user_devices"); + Ok(HashMap::default()) + } + + async fn get_user_identity( + &self, + _user_id: &UserId, + ) -> Result, Self::Error> { + self.call("get_user_identity"); + Ok(None) + } + + async fn is_message_known( + &self, + _message_hash: &matrix_sdk_crypto::olm::OlmMessageHash, + ) -> Result { + self.call("is_message_known"); + Ok(false) + } + + async fn get_outgoing_secret_requests( + &self, + _request_id: &TransactionId, + ) -> Result, Self::Error> { + self.call("get_outgoing_secret_requests"); + Ok(None) + } + + async fn get_secret_request_by_info( + &self, + _key_info: &SecretInfo, + ) -> Result, Self::Error> { + self.call("get_secret_request_by_info"); + Ok(None) + } + + async fn get_unsent_secret_requests(&self) -> Result, Self::Error> { + self.call("get_unsent_secret_requests"); + Ok(vec![]) + } + + async fn delete_outgoing_secret_requests( + &self, + _request_id: &TransactionId, + ) -> Result<(), Self::Error> { + self.call("delete_outgoing_secret_requests"); + Ok(()) + } + + async fn get_secrets_from_inbox( + &self, + _secret_name: &SecretName, + ) -> Result, Self::Error> { + self.call("get_secrets_from_inbox"); + Ok(vec![]) + } + + async fn delete_secrets_from_inbox( + &self, + _secret_name: &SecretName, + ) -> Result<(), Self::Error> { + self.call("delete_secrets_from_inbox"); + Ok(()) + } + + async fn get_room_settings( + &self, + _room_id: &RoomId, + ) -> Result, Self::Error> { + self.call("get_room_settings"); + Ok(None) + } + + async fn get_custom_value(&self, _key: &str) -> Result>, Self::Error> { + self.call("get_custom_value"); + Ok(None) + } + + async fn set_custom_value( + &self, + _key: &str, + _value: Vec, + ) -> Result<(), Self::Error> { + self.call("set_custom_value"); + Ok(()) + } + + async fn remove_custom_value(&self, _key: &str) -> Result<(), Self::Error> { + self.call("remove_custom_value"); + Ok(()) + } + + async fn try_take_leased_lock( + &self, + _lease_duration_ms: u32, + _key: &str, + _holder: &str, + ) -> Result { + self.call("try_take_leased_lock"); + Ok(true) + } + } + } }