Skip to content

Commit 30b1641

Browse files
committed
arkd-wallet: unit tests outpointLocker
1 parent 755d051 commit 30b1641

File tree

3 files changed

+148
-4
lines changed

3 files changed

+148
-4
lines changed

pkg/arkd-wallet/core/application/wallet/key_manager.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func newKeyManager(seed []byte, network *chaincfg.Params) (*keyManager, error) {
6767
return nil, err
6868
}
6969

70-
forfeitPrvkey, err := deriveForfeitPrvkey(mainAccount, network)
70+
forfeitPrvkey, err := deriveForfeitPrvkey(mainAccount)
7171
if err != nil {
7272
return nil, err
7373
}
@@ -114,7 +114,7 @@ func computeTaprootDerivationScheme(accountKey *hdkeychain.ExtendedKey) (string,
114114
return neutered.String() + "-[taproot]", nil
115115
}
116116

117-
func deriveForfeitPrvkey(xpub *hdkeychain.ExtendedKey, network *chaincfg.Params) (*btcec.PrivateKey, error) {
117+
func deriveForfeitPrvkey(xpub *hdkeychain.ExtendedKey) (*btcec.PrivateKey, error) {
118118
key, err := xpub.Derive(0)
119119
if err != nil {
120120
return nil, err

pkg/arkd-wallet/core/application/wallet/outpoint_locker.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ func newOutpointLocker(lockFor time.Duration) *outpointLocker {
2222
}
2323
}
2424

25-
func (l *outpointLocker) lock(ctx context.Context, outpoints ...wire.OutPoint) error {
25+
func (l *outpointLocker) lock(_ context.Context, outpoints ...wire.OutPoint) error {
26+
if len(outpoints) == 0 {
27+
return nil
28+
}
29+
2630
l.locker.Lock()
2731
defer l.locker.Unlock()
2832

@@ -36,7 +40,7 @@ func (l *outpointLocker) lock(ctx context.Context, outpoints ...wire.OutPoint) e
3640
return nil
3741
}
3842

39-
func (l *outpointLocker) get(ctx context.Context) (map[wire.OutPoint]struct{}, error) {
43+
func (l *outpointLocker) get(_ context.Context) (map[wire.OutPoint]struct{}, error) {
4044
l.locker.Lock()
4145
defer l.locker.Unlock()
4246

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package wallet
2+
3+
import (
4+
"context"
5+
"crypto/rand"
6+
"sync"
7+
"testing"
8+
"time"
9+
10+
"github.com/btcsuite/btcd/wire"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestNewOutpointLocker(t *testing.T) {
15+
lockDuration := 5 * time.Minute
16+
locker := newOutpointLocker(lockDuration)
17+
18+
require.NotNil(t, locker)
19+
require.Equal(t, lockDuration, locker.lockExpiry)
20+
require.NotNil(t, locker.lockedOutpoints)
21+
require.Empty(t, locker.lockedOutpoints)
22+
}
23+
24+
func TestOutpointLocker_Lock(t *testing.T) {
25+
lockDuration := 1 * time.Hour
26+
locker := newOutpointLocker(lockDuration)
27+
28+
hash0 := random32Bytes()
29+
hash1 := random32Bytes()
30+
outpoint1 := wire.OutPoint{Hash: hash0, Index: 0}
31+
outpoint2 := wire.OutPoint{Hash: hash1, Index: 1}
32+
33+
// test locking single outpoint
34+
err := locker.lock(context.Background(), outpoint1)
35+
require.NoError(t, err)
36+
37+
// verify outpoint is locked
38+
lockedOutpoints, err := locker.get(context.Background())
39+
require.NoError(t, err)
40+
require.Len(t, lockedOutpoints, 1)
41+
require.Contains(t, lockedOutpoints, outpoint1)
42+
43+
// test locking multiple outpoints
44+
err = locker.lock(context.Background(), outpoint2)
45+
require.NoError(t, err)
46+
47+
// verify both outpoints are locked
48+
lockedOutpoints, err = locker.get(context.Background())
49+
require.NoError(t, err)
50+
require.Len(t, lockedOutpoints, 2)
51+
require.Contains(t, lockedOutpoints, outpoint1)
52+
require.Contains(t, lockedOutpoints, outpoint2)
53+
54+
// test locking same outpoint again (should update expiry)
55+
time.Sleep(10 * time.Millisecond) // Small delay to ensure different timestamps
56+
err = locker.lock(context.Background(), outpoint1)
57+
require.NoError(t, err)
58+
59+
// verify outpoint is still locked with updated expiry
60+
lockedOutpoints, err = locker.get(context.Background())
61+
require.NoError(t, err)
62+
require.Len(t, lockedOutpoints, 2)
63+
require.Contains(t, lockedOutpoints, outpoint1)
64+
require.Contains(t, lockedOutpoints, outpoint2)
65+
}
66+
67+
func TestOutpointLocker_Get(t *testing.T) {
68+
lockDuration := 100 * time.Millisecond
69+
locker := newOutpointLocker(lockDuration)
70+
71+
hash0 := random32Bytes()
72+
hash1 := random32Bytes()
73+
outpoint1 := wire.OutPoint{Hash: hash0, Index: 0}
74+
outpoint2 := wire.OutPoint{Hash: hash1, Index: 1}
75+
76+
// lock outpoints
77+
err := locker.lock(context.Background(), outpoint1, outpoint2)
78+
require.NoError(t, err)
79+
80+
lockedOutpoints, err := locker.get(context.Background())
81+
require.NoError(t, err)
82+
require.Len(t, lockedOutpoints, 2)
83+
require.Contains(t, lockedOutpoints, outpoint1)
84+
require.Contains(t, lockedOutpoints, outpoint2)
85+
86+
// wait for locks to expire
87+
time.Sleep(lockDuration + 50*time.Millisecond)
88+
89+
lockedOutpoints, err = locker.get(context.Background())
90+
require.NoError(t, err)
91+
require.Empty(t, lockedOutpoints)
92+
}
93+
94+
func TestOutpointLocker_ConcurrentGetAndLock(t *testing.T) {
95+
// half lock, half get
96+
numberOfRoutines := 100
97+
lockDuration := 100 * time.Millisecond
98+
locker := newOutpointLocker(lockDuration)
99+
100+
outpoints := make([]wire.OutPoint, 0, 10)
101+
for index := range numberOfRoutines / 2 {
102+
outpoints = append(outpoints, wire.OutPoint{Hash: random32Bytes(), Index: uint32(index)})
103+
}
104+
105+
wg := sync.WaitGroup{}
106+
wg.Add(numberOfRoutines)
107+
108+
// start 10 goroutines that lock the outpoint
109+
for _, outpoint := range outpoints {
110+
go func() {
111+
err := locker.lock(context.Background(), outpoint)
112+
require.NoError(t, err)
113+
wg.Done()
114+
}()
115+
}
116+
117+
// start 10 goroutines that get locked outpoints
118+
for range numberOfRoutines / 2 {
119+
go func() {
120+
_, err := locker.get(context.Background())
121+
require.NoError(t, err)
122+
wg.Done()
123+
}()
124+
}
125+
126+
wg.Wait()
127+
128+
lockedOutpoints, err := locker.get(context.Background())
129+
require.NoError(t, err)
130+
require.Len(t, lockedOutpoints, len(outpoints))
131+
for _, outpoint := range outpoints {
132+
require.Contains(t, lockedOutpoints, outpoint)
133+
}
134+
}
135+
136+
func random32Bytes() [32]byte {
137+
var b [32]byte
138+
rand.Read(b[:])
139+
return b
140+
}

0 commit comments

Comments
 (0)