Skip to content

Commit

Permalink
Add invalid requests sanitiser
Browse files Browse the repository at this point in the history
  • Loading branch information
llbartekll committed May 15, 2024
1 parent 744c76d commit 0391948
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 23 deletions.
14 changes: 12 additions & 2 deletions Sources/WalletConnectSign/Engine/Common/SessionEngine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ final class SessionEngine {
private var publishers = [AnyCancellable]()
private let logger: ConsoleLogging
private let sessionRequestsProvider: SessionRequestsProvider
private let invalidRequestsSanitiser: InvalidRequestsSanitiser

init(
networkingInteractor: NetworkInteracting,
Expand All @@ -35,7 +36,8 @@ final class SessionEngine {
kms: KeyManagementServiceProtocol,
sessionStore: WCSessionStorage,
logger: ConsoleLogging,
sessionRequestsProvider: SessionRequestsProvider
sessionRequestsProvider: SessionRequestsProvider,
invalidRequestsSanitiser: InvalidRequestsSanitiser
) {
self.networkingInteractor = networkingInteractor
self.historyService = historyService
Expand All @@ -45,15 +47,23 @@ final class SessionEngine {
self.sessionStore = sessionStore
self.logger = logger
self.sessionRequestsProvider = sessionRequestsProvider
self.invalidRequestsSanitiser = invalidRequestsSanitiser

setupConnectionSubscriptions()
setupRequestSubscriptions()
setupResponseSubscriptions()
setupUpdateSubscriptions()
setupExpirationSubscriptions()
DispatchQueue.main.asyncAfter(deadline: .now() + 1) { [weak self] in
sessionRequestsProvider.emitRequestIfPending()
self?.sessionRequestsProvider.emitRequestIfPending()
}

removeInvalidSessionRequests()
}

private func removeInvalidSessionRequests() {
let sessionTopics = Set(sessionStore.getAll().map {$0.topic})
invalidRequestsSanitiser.removeInvalidSessionRequests(validSessionTopics: sessionTopics)
}

func hasSession(for topic: String) -> Bool {
Expand Down
18 changes: 17 additions & 1 deletion Sources/WalletConnectSign/Services/HistoryService.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import Foundation

final class HistoryService {
protocol HistoryServiceProtocol {
func getPendingRequests() -> [(request: Request, context: VerifyContext?)]
}

final class HistoryService: HistoryServiceProtocol {

private let history: RPCHistory
private let verifyContextStore: CodableStore<VerifyContext>
Expand All @@ -25,6 +29,8 @@ final class HistoryService {
getPendingRequestsSortedByTimestamp()
}



func getPendingRequestsSortedByTimestamp() -> [(request: Request, context: VerifyContext?)] {
let requests = history.getPending()
.compactMap { mapRequestRecord($0) }
Expand Down Expand Up @@ -80,3 +86,13 @@ private extension HistoryService {
return (mappedRequest, record.id, record.timestamp)
}
}

#if DEBUG
class MockHistoryService: HistoryServiceProtocol {
var pendingRequests: [(request: Request, context: VerifyContext?)] = []

func getPendingRequests() -> [(request: Request, context: VerifyContext?)] {
return pendingRequests
}
}
#endif
20 changes: 20 additions & 0 deletions Sources/WalletConnectSign/Services/InvalidRequestsSanitiser.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

import Foundation

class InvalidRequestsSanitiser {
let historyService: HistoryServiceProtocol
private let history: RPCHistoryProtocol

init(historyService: HistoryServiceProtocol, history: RPCHistoryProtocol) {
self.historyService = historyService
self.history = history
}

func removeInvalidSessionRequests(validSessionTopics: Set<String>) {
let pendingRequests = historyService.getPendingRequests()
let invalidTopics = Set(pendingRequests.map { $0.request.topic }).subtracting(validSessionTopics)
if !invalidTopics.isEmpty {
history.deleteAll(forTopics: Array(invalidTopics))
}
}
}
3 changes: 2 additions & 1 deletion Sources/WalletConnectSign/Sign/SignClientFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ public struct SignClientFactory {
let historyService = HistoryService(history: rpcHistory, verifyContextStore: verifyContextStore)
let verifyClient = VerifyClientFactory.create()
let sessionRequestsProvider = SessionRequestsProvider(historyService: historyService)
let sessionEngine = SessionEngine(networkingInteractor: networkingClient, historyService: historyService, verifyContextStore: verifyContextStore, verifyClient: verifyClient, kms: kms, sessionStore: sessionStore, logger: logger, sessionRequestsProvider: sessionRequestsProvider)
let invalidRequestsSanitiser = InvalidRequestsSanitiser(historyService: historyService, history: rpcHistory)
let sessionEngine = SessionEngine(networkingInteractor: networkingClient, historyService: historyService, verifyContextStore: verifyContextStore, verifyClient: verifyClient, kms: kms, sessionStore: sessionStore, logger: logger, sessionRequestsProvider: sessionRequestsProvider, invalidRequestsSanitiser: invalidRequestsSanitiser)
let nonControllerSessionStateMachine = NonControllerSessionStateMachine(networkingInteractor: networkingClient, kms: kms, sessionStore: sessionStore, logger: logger)
let controllerSessionStateMachine = ControllerSessionStateMachine(networkingInteractor: networkingClient, kms: kms, sessionStore: sessionStore, logger: logger)
let sessionExtendRequester = SessionExtendRequester(sessionStore: sessionStore, networkingInteractor: networkingClient)
Expand Down
16 changes: 15 additions & 1 deletion Sources/WalletConnectUtils/RPCHistory/RPCHistory.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import Foundation

public final class RPCHistory {
public protocol RPCHistoryProtocol {
func deleteAll(forTopics topics: [String])
}

public final class RPCHistory: RPCHistoryProtocol {

public struct Record: Codable {
public enum Origin: String, Codable {
Expand Down Expand Up @@ -144,3 +148,13 @@ extension RPCHistory {
}
}
}

#if DEBUG
class MockRPCHistory: RPCHistoryProtocol {
var deletedTopics: [String] = []

func deleteAll(forTopics topics: [String]) {
deletedTopics.append(contentsOf: topics)
}
}
#endif
69 changes: 69 additions & 0 deletions Tests/WalletConnectSignTests/InvalidRequestsSanitiserTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import XCTest
@testable import WalletConnectSign
@testable import WalletConnectUtils

class InvalidRequestsSanitiserTests: XCTestCase {
var sanitiser: InvalidRequestsSanitiser!
var mockHistoryService: MockHistoryService!
var mockRPCHistory: MockRPCHistory!

override func setUp() {
super.setUp()
mockHistoryService = MockHistoryService()
mockRPCHistory = MockRPCHistory()
sanitiser = InvalidRequestsSanitiser(historyService: mockHistoryService, history: mockRPCHistory)
}

override func tearDown() {
sanitiser = nil
mockHistoryService = nil
mockRPCHistory = nil
super.tearDown()
}

func testRemoveInvalidSessionRequests_noPendingRequests() {
let validSessionTopics: Set<String> = ["validTopic1", "validTopic2"]

sanitiser.removeInvalidSessionRequests(validSessionTopics: validSessionTopics)

XCTAssertTrue(mockRPCHistory.deletedTopics.isEmpty)
}

func testRemoveInvalidSessionRequests_allRequestsValid() {
let validSessionTopics: Set<String> = ["validTopic1", "validTopic2"]
mockHistoryService.pendingRequests = [
(request: try! Request(topic: "validTopic1", method: "method1", params: AnyCodable("params1"), chainId: Blockchain("eip155:1")!), context: nil),
(request: try! Request(topic: "validTopic2", method: "method2", params: AnyCodable("params2"), chainId: Blockchain("eip155:1")!), context: nil)
]

sanitiser.removeInvalidSessionRequests(validSessionTopics: validSessionTopics)

XCTAssertTrue(mockRPCHistory.deletedTopics.isEmpty)
}

func testRemoveInvalidSessionRequests_someRequestsInvalid() {
let validSessionTopics: Set<String> = ["validTopic1", "validTopic2"]
mockHistoryService.pendingRequests = [
(request: try! Request(topic: "validTopic1", method: "method1", params: AnyCodable("params1"), chainId: Blockchain("eip155:1")!), context: nil),
(request: try! Request(topic: "invalidTopic1", method: "method2", params: AnyCodable("params2"), chainId: Blockchain("eip155:1")!), context: nil),
(request: try! Request(topic: "invalidTopic2", method: "method3", params: AnyCodable("params3"), chainId: Blockchain("eip155:1")!), context: nil)
]

sanitiser.removeInvalidSessionRequests(validSessionTopics: validSessionTopics)

XCTAssertEqual(mockRPCHistory.deletedTopics.sorted(), ["invalidTopic1", "invalidTopic2"])
}

func testRemoveInvalidSessionRequests_withEmptyValidSessionTopics() {
let validSessionTopics: Set<String> = []

mockHistoryService.pendingRequests = [
(request: try! Request(topic: "invalidTopic1", method: "method1", params: AnyCodable("params1"), chainId: Blockchain("eip155:1")!), context: nil),
(request: try! Request(topic: "invalidTopic2", method: "method2", params: AnyCodable("params2"), chainId: Blockchain("eip155:1")!), context: nil)
]

sanitiser.removeInvalidSessionRequests(validSessionTopics: validSessionTopics)

XCTAssertEqual(mockRPCHistory.deletedTopics.sorted(), ["invalidTopic1", "invalidTopic2"])
}
}
32 changes: 14 additions & 18 deletions Tests/WalletConnectSignTests/SessionEngineTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,29 @@ final class SessionEngineTests: XCTestCase {
override func setUp() {
networkingInteractor = NetworkingInteractorMock()
sessionStorage = WCSessionStorageMock()
let defaults = RuntimeKeyValueStorage()
let rpcHistory = RPCHistory(
keyValueStore: .init(
defaults: defaults,
identifier: ""
)
)
verifyContextStore = CodableStore<VerifyContext>(defaults: RuntimeKeyValueStorage(), identifier: "")
let historyService = HistoryService(
history: rpcHistory,
verifyContextStore: verifyContextStore
)
engine = SessionEngine(
networkingInteractor: networkingInteractor,
historyService: HistoryService(
history: RPCHistory(
keyValueStore: .init(
defaults: RuntimeKeyValueStorage(),
identifier: ""
)
),
verifyContextStore: verifyContextStore
),
historyService: historyService,
verifyContextStore: verifyContextStore,
verifyClient: VerifyClientMock(),
kms: KeyManagementServiceMock(),
sessionStore: sessionStorage,
logger: ConsoleLoggerMock(),
sessionRequestsProvider: SessionRequestsProvider(
historyService: HistoryService(
history: RPCHistory(
keyValueStore: .init(
defaults: RuntimeKeyValueStorage(),
identifier: ""
)
),
verifyContextStore: verifyContextStore
))
historyService: historyService),
invalidRequestsSanitiser: InvalidRequestsSanitiser(historyService: historyService, history: rpcHistory)
)
}

Expand Down

0 comments on commit 0391948

Please sign in to comment.