Skip to content

Commit

Permalink
Merge pull request #1356 from WalletConnect/jack/clear-requests-on-se…
Browse files Browse the repository at this point in the history
…ssion-expiration

Remove pending requests on session expiration
  • Loading branch information
jackpooleywc authored May 17, 2024
2 parents c280211 + 787c8a9 commit 92bfb50
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 11 deletions.
5 changes: 3 additions & 2 deletions Sources/WalletConnectSign/Engine/Common/SessionEngine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ final class SessionEngine {

private let sessionStore: WCSessionStorage
private let networkingInteractor: NetworkInteracting
private let historyService: HistoryService
private let historyService: HistoryServiceProtocol
private let verifyContextStore: CodableStore<VerifyContext>
private let verifyClient: VerifyClientProtocol
private let kms: KeyManagementServiceProtocol
Expand All @@ -30,7 +30,7 @@ final class SessionEngine {

init(
networkingInteractor: NetworkInteracting,
historyService: HistoryService,
historyService: HistoryServiceProtocol,
verifyContextStore: CodableStore<VerifyContext>,
verifyClient: VerifyClientProtocol,
kms: KeyManagementServiceProtocol,
Expand Down Expand Up @@ -202,6 +202,7 @@ private extension SessionEngine {

func setupExpirationSubscriptions() {
sessionStore.onSessionExpiration = { [weak self] session in
self?.historyService.removePendingRequest(topic: session.topic)
self?.kms.deletePrivateKey(for: session.selfParticipant.publicKey)
self?.kms.deleteAgreementSecret(for: session.topic)
}
Expand Down
39 changes: 33 additions & 6 deletions Sources/WalletConnectSign/Services/HistoryService.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import Foundation

protocol HistoryServiceProtocol {

func getSessionRequest(id: RPCID) -> (request: Request, context: VerifyContext?)?

func removePendingRequest(topic: String)

func getPendingRequests() -> [(request: Request, context: VerifyContext?)]

func getPendingRequestsSortedByTimestamp() -> [(request: Request, context: VerifyContext?)]
}

final class HistoryService: HistoryServiceProtocol {
Expand All @@ -16,21 +23,25 @@ final class HistoryService: HistoryServiceProtocol {
self.history = history
self.verifyContextStore = verifyContextStore
}

public func getSessionRequest(id: RPCID) -> (request: Request, context: VerifyContext?)? {
func getSessionRequest(id: RPCID) -> (request: Request, context: VerifyContext?)? {
guard let record = history.get(recordId: id) else { return nil }
guard let (request, recordId, _) = mapRequestRecord(record) else {
return nil
}
return (request, try? verifyContextStore.get(key: recordId.string))
}

func removePendingRequest(topic: String) {
DispatchQueue.global(qos: .background).async { [unowned self] in
history.deleteAll(forTopic: topic)
}
}

func getPendingRequests() -> [(request: Request, context: VerifyContext?)] {
getPendingRequestsSortedByTimestamp()
}



func getPendingRequestsSortedByTimestamp() -> [(request: Request, context: VerifyContext?)] {
let requests = history.getPending()
.compactMap { mapRequestRecord($0) }
Expand Down Expand Up @@ -88,11 +99,27 @@ private extension HistoryService {
}

#if DEBUG
class MockHistoryService: HistoryServiceProtocol {
final class MockHistoryService: HistoryServiceProtocol {

var removePendingRequestCalled: (String) -> Void = { _ in }

var pendingRequests: [(request: Request, context: VerifyContext?)] = []

func removePendingRequest(topic: String) {
pendingRequests.removeAll(where: { $0.request.topic == topic })
removePendingRequestCalled(topic)
}

func getSessionRequest(id: JSONRPC.RPCID) -> (request: Request, context: VerifyContext?)? {
fatalError("Unimplemented")
}

func getPendingRequests() -> [(request: Request, context: VerifyContext?)] {
return pendingRequests
pendingRequests
}

func getPendingRequestsSortedByTimestamp() -> [(request: Request, context: VerifyContext?)] {
fatalError("Unimplemented")
}
}
#endif
4 changes: 2 additions & 2 deletions Sources/WalletConnectSign/Sign/SessionRequestsProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import Combine
import Foundation

class SessionRequestsProvider {
private let historyService: HistoryService
private let historyService: HistoryServiceProtocol
private var sessionRequestPublisherSubject = PassthroughSubject<(request: Request, context: VerifyContext?), Never>()
private var lastEmitTime: Date?
private let debounceInterval: TimeInterval = 1
Expand All @@ -11,7 +11,7 @@ class SessionRequestsProvider {
sessionRequestPublisherSubject.eraseToAnyPublisher()
}

init(historyService: HistoryService) {
init(historyService: HistoryServiceProtocol) {
self.historyService = historyService
}

Expand Down
7 changes: 7 additions & 0 deletions Sources/WalletConnectUtils/RPCHistory/RPCHistory.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import Foundation

public protocol RPCHistoryProtocol {

func deleteAll(forTopic topic: String)

func deleteAll(forTopics topics: [String])
}

Expand Down Expand Up @@ -152,6 +155,10 @@ extension RPCHistory {
#if DEBUG
class MockRPCHistory: RPCHistoryProtocol {
var deletedTopics: [String] = []

func deleteAll(forTopic topic: String) {
deletedTopics.append(topic)
}

func deleteAll(forTopics topics: [String]) {
deletedTopics.append(contentsOf: topics)
Expand Down
45 changes: 44 additions & 1 deletion Tests/WalletConnectSignTests/SessionEngineTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ final class SessionEngineTests: XCTestCase {
var networkingInteractor: NetworkingInteractorMock!
var sessionStorage: WCSessionStorageMock!
var verifyContextStore: CodableStore<VerifyContext>!
var rpcHistory: RPCHistory!
var engine: SessionEngine!

override func setUp() {
networkingInteractor = NetworkingInteractorMock()
sessionStorage = WCSessionStorageMock()
let defaults = RuntimeKeyValueStorage()
let rpcHistory = RPCHistory(
rpcHistory = RPCHistory(
keyValueStore: .init(
defaults: defaults,
identifier: ""
Expand Down Expand Up @@ -62,4 +63,46 @@ final class SessionEngineTests: XCTestCase {

wait(for: [expectation], timeout: 0.5)
}

func testRemovePendingRequestsOnSessionExpiration() {
let expectation = expectation(
description: "Remove pending requests on session expiration"
)

let historyService = MockHistoryService()

engine = SessionEngine(
networkingInteractor: networkingInteractor,
historyService: historyService,
verifyContextStore: verifyContextStore,
verifyClient: VerifyClientMock(),
kms: KeyManagementServiceMock(),
sessionStore: sessionStorage,
logger: ConsoleLoggerMock(),
sessionRequestsProvider: SessionRequestsProvider(
historyService: historyService),
invalidRequestsSanitiser: InvalidRequestsSanitiser(
historyService: historyService,
history: rpcHistory
)
)

let expectedTopic = "topic"

let session = WCSession.stub(
topic: expectedTopic,
namespaces: SessionNamespace.stubDictionary()
)

sessionStorage.setSession(session)

historyService.removePendingRequestCalled = { topic in
XCTAssertEqual(topic, expectedTopic)
expectation.fulfill()
}

sessionStorage.onSessionExpiration!(session)

wait(for: [expectation], timeout: 0.5)
}
}

0 comments on commit 92bfb50

Please sign in to comment.