From 484b2b59e53ce187b40ed5d285dcc2577d4be55e Mon Sep 17 00:00:00 2001 From: Doug Richar Date: Fri, 27 Sep 2024 18:27:59 -0400 Subject: [PATCH] feat(wallets): add WalletConnect v1 session management (#275) * feat(wallets): add WalletConnect v1 session management This commit introduces WalletConnect v1 session management functionality for Pera (v1) and Defly wallets to improve multi-wallet support. - Added `manageWalletConnectSession` method to Pera and Defly wallet classes - Implemented session backup and restore logic in `connect` and `setActive` methods - Updated `StorageAdapter` mock in tests to handle WalletConnect data - Added new test cases for WalletConnect session management * refactor(wallets): improve WalletConnect session management - Move `manageWalletConnectSession` to `BaseWallet` - Update `connect`, `disconnect`, and `setActive` methods - Add delay after disconnect to prevent race condition - Adjust tests to reflect new behavior --- .../use-wallet/src/__tests__/manager.test.ts | 3 +- .../src/__tests__/wallets/custom.test.ts | 3 +- .../src/__tests__/wallets/defly.test.ts | 214 +++++++++++++++++- .../src/__tests__/wallets/exodus.test.ts | 3 +- .../src/__tests__/wallets/kibisis.test.ts | 3 +- .../src/__tests__/wallets/kmd.test.ts | 3 +- .../src/__tests__/wallets/lute.test.ts | 3 +- .../src/__tests__/wallets/magic.test.ts | 3 +- .../src/__tests__/wallets/mnemonic.test.ts | 3 +- .../src/__tests__/wallets/pera.test.ts | 214 +++++++++++++++++- .../src/__tests__/wallets/pera2.test.ts | 3 +- .../__tests__/wallets/walletconnect.test.ts | 3 +- packages/use-wallet/src/storage.ts | 7 + packages/use-wallet/src/wallets/base.ts | 23 ++ packages/use-wallet/src/wallets/defly.ts | 34 ++- packages/use-wallet/src/wallets/pera.ts | 32 ++- 16 files changed, 529 insertions(+), 25 deletions(-) diff --git a/packages/use-wallet/src/__tests__/manager.test.ts b/packages/use-wallet/src/__tests__/manager.test.ts index 259ad23f..b1ea00b5 100644 --- a/packages/use-wallet/src/__tests__/manager.test.ts +++ b/packages/use-wallet/src/__tests__/manager.test.ts @@ -38,7 +38,8 @@ vi.mock('src/logger', () => { vi.mock('src/storage', () => ({ StorageAdapter: { getItem: vi.fn(), - setItem: vi.fn() + setItem: vi.fn(), + removeItem: vi.fn() } })) diff --git a/packages/use-wallet/src/__tests__/wallets/custom.test.ts b/packages/use-wallet/src/__tests__/wallets/custom.test.ts index 744fc856..1dce7b6e 100644 --- a/packages/use-wallet/src/__tests__/wallets/custom.test.ts +++ b/packages/use-wallet/src/__tests__/wallets/custom.test.ts @@ -22,7 +22,8 @@ vi.mock('src/logger', () => ({ vi.mock('src/storage', () => ({ StorageAdapter: { getItem: vi.fn(), - setItem: vi.fn() + setItem: vi.fn(), + removeItem: vi.fn() } })) diff --git a/packages/use-wallet/src/__tests__/wallets/defly.test.ts b/packages/use-wallet/src/__tests__/wallets/defly.test.ts index add388d2..c8f4cc64 100644 --- a/packages/use-wallet/src/__tests__/wallets/defly.test.ts +++ b/packages/use-wallet/src/__tests__/wallets/defly.test.ts @@ -23,7 +23,8 @@ vi.mock('src/logger', () => ({ vi.mock('src/storage', () => ({ StorageAdapter: { getItem: vi.fn(), - setItem: vi.fn() + setItem: vi.fn(), + removeItem: vi.fn() } })) @@ -87,10 +88,15 @@ describe('DeflyWallet', () => { beforeEach(() => { vi.clearAllMocks() + let mockWalletConnectData: string | null = null + vi.mocked(StorageAdapter.getItem).mockImplementation((key: string) => { if (key === LOCAL_STORAGE_KEY && mockInitialState !== null) { return JSON.stringify(mockInitialState) } + if (key === 'walletconnect') { + return mockWalletConnectData + } return null }) @@ -98,6 +104,15 @@ describe('DeflyWallet', () => { if (key === LOCAL_STORAGE_KEY) { mockInitialState = JSON.parse(value) } + if (key.startsWith('walletconnect-')) { + mockWalletConnectData = value + } + }) + + vi.mocked(StorageAdapter.removeItem).mockImplementation((key: string) => { + if (key === 'walletconnect') { + mockWalletConnectData = null + } }) mockLogger = { @@ -154,6 +169,29 @@ describe('DeflyWallet', () => { expect(mockDeflyWallet.connector.on).toHaveBeenCalledWith('disconnect', expect.any(Function)) }) + + it('should backup WalletConnect session when connecting and another wallet is active', async () => { + const mockWalletConnectData = 'mockWalletConnectData' + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(mockWalletConnectData) + mockDeflyWallet.connect.mockResolvedValueOnce([account1.address]) + + // Set Pera as the active wallet + store.setState((state) => ({ + ...state, + activeWallet: WalletId.PERA + })) + + const manageWalletConnectSessionSpy = vi.spyOn(wallet, 'manageWalletConnectSession' as any) + + await wallet.connect() + + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('backup', WalletId.PERA) + expect(StorageAdapter.getItem).toHaveBeenCalledWith('walletconnect') + expect(StorageAdapter.setItem).toHaveBeenCalledWith( + `walletconnect-${WalletId.PERA}`, + mockWalletConnectData + ) + }) }) describe('disconnect', () => { @@ -168,6 +206,87 @@ describe('DeflyWallet', () => { expect(wallet.isConnected).toBe(false) }) + it('should backup and restore active wallet session when disconnecting non-active wallet', async () => { + const mockWalletConnectData = 'mockWalletConnectData' + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(mockWalletConnectData) + mockDeflyWallet.connect.mockResolvedValueOnce([account1.address]) + await wallet.connect() + + // Set Pera as the active wallet + store.setState((state) => ({ + ...state, + activeWallet: WalletId.PERA + })) + + const manageWalletConnectSessionSpy = vi.spyOn(wallet, 'manageWalletConnectSession' as any) + + await wallet.disconnect() + + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('backup', WalletId.PERA) + expect(mockDeflyWallet.disconnect).toHaveBeenCalled() + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('restore', WalletId.PERA) + expect(store.state.wallets[WalletId.DEFLY]).toBeUndefined() + }) + + it('should not backup or restore session when disconnecting active wallet', async () => { + mockDeflyWallet.connect.mockResolvedValueOnce([account1.address]) + await wallet.connect() + + // Set Defly as the active wallet + store.setState((state) => ({ + ...state, + activeWallet: WalletId.DEFLY + })) + + const manageWalletConnectSessionSpy = vi.spyOn(wallet, 'manageWalletConnectSession' as any) + + await wallet.disconnect() + + expect(manageWalletConnectSessionSpy).not.toHaveBeenCalled() + expect(mockDeflyWallet.disconnect).toHaveBeenCalled() + expect(store.state.wallets[WalletId.DEFLY]).toBeUndefined() + }) + + it('should backup active wallet, restore inactive wallet, disconnect, and restore active wallet when disconnecting an inactive wallet', async () => { + mockDeflyWallet.connect.mockResolvedValueOnce([account1.address]) + await wallet.connect() + + // Set Pera as the active wallet + store.setState((state) => ({ + ...state, + activeWallet: WalletId.PERA + })) + + const manageWalletConnectSessionSpy = vi.spyOn(wallet, 'manageWalletConnectSession' as any) + + await wallet.disconnect() + + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('backup', WalletId.PERA) + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('restore', WalletId.DEFLY) + expect(mockDeflyWallet.disconnect).toHaveBeenCalled() + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('restore', WalletId.PERA) + expect(store.state.wallets[WalletId.DEFLY]).toBeUndefined() + }) + + it('should not remove backup when disconnecting the active wallet', async () => { + mockDeflyWallet.connect.mockResolvedValueOnce([account1.address]) + await wallet.connect() + + // Set Defly as the active wallet + store.setState((state) => ({ + ...state, + activeWallet: WalletId.DEFLY + })) + + const manageWalletConnectSessionSpy = vi.spyOn(wallet, 'manageWalletConnectSession' as any) + + await wallet.disconnect() + + expect(manageWalletConnectSessionSpy).not.toHaveBeenCalled() + expect(mockDeflyWallet.disconnect).toHaveBeenCalled() + expect(store.state.wallets[WalletId.DEFLY]).toBeUndefined() + }) + it('should throw an error if client.disconnect fails', async () => { mockDeflyWallet.connect.mockResolvedValueOnce([account1.address]) mockDeflyWallet.disconnect.mockRejectedValueOnce(new Error('Disconnect error')) @@ -176,9 +295,8 @@ describe('DeflyWallet', () => { await expect(wallet.disconnect()).rejects.toThrow('Disconnect error') - // Should still update store/state - expect(store.state.wallets[WalletId.DEFLY]).toBeUndefined() - expect(wallet.isConnected).toBe(false) + expect(store.state.wallets[WalletId.DEFLY]).toBeDefined() + expect(wallet.isConnected).toBe(true) }) }) @@ -333,6 +451,94 @@ describe('DeflyWallet', () => { }) }) + describe('setActive', () => { + it('should set the wallet as active', async () => { + const mockWalletConnectData = 'mockWalletConnectData' + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(mockWalletConnectData) + mockDeflyWallet.connect.mockResolvedValueOnce([account1.address]) + + await wallet.connect() + wallet.setActive() + + expect(store.state.activeWallet).toBe(WalletId.DEFLY) + expect(StorageAdapter.setItem).toHaveBeenCalledWith('walletconnect', mockWalletConnectData) + expect(StorageAdapter.removeItem).toHaveBeenCalledWith(`walletconnect-${WalletId.DEFLY}`) + }) + + it('should backup current active wallet session and restore Pera session when setting active', async () => { + const mockWalletConnectData = 'mockWalletConnectData' + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(mockWalletConnectData) + + store.setState((state) => ({ + ...state, + activeWallet: WalletId.PERA, + wallets: { + ...state.wallets, + [WalletId.DEFLY]: { accounts: [account1], activeAccount: account1 }, + [WalletId.PERA]: { accounts: [account2], activeAccount: account2 } + } + })) + + const manageWalletConnectSessionSpy = vi.spyOn(wallet, 'manageWalletConnectSession' as any) + + wallet.setActive() + + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('backup', WalletId.PERA) + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('restore') + expect(store.state.activeWallet).toBe(WalletId.DEFLY) + }) + }) + + describe('manageWalletConnectSession', () => { + it('should backup WalletConnect session', async () => { + const mockWalletConnectData = 'mockWalletConnectData' + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(mockWalletConnectData) + + // @ts-expect-error - Accessing protected method for testing + wallet.manageWalletConnectSession('backup', WalletId.PERA) + + expect(StorageAdapter.getItem).toHaveBeenCalledWith('walletconnect') + expect(StorageAdapter.setItem).toHaveBeenCalledWith( + `walletconnect-${WalletId.PERA}`, + mockWalletConnectData + ) + }) + + it('should not backup WalletConnect session if no data exists', async () => { + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(null) + + // @ts-expect-error - Accessing protected method for testing + wallet.manageWalletConnectSession('backup', WalletId.PERA) + + expect(StorageAdapter.getItem).toHaveBeenCalledWith('walletconnect') + expect(StorageAdapter.setItem).not.toHaveBeenCalled() + expect(StorageAdapter.removeItem).not.toHaveBeenCalled() + }) + + it('should restore WalletConnect session', () => { + const mockWalletConnectData = 'mockWalletConnectData' + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(mockWalletConnectData) + + // @ts-expect-error - Accessing protected method for testing + wallet.manageWalletConnectSession('restore') + + expect(StorageAdapter.getItem).toHaveBeenCalledWith(`walletconnect-${WalletId.DEFLY}`) + expect(StorageAdapter.setItem).toHaveBeenCalledWith('walletconnect', mockWalletConnectData) + expect(StorageAdapter.removeItem).toHaveBeenCalledWith(`walletconnect-${WalletId.DEFLY}`) + }) + + it('should not restore WalletConnect session if no backup exists', () => { + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(null) + + // @ts-expect-error - Accessing protected method for testing + wallet.manageWalletConnectSession('restore') + + expect(StorageAdapter.getItem).toHaveBeenCalledWith(`walletconnect-${WalletId.DEFLY}`) + expect(StorageAdapter.setItem).not.toHaveBeenCalled() + expect(StorageAdapter.removeItem).not.toHaveBeenCalled() + }) + }) + describe('signing transactions', () => { // Connected accounts const connectedAcct1 = '7ZUECA7HFLZTXENRV24SHLU4AVPUTMTTDUFUBNBD64C73F3UHRTHAIOF6Q' diff --git a/packages/use-wallet/src/__tests__/wallets/exodus.test.ts b/packages/use-wallet/src/__tests__/wallets/exodus.test.ts index 5b2e60e9..9cd134b6 100644 --- a/packages/use-wallet/src/__tests__/wallets/exodus.test.ts +++ b/packages/use-wallet/src/__tests__/wallets/exodus.test.ts @@ -24,7 +24,8 @@ vi.mock('src/logger', () => ({ vi.mock('src/storage', () => ({ StorageAdapter: { getItem: vi.fn(), - setItem: vi.fn() + setItem: vi.fn(), + removeItem: vi.fn() } })) diff --git a/packages/use-wallet/src/__tests__/wallets/kibisis.test.ts b/packages/use-wallet/src/__tests__/wallets/kibisis.test.ts index e1fa8584..77785ff2 100644 --- a/packages/use-wallet/src/__tests__/wallets/kibisis.test.ts +++ b/packages/use-wallet/src/__tests__/wallets/kibisis.test.ts @@ -33,7 +33,8 @@ vi.mock('src/logger', () => ({ vi.mock('src/storage', () => ({ StorageAdapter: { getItem: vi.fn(), - setItem: vi.fn() + setItem: vi.fn(), + removeItem: vi.fn() } })) diff --git a/packages/use-wallet/src/__tests__/wallets/kmd.test.ts b/packages/use-wallet/src/__tests__/wallets/kmd.test.ts index 34dcac66..4550d27e 100644 --- a/packages/use-wallet/src/__tests__/wallets/kmd.test.ts +++ b/packages/use-wallet/src/__tests__/wallets/kmd.test.ts @@ -23,7 +23,8 @@ vi.mock('src/logger', () => ({ vi.mock('src/storage', () => ({ StorageAdapter: { getItem: vi.fn(), - setItem: vi.fn() + setItem: vi.fn(), + removeItem: vi.fn() } })) diff --git a/packages/use-wallet/src/__tests__/wallets/lute.test.ts b/packages/use-wallet/src/__tests__/wallets/lute.test.ts index 20bed857..16d1d5fe 100644 --- a/packages/use-wallet/src/__tests__/wallets/lute.test.ts +++ b/packages/use-wallet/src/__tests__/wallets/lute.test.ts @@ -24,7 +24,8 @@ vi.mock('src/logger', () => ({ vi.mock('src/storage', () => ({ StorageAdapter: { getItem: vi.fn(), - setItem: vi.fn() + setItem: vi.fn(), + removeItem: vi.fn() } })) diff --git a/packages/use-wallet/src/__tests__/wallets/magic.test.ts b/packages/use-wallet/src/__tests__/wallets/magic.test.ts index 9f6fdcdf..f3c17042 100644 --- a/packages/use-wallet/src/__tests__/wallets/magic.test.ts +++ b/packages/use-wallet/src/__tests__/wallets/magic.test.ts @@ -24,7 +24,8 @@ vi.mock('src/logger', () => ({ vi.mock('src/storage', () => ({ StorageAdapter: { getItem: vi.fn(), - setItem: vi.fn() + setItem: vi.fn(), + removeItem: vi.fn() } })) diff --git a/packages/use-wallet/src/__tests__/wallets/mnemonic.test.ts b/packages/use-wallet/src/__tests__/wallets/mnemonic.test.ts index 843fc308..3873764d 100644 --- a/packages/use-wallet/src/__tests__/wallets/mnemonic.test.ts +++ b/packages/use-wallet/src/__tests__/wallets/mnemonic.test.ts @@ -29,7 +29,8 @@ vi.mock('src/logger', () => ({ vi.mock('src/storage', () => ({ StorageAdapter: { getItem: vi.fn(), - setItem: vi.fn() + setItem: vi.fn(), + removeItem: vi.fn() } })) diff --git a/packages/use-wallet/src/__tests__/wallets/pera.test.ts b/packages/use-wallet/src/__tests__/wallets/pera.test.ts index 90a416fb..d5b290aa 100644 --- a/packages/use-wallet/src/__tests__/wallets/pera.test.ts +++ b/packages/use-wallet/src/__tests__/wallets/pera.test.ts @@ -23,7 +23,8 @@ vi.mock('src/logger', () => ({ vi.mock('src/storage', () => ({ StorageAdapter: { getItem: vi.fn(), - setItem: vi.fn() + setItem: vi.fn(), + removeItem: vi.fn() } })) @@ -87,10 +88,15 @@ describe('PeraWallet', () => { beforeEach(() => { vi.clearAllMocks() + let mockWalletConnectData: string | null = null + vi.mocked(StorageAdapter.getItem).mockImplementation((key: string) => { if (key === LOCAL_STORAGE_KEY && mockInitialState !== null) { return JSON.stringify(mockInitialState) } + if (key === 'walletconnect') { + return mockWalletConnectData + } return null }) @@ -98,6 +104,15 @@ describe('PeraWallet', () => { if (key === LOCAL_STORAGE_KEY) { mockInitialState = JSON.parse(value) } + if (key.startsWith('walletconnect-')) { + mockWalletConnectData = value + } + }) + + vi.mocked(StorageAdapter.removeItem).mockImplementation((key: string) => { + if (key === 'walletconnect') { + mockWalletConnectData = null + } }) mockLogger = { @@ -154,6 +169,29 @@ describe('PeraWallet', () => { expect(mockPeraWallet.connector.on).toHaveBeenCalledWith('disconnect', expect.any(Function)) }) + + it('should backup WalletConnect session when connecting and another wallet is active', async () => { + const mockWalletConnectData = 'mockWalletConnectData' + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(mockWalletConnectData) + mockPeraWallet.connect.mockResolvedValueOnce([account1.address]) + + // Set Defly as the active wallet + store.setState((state) => ({ + ...state, + activeWallet: WalletId.DEFLY + })) + + const manageWalletConnectSessionSpy = vi.spyOn(wallet, 'manageWalletConnectSession' as any) + + await wallet.connect() + + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('backup', WalletId.DEFLY) + expect(StorageAdapter.getItem).toHaveBeenCalledWith('walletconnect') + expect(StorageAdapter.setItem).toHaveBeenCalledWith( + `walletconnect-${WalletId.DEFLY}`, + mockWalletConnectData + ) + }) }) describe('disconnect', () => { @@ -168,6 +206,87 @@ describe('PeraWallet', () => { expect(wallet.isConnected).toBe(false) }) + it('should backup and restore active wallet session when disconnecting non-active wallet', async () => { + const mockWalletConnectData = 'mockWalletConnectData' + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(mockWalletConnectData) + mockPeraWallet.connect.mockResolvedValueOnce([account1.address]) + await wallet.connect() + + // Set Defly as the active wallet + store.setState((state) => ({ + ...state, + activeWallet: WalletId.DEFLY + })) + + const manageWalletConnectSessionSpy = vi.spyOn(wallet, 'manageWalletConnectSession' as any) + + await wallet.disconnect() + + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('backup', WalletId.DEFLY) + expect(mockPeraWallet.disconnect).toHaveBeenCalled() + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('restore', WalletId.DEFLY) + expect(store.state.wallets[WalletId.PERA]).toBeUndefined() + }) + + it('should not backup or restore session when disconnecting active wallet', async () => { + mockPeraWallet.connect.mockResolvedValueOnce([account1.address]) + await wallet.connect() + + // Set Pera as the active wallet + store.setState((state) => ({ + ...state, + activeWallet: WalletId.PERA + })) + + const manageWalletConnectSessionSpy = vi.spyOn(wallet, 'manageWalletConnectSession' as any) + + await wallet.disconnect() + + expect(manageWalletConnectSessionSpy).not.toHaveBeenCalled() + expect(mockPeraWallet.disconnect).toHaveBeenCalled() + expect(store.state.wallets[WalletId.PERA]).toBeUndefined() + }) + + it('should backup active wallet, restore inactive wallet, disconnect, and restore active wallet when disconnecting an inactive wallet', async () => { + mockPeraWallet.connect.mockResolvedValueOnce([account1.address]) + await wallet.connect() + + // Set Defly as the active wallet + store.setState((state) => ({ + ...state, + activeWallet: WalletId.DEFLY + })) + + const manageWalletConnectSessionSpy = vi.spyOn(wallet, 'manageWalletConnectSession' as any) + + await wallet.disconnect() + + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('backup', WalletId.DEFLY) + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('restore', WalletId.PERA) + expect(mockPeraWallet.disconnect).toHaveBeenCalled() + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('restore', WalletId.DEFLY) + expect(store.state.wallets[WalletId.PERA]).toBeUndefined() + }) + + it('should not remove backup when disconnecting the active wallet', async () => { + mockPeraWallet.connect.mockResolvedValueOnce([account1.address]) + await wallet.connect() + + // Set Pera as the active wallet + store.setState((state) => ({ + ...state, + activeWallet: WalletId.PERA + })) + + const manageWalletConnectSessionSpy = vi.spyOn(wallet, 'manageWalletConnectSession' as any) + + await wallet.disconnect() + + expect(manageWalletConnectSessionSpy).not.toHaveBeenCalled() + expect(mockPeraWallet.disconnect).toHaveBeenCalled() + expect(store.state.wallets[WalletId.PERA]).toBeUndefined() + }) + it('should throw an error if client.disconnect fails', async () => { mockPeraWallet.connect.mockResolvedValueOnce([account1.address]) mockPeraWallet.disconnect.mockRejectedValueOnce(new Error('Disconnect error')) @@ -176,9 +295,8 @@ describe('PeraWallet', () => { await expect(wallet.disconnect()).rejects.toThrow('Disconnect error') - // Should still update store/state - expect(store.state.wallets[WalletId.PERA]).toBeUndefined() - expect(wallet.isConnected).toBe(false) + expect(store.state.wallets[WalletId.PERA]).toBeDefined() + expect(wallet.isConnected).toBe(true) }) }) @@ -333,6 +451,94 @@ describe('PeraWallet', () => { }) }) + describe('setActive', () => { + it('should set the wallet as active', async () => { + const mockWalletConnectData = 'mockWalletConnectData' + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(mockWalletConnectData) + mockPeraWallet.connect.mockResolvedValueOnce([account1.address]) + + await wallet.connect() + wallet.setActive() + + expect(store.state.activeWallet).toBe(WalletId.PERA) + expect(StorageAdapter.setItem).toHaveBeenCalledWith('walletconnect', mockWalletConnectData) + expect(StorageAdapter.removeItem).toHaveBeenCalledWith(`walletconnect-${WalletId.PERA}`) + }) + + it('should backup current active wallet session and restore Pera session when setting active', async () => { + const mockWalletConnectData = 'mockWalletConnectData' + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(mockWalletConnectData) + + store.setState((state) => ({ + ...state, + activeWallet: WalletId.DEFLY, + wallets: { + ...state.wallets, + [WalletId.DEFLY]: { accounts: [account1], activeAccount: account1 }, + [WalletId.PERA]: { accounts: [account2], activeAccount: account2 } + } + })) + + const manageWalletConnectSessionSpy = vi.spyOn(wallet, 'manageWalletConnectSession' as any) + + wallet.setActive() + + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('backup', WalletId.DEFLY) + expect(manageWalletConnectSessionSpy).toHaveBeenCalledWith('restore') + expect(store.state.activeWallet).toBe(WalletId.PERA) + }) + }) + + describe('manageWalletConnectSession', () => { + it('should backup WalletConnect session', async () => { + const mockWalletConnectData = 'mockWalletConnectData' + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(mockWalletConnectData) + + // @ts-expect-error - Accessing private method for testing + wallet.manageWalletConnectSession('backup', WalletId.DEFLY) + + expect(StorageAdapter.getItem).toHaveBeenCalledWith('walletconnect') + expect(StorageAdapter.setItem).toHaveBeenCalledWith( + `walletconnect-${WalletId.DEFLY}`, + mockWalletConnectData + ) + }) + + it('should not backup WalletConnect session if no data exists', async () => { + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(null) + + // @ts-expect-error - Accessing private method for testing + wallet.manageWalletConnectSession('backup', WalletId.DEFLY) + + expect(StorageAdapter.getItem).toHaveBeenCalledWith('walletconnect') + expect(StorageAdapter.setItem).not.toHaveBeenCalled() + expect(StorageAdapter.removeItem).not.toHaveBeenCalled() + }) + + it('should restore WalletConnect session', () => { + const mockWalletConnectData = 'mockWalletConnectData' + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(mockWalletConnectData) + + // @ts-expect-error - Accessing private method for testing + wallet.manageWalletConnectSession('restore') + + expect(StorageAdapter.getItem).toHaveBeenCalledWith(`walletconnect-${WalletId.PERA}`) + expect(StorageAdapter.setItem).toHaveBeenCalledWith('walletconnect', mockWalletConnectData) + expect(StorageAdapter.removeItem).toHaveBeenCalledWith(`walletconnect-${WalletId.PERA}`) + }) + + it('should not restore WalletConnect session if no backup exists', () => { + vi.mocked(StorageAdapter.getItem).mockReturnValueOnce(null) + + // @ts-expect-error - Accessing private method for testing + wallet.manageWalletConnectSession('restore') + + expect(StorageAdapter.getItem).toHaveBeenCalledWith(`walletconnect-${WalletId.PERA}`) + expect(StorageAdapter.setItem).not.toHaveBeenCalled() + expect(StorageAdapter.removeItem).not.toHaveBeenCalled() + }) + }) + describe('signing transactions', () => { // Connected accounts const connectedAcct1 = '7ZUECA7HFLZTXENRV24SHLU4AVPUTMTTDUFUBNBD64C73F3UHRTHAIOF6Q' diff --git a/packages/use-wallet/src/__tests__/wallets/pera2.test.ts b/packages/use-wallet/src/__tests__/wallets/pera2.test.ts index 3158c581..171760cf 100644 --- a/packages/use-wallet/src/__tests__/wallets/pera2.test.ts +++ b/packages/use-wallet/src/__tests__/wallets/pera2.test.ts @@ -23,7 +23,8 @@ vi.mock('src/logger', () => ({ vi.mock('src/storage', () => ({ StorageAdapter: { getItem: vi.fn(), - setItem: vi.fn() + setItem: vi.fn(), + removeItem: vi.fn() } })) diff --git a/packages/use-wallet/src/__tests__/wallets/walletconnect.test.ts b/packages/use-wallet/src/__tests__/wallets/walletconnect.test.ts index 0ee238a2..b3c5ba0c 100644 --- a/packages/use-wallet/src/__tests__/wallets/walletconnect.test.ts +++ b/packages/use-wallet/src/__tests__/wallets/walletconnect.test.ts @@ -27,7 +27,8 @@ vi.mock('src/logger', () => ({ vi.mock('src/storage', () => ({ StorageAdapter: { getItem: vi.fn(), - setItem: vi.fn() + setItem: vi.fn(), + removeItem: vi.fn() } })) diff --git a/packages/use-wallet/src/storage.ts b/packages/use-wallet/src/storage.ts index a5bd5b21..842005f4 100644 --- a/packages/use-wallet/src/storage.ts +++ b/packages/use-wallet/src/storage.ts @@ -12,4 +12,11 @@ export class StorageAdapter { } localStorage.setItem(key, value) } + + static removeItem(key: string): void { + if (typeof window === 'undefined') { + return + } + localStorage.removeItem(key) + } } diff --git a/packages/use-wallet/src/wallets/base.ts b/packages/use-wallet/src/wallets/base.ts index d3e2d50c..aabdddd3 100644 --- a/packages/use-wallet/src/wallets/base.ts +++ b/packages/use-wallet/src/wallets/base.ts @@ -1,4 +1,5 @@ import { logger } from 'src/logger' +import { StorageAdapter } from 'src/storage' import { setActiveWallet, setActiveAccount, removeWallet, type State } from 'src/store' import type { Store } from '@tanstack/store' import type algosdk from 'algosdk' @@ -130,4 +131,26 @@ export abstract class BaseWallet { this.logger.debug(`Removing wallet from store...`) removeWallet(this.store, { walletId: this.id }) } + + protected manageWalletConnectSession = ( + action: 'backup' | 'restore', + targetWalletId?: WalletId + ): void => { + const walletId = targetWalletId || this.id + if (action === 'backup') { + const data = StorageAdapter.getItem('walletconnect') + if (data) { + StorageAdapter.setItem(`walletconnect-${walletId}`, data) + StorageAdapter.removeItem('walletconnect') + this.logger.debug(`Backed up WalletConnect session for ${walletId}`) + } + } else if (action === 'restore') { + const data = StorageAdapter.getItem(`walletconnect-${walletId}`) + if (data) { + StorageAdapter.setItem('walletconnect', data) + StorageAdapter.removeItem(`walletconnect-${walletId}`) + this.logger.debug(`Restored WalletConnect session for ${walletId}`) + } + } + } } diff --git a/packages/use-wallet/src/wallets/defly.ts b/packages/use-wallet/src/wallets/defly.ts index fe13ed13..420b8d54 100644 --- a/packages/use-wallet/src/wallets/defly.ts +++ b/packages/use-wallet/src/wallets/defly.ts @@ -1,5 +1,5 @@ import algosdk from 'algosdk' -import { WalletState, addWallet, setAccounts, type State } from 'src/store' +import { WalletState, addWallet, setAccounts, setActiveWallet, type State } from 'src/store' import { compareAccounts, flattenTxnGroup, isSignedTxn, isTransactionArray } from 'src/utils' import { BaseWallet } from 'src/wallets/base' import type { DeflyWalletConnect } from '@blockshake/defly-connect' @@ -64,6 +64,10 @@ export class DeflyWallet extends BaseWallet { public connect = async (): Promise => { this.logger.info('Connecting...') + const currentActiveWallet = this.store.state.activeWallet + if (currentActiveWallet && currentActiveWallet !== this.id) { + this.manageWalletConnectSession('backup', currentActiveWallet) + } const client = this.client || (await this.initializeClient()) const accounts = await client.connect() @@ -98,10 +102,32 @@ export class DeflyWallet extends BaseWallet { public disconnect = async (): Promise => { this.logger.info('Disconnecting...') - this.onDisconnect() const client = this.client || (await this.initializeClient()) - await client.disconnect() - this.logger.info('Disconnected.') + + const currentActiveWallet = this.store.state.activeWallet + if (currentActiveWallet && currentActiveWallet !== this.id) { + this.manageWalletConnectSession('backup', currentActiveWallet) + this.manageWalletConnectSession('restore', this.id) + await client.disconnect() + // Wait for the disconnect to complete (race condition) + await new Promise((resolve) => setTimeout(resolve, 500)) + this.manageWalletConnectSession('restore', currentActiveWallet) + } else { + await client.disconnect() + } + + this.onDisconnect() + this.logger.info('Disconnected') + } + + public setActive = (): void => { + this.logger.info(`Set active wallet: ${this.id}`) + const currentActiveWallet = this.store.state.activeWallet + if (currentActiveWallet && currentActiveWallet !== this.id) { + this.manageWalletConnectSession('backup', currentActiveWallet) + } + this.manageWalletConnectSession('restore') + setActiveWallet(this.store, { walletId: this.id }) } public resumeSession = async (): Promise => { diff --git a/packages/use-wallet/src/wallets/pera.ts b/packages/use-wallet/src/wallets/pera.ts index a97cc392..72e9a27e 100644 --- a/packages/use-wallet/src/wallets/pera.ts +++ b/packages/use-wallet/src/wallets/pera.ts @@ -1,5 +1,5 @@ import algosdk from 'algosdk' -import { WalletState, addWallet, setAccounts, type State } from 'src/store' +import { WalletState, addWallet, setAccounts, setActiveWallet, type State } from 'src/store' import { compareAccounts, flattenTxnGroup, isSignedTxn, isTransactionArray } from 'src/utils' import { BaseWallet } from 'src/wallets/base' import type { PeraWalletConnect } from '@perawallet/connect' @@ -69,6 +69,10 @@ export class PeraWallet extends BaseWallet { public connect = async (): Promise => { this.logger.info('Connecting...') + const currentActiveWallet = this.store.state.activeWallet + if (currentActiveWallet && currentActiveWallet !== this.id) { + this.manageWalletConnectSession('backup', currentActiveWallet) + } const client = this.client || (await this.initializeClient()) const accounts = await client.connect() @@ -103,12 +107,34 @@ export class PeraWallet extends BaseWallet { public disconnect = async (): Promise => { this.logger.info('Disconnecting...') - this.onDisconnect() const client = this.client || (await this.initializeClient()) - await client.disconnect() + + const currentActiveWallet = this.store.state.activeWallet + if (currentActiveWallet && currentActiveWallet !== this.id) { + this.manageWalletConnectSession('backup', currentActiveWallet) + this.manageWalletConnectSession('restore', this.id) + await client.disconnect() + // Wait for the disconnect to complete (race condition) + await new Promise((resolve) => setTimeout(resolve, 500)) + this.manageWalletConnectSession('restore', currentActiveWallet) + } else { + await client.disconnect() + } + + this.onDisconnect() this.logger.info('Disconnected') } + public setActive = (): void => { + this.logger.info(`Set active wallet: ${this.id}`) + const currentActiveWallet = this.store.state.activeWallet + if (currentActiveWallet && currentActiveWallet !== this.id) { + this.manageWalletConnectSession('backup', currentActiveWallet) + } + this.manageWalletConnectSession('restore') + setActiveWallet(this.store, { walletId: this.id }) + } + public resumeSession = async (): Promise => { try { const state = this.store.state