diff --git a/Sources/WalletConnectSign/Engine/Common/SessionEngine.swift b/Sources/WalletConnectSign/Engine/Common/SessionEngine.swift index 08cd1d0f6..48964f333 100644 --- a/Sources/WalletConnectSign/Engine/Common/SessionEngine.swift +++ b/Sources/WalletConnectSign/Engine/Common/SessionEngine.swift @@ -19,7 +19,7 @@ final class SessionEngine { private let sessionStore: WCSessionStorage private let networkingInteractor: NetworkInteracting - private let historyService: HistoryService + private let historyService: HistoryServiceProtocol private let verifyContextStore: CodableStore private let verifyClient: VerifyClientProtocol private let kms: KeyManagementServiceProtocol @@ -30,7 +30,7 @@ final class SessionEngine { init( networkingInteractor: NetworkInteracting, - historyService: HistoryService, + historyService: HistoryServiceProtocol, verifyContextStore: CodableStore, verifyClient: VerifyClientProtocol, kms: KeyManagementServiceProtocol, @@ -202,6 +202,7 @@ private extension SessionEngine { func setupExpirationSubscriptions() { sessionStore.onSessionExpiration = { [weak self] session in + self?.historyService.removePendingRequest(topic: session.topic) self?.kms.deletePrivateKey(for: session.selfParticipant.publicKey) self?.kms.deleteAgreementSecret(for: session.topic) } diff --git a/Sources/WalletConnectSign/Services/HistoryService.swift b/Sources/WalletConnectSign/Services/HistoryService.swift index ddb44c1ab..c3ac80a3a 100644 --- a/Sources/WalletConnectSign/Services/HistoryService.swift +++ b/Sources/WalletConnectSign/Services/HistoryService.swift @@ -1,7 +1,14 @@ import Foundation protocol HistoryServiceProtocol { + + func getSessionRequest(id: RPCID) -> (request: Request, context: VerifyContext?)? + + func removePendingRequest(topic: String) + func getPendingRequests() -> [(request: Request, context: VerifyContext?)] + + func getPendingRequestsSortedByTimestamp() -> [(request: Request, context: VerifyContext?)] } final class HistoryService: HistoryServiceProtocol { @@ -16,21 +23,25 @@ final class HistoryService: HistoryServiceProtocol { self.history = history self.verifyContextStore = verifyContextStore } - - public func getSessionRequest(id: RPCID) -> (request: Request, context: VerifyContext?)? { + + func getSessionRequest(id: RPCID) -> (request: Request, context: VerifyContext?)? { guard let record = history.get(recordId: id) else { return nil } guard let (request, recordId, _) = mapRequestRecord(record) else { return nil } return (request, try? verifyContextStore.get(key: recordId.string)) } + + func removePendingRequest(topic: String) { + DispatchQueue.global(qos: .background).async { [unowned self] in + history.deleteAll(forTopic: topic) + } + } func getPendingRequests() -> [(request: Request, context: VerifyContext?)] { getPendingRequestsSortedByTimestamp() } - - func getPendingRequestsSortedByTimestamp() -> [(request: Request, context: VerifyContext?)] { let requests = history.getPending() .compactMap { mapRequestRecord($0) } @@ -88,11 +99,27 @@ private extension HistoryService { } #if DEBUG -class MockHistoryService: HistoryServiceProtocol { +final class MockHistoryService: HistoryServiceProtocol { + + var removePendingRequestCalled: (String) -> Void = { _ in } + var pendingRequests: [(request: Request, context: VerifyContext?)] = [] + func removePendingRequest(topic: String) { + pendingRequests.removeAll(where: { $0.request.topic == topic }) + removePendingRequestCalled(topic) + } + + func getSessionRequest(id: JSONRPC.RPCID) -> (request: Request, context: VerifyContext?)? { + fatalError("Unimplemented") + } + func getPendingRequests() -> [(request: Request, context: VerifyContext?)] { - return pendingRequests + pendingRequests + } + + func getPendingRequestsSortedByTimestamp() -> [(request: Request, context: VerifyContext?)] { + fatalError("Unimplemented") } } #endif diff --git a/Sources/WalletConnectSign/Sign/SessionRequestsProvider.swift b/Sources/WalletConnectSign/Sign/SessionRequestsProvider.swift index 89c8e9e79..3afa8290e 100644 --- a/Sources/WalletConnectSign/Sign/SessionRequestsProvider.swift +++ b/Sources/WalletConnectSign/Sign/SessionRequestsProvider.swift @@ -2,7 +2,7 @@ import Combine import Foundation class SessionRequestsProvider { - private let historyService: HistoryService + private let historyService: HistoryServiceProtocol private var sessionRequestPublisherSubject = PassthroughSubject<(request: Request, context: VerifyContext?), Never>() private var lastEmitTime: Date? private let debounceInterval: TimeInterval = 1 @@ -11,7 +11,7 @@ class SessionRequestsProvider { sessionRequestPublisherSubject.eraseToAnyPublisher() } - init(historyService: HistoryService) { + init(historyService: HistoryServiceProtocol) { self.historyService = historyService } diff --git a/Sources/WalletConnectUtils/RPCHistory/RPCHistory.swift b/Sources/WalletConnectUtils/RPCHistory/RPCHistory.swift index 7060c89ae..cd3b6265c 100644 --- a/Sources/WalletConnectUtils/RPCHistory/RPCHistory.swift +++ b/Sources/WalletConnectUtils/RPCHistory/RPCHistory.swift @@ -1,6 +1,9 @@ import Foundation public protocol RPCHistoryProtocol { + + func deleteAll(forTopic topic: String) + func deleteAll(forTopics topics: [String]) } @@ -152,6 +155,10 @@ extension RPCHistory { #if DEBUG class MockRPCHistory: RPCHistoryProtocol { var deletedTopics: [String] = [] + + func deleteAll(forTopic topic: String) { + deletedTopics.append(topic) + } func deleteAll(forTopics topics: [String]) { deletedTopics.append(contentsOf: topics) diff --git a/Tests/WalletConnectSignTests/SessionEngineTests.swift b/Tests/WalletConnectSignTests/SessionEngineTests.swift index a6854c5bc..5cd4e1eea 100644 --- a/Tests/WalletConnectSignTests/SessionEngineTests.swift +++ b/Tests/WalletConnectSignTests/SessionEngineTests.swift @@ -8,13 +8,14 @@ final class SessionEngineTests: XCTestCase { var networkingInteractor: NetworkingInteractorMock! var sessionStorage: WCSessionStorageMock! var verifyContextStore: CodableStore! + var rpcHistory: RPCHistory! var engine: SessionEngine! override func setUp() { networkingInteractor = NetworkingInteractorMock() sessionStorage = WCSessionStorageMock() let defaults = RuntimeKeyValueStorage() - let rpcHistory = RPCHistory( + rpcHistory = RPCHistory( keyValueStore: .init( defaults: defaults, identifier: "" @@ -62,4 +63,52 @@ final class SessionEngineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } + + func testRemovePendingRequestsOnSessionExpiration() { + let expectation = expectation( + description: "Remove pending requests on session expiration" + ) + + let historyService = MockHistoryService() + + engine = SessionEngine( + networkingInteractor: networkingInteractor, + historyService: historyService, + verifyContextStore: verifyContextStore, + verifyClient: VerifyClientMock(), + kms: KeyManagementServiceMock(), + sessionStore: sessionStorage, + logger: ConsoleLoggerMock(), + sessionRequestsProvider: SessionRequestsProvider( + historyService: historyService), + invalidRequestsSanitiser: InvalidRequestsSanitiser( + historyService: historyService, + history: rpcHistory + ) + ) + + let expectedTopic = "topic" + + let session = WCSession.stub( + topic: "topic", + namespaces: SessionNamespace.stubDictionary() + ) + + sessionStorage.setSession(session) + + let request = RPCRequest.stubRequest( + method: "method", + chainId: Blockchain("eip155:1")!, + expiry: UInt64(Date().timeIntervalSince1970) + ) + + historyService.removePendingRequestCalled = { topic in + XCTAssertEqual(topic, expectedTopic) + expectation.fulfill() + } + + sessionStorage.onSessionExpiration!(session) + + wait(for: [expectation], timeout: 0.5) + } }