Skip to content

Commit b3a07d8

Browse files
authored
Use require.IsType for type assertions in tests (ava-labs#1458)
1 parent eb8b52a commit b3a07d8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+379
-438
lines changed

api/auth/auth_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ func TestNewTokenHappyPath(t *testing.T) {
6969
})
7070
require.NoError(t, err, "couldn't parse new token")
7171

72-
claims, ok := token.Claims.(*endpointClaims)
73-
require.True(t, ok, "expected auth token's claims to be type endpointClaims but is different type")
72+
require.IsType(t, &endpointClaims{}, token.Claims)
73+
claims := token.Claims.(*endpointClaims)
7474
require.ElementsMatch(t, endpoints, claims.Endpoints, "token has wrong endpoint claims")
7575

7676
shouldExpireAt := jwt.NewNumericDate(now.Add(defaultTokenLifespan))

database/manager/manager_test.go

+6-12
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,9 @@ func TestMeterDBManager(t *testing.T) {
306306
dbs := manager.GetDatabases()
307307
require.Len(dbs, 3)
308308

309-
_, ok := dbs[0].Database.(*meterdb.Database)
310-
require.True(ok)
311-
_, ok = dbs[1].Database.(*meterdb.Database)
312-
require.False(ok)
313-
_, ok = dbs[2].Database.(*meterdb.Database)
314-
require.False(ok)
309+
require.IsType(&meterdb.Database{}, dbs[0].Database)
310+
require.IsType(&memdb.Database{}, dbs[1].Database)
311+
require.IsType(&memdb.Database{}, dbs[2].Database)
315312

316313
// Confirm that the error from a name conflict is handled correctly
317314
_, err = m.NewMeterDBManager("", registry)
@@ -355,12 +352,9 @@ func TestCompleteMeterDBManager(t *testing.T) {
355352
dbs := manager.GetDatabases()
356353
require.Len(dbs, 3)
357354

358-
_, ok := dbs[0].Database.(*meterdb.Database)
359-
require.True(ok)
360-
_, ok = dbs[1].Database.(*meterdb.Database)
361-
require.True(ok)
362-
_, ok = dbs[2].Database.(*meterdb.Database)
363-
require.True(ok)
355+
require.IsType(&meterdb.Database{}, dbs[0].Database)
356+
require.IsType(&meterdb.Database{}, dbs[1].Database)
357+
require.IsType(&meterdb.Database{}, dbs[2].Database)
364358

365359
// Confirm that the error from a name conflict is handled correctly
366360
_, err = m.NewCompleteMeterDBManager("", registry)

indexer/indexer_test.go

+19-20
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ func TestNewIndexer(t *testing.T) {
6767

6868
idxrIntf, err := NewIndexer(config)
6969
require.NoError(err)
70-
idxr, ok := idxrIntf.(*indexer)
71-
require.True(ok)
70+
require.IsType(&indexer{}, idxrIntf)
71+
idxr := idxrIntf.(*indexer)
7272
require.NotNil(idxr.codec)
7373
require.NotNil(idxr.log)
7474
require.NotNil(idxr.db)
@@ -118,8 +118,8 @@ func TestMarkHasRunAndShutdown(t *testing.T) {
118118
config.DB = versiondb.New(baseDB)
119119
idxrIntf, err = NewIndexer(config)
120120
require.NoError(err)
121-
idxr, ok := idxrIntf.(*indexer)
122-
require.True(ok)
121+
require.IsType(&indexer{}, idxrIntf)
122+
idxr := idxrIntf.(*indexer)
123123
require.True(idxr.hasRunBefore)
124124
require.NoError(idxr.Close())
125125
shutdown.Wait()
@@ -150,8 +150,8 @@ func TestIndexer(t *testing.T) {
150150
// Create indexer
151151
idxrIntf, err := NewIndexer(config)
152152
require.NoError(err)
153-
idxr, ok := idxrIntf.(*indexer)
154-
require.True(ok)
153+
require.IsType(&indexer{}, idxrIntf)
154+
idxr := idxrIntf.(*indexer)
155155
now := time.Now()
156156
idxr.clock.Set(now)
157157

@@ -232,10 +232,10 @@ func TestIndexer(t *testing.T) {
232232
config.DB = versiondb.New(baseDB)
233233
idxrIntf, err = NewIndexer(config)
234234
require.NoError(err)
235-
idxr, ok = idxrIntf.(*indexer)
235+
require.IsType(&indexer{}, idxrIntf)
236+
idxr = idxrIntf.(*indexer)
236237
now = time.Now()
237238
idxr.clock.Set(now)
238-
require.True(ok)
239239
require.Len(idxr.blockIndices, 0)
240240
require.Len(idxr.txIndices, 0)
241241
require.Len(idxr.vtxIndices, 0)
@@ -389,8 +389,8 @@ func TestIndexer(t *testing.T) {
389389
config.DB = versiondb.New(baseDB)
390390
idxrIntf, err = NewIndexer(config)
391391
require.NoError(err)
392-
idxr, ok = idxrIntf.(*indexer)
393-
require.True(ok)
392+
require.IsType(&indexer{}, idxrIntf)
393+
idxr = idxrIntf.(*indexer)
394394
idxr.RegisterChain("chain1", chain1Ctx, chainVM)
395395
idxr.RegisterChain("chain2", chain2Ctx, dagVM)
396396

@@ -427,8 +427,8 @@ func TestIncompleteIndex(t *testing.T) {
427427
}
428428
idxrIntf, err := NewIndexer(config)
429429
require.NoError(err)
430-
idxr, ok := idxrIntf.(*indexer)
431-
require.True(ok)
430+
require.IsType(&indexer{}, idxrIntf)
431+
idxr := idxrIntf.(*indexer)
432432
require.False(idxr.indexingEnabled)
433433

434434
// Register a chain
@@ -454,8 +454,8 @@ func TestIncompleteIndex(t *testing.T) {
454454
config.DB = versiondb.New(baseDB)
455455
idxrIntf, err = NewIndexer(config)
456456
require.NoError(err)
457-
idxr, ok = idxrIntf.(*indexer)
458-
require.True(ok)
457+
require.IsType(&indexer{}, idxrIntf)
458+
idxr = idxrIntf.(*indexer)
459459
require.True(idxr.indexingEnabled)
460460

461461
// Register the chain again. Should die due to incomplete index.
@@ -470,8 +470,8 @@ func TestIncompleteIndex(t *testing.T) {
470470
config.DB = versiondb.New(baseDB)
471471
idxrIntf, err = NewIndexer(config)
472472
require.NoError(err)
473-
idxr, ok = idxrIntf.(*indexer)
474-
require.True(ok)
473+
require.IsType(&indexer{}, idxrIntf)
474+
idxr = idxrIntf.(*indexer)
475475
require.True(idxr.allowIncompleteIndex)
476476

477477
// Register the chain again. Should be OK
@@ -486,8 +486,7 @@ func TestIncompleteIndex(t *testing.T) {
486486
config.DB = versiondb.New(baseDB)
487487
idxrIntf, err = NewIndexer(config)
488488
require.NoError(err)
489-
_, ok = idxrIntf.(*indexer)
490-
require.True(ok)
489+
require.IsType(&indexer{}, idxrIntf)
491490
}
492491

493492
// Ensure we only index chains in the primary network
@@ -513,8 +512,8 @@ func TestIgnoreNonDefaultChains(t *testing.T) {
513512
// Create indexer
514513
idxrIntf, err := NewIndexer(config)
515514
require.NoError(err)
516-
idxr, ok := idxrIntf.(*indexer)
517-
require.True(ok)
515+
require.IsType(&indexer{}, idxrIntf)
516+
idxr := idxrIntf.(*indexer)
518517

519518
// Assert state is right
520519
chain1Ctx := snow.DefaultConsensusContextTest()

message/inbound_msg_builder_test.go

+26-26
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ func TestInboundMsgBuilder(t *testing.T) {
6565
require.Equal(nodeID, msg.NodeID())
6666
require.False(msg.Expiration().Before(start.Add(deadline)))
6767
require.False(end.Add(deadline).Before(msg.Expiration()))
68-
innerMsg, ok := msg.Message().(*p2p.GetStateSummaryFrontier)
69-
require.True(ok)
68+
require.IsType(&p2p.GetStateSummaryFrontier{}, msg.Message())
69+
innerMsg := msg.Message().(*p2p.GetStateSummaryFrontier)
7070
require.Equal(chainID[:], innerMsg.ChainId)
7171
require.Equal(requestID, innerMsg.RequestId)
7272
},
@@ -87,8 +87,8 @@ func TestInboundMsgBuilder(t *testing.T) {
8787
require.Equal(StateSummaryFrontierOp, msg.Op())
8888
require.Equal(nodeID, msg.NodeID())
8989
require.Equal(mockable.MaxTime, msg.Expiration())
90-
innerMsg, ok := msg.Message().(*p2p.StateSummaryFrontier)
91-
require.True(ok)
90+
require.IsType(&p2p.StateSummaryFrontier{}, msg.Message())
91+
innerMsg := msg.Message().(*p2p.StateSummaryFrontier)
9292
require.Equal(chainID[:], innerMsg.ChainId)
9393
require.Equal(requestID, innerMsg.RequestId)
9494
require.Equal(summary, innerMsg.Summary)
@@ -114,8 +114,8 @@ func TestInboundMsgBuilder(t *testing.T) {
114114
require.Equal(nodeID, msg.NodeID())
115115
require.False(msg.Expiration().Before(start.Add(deadline)))
116116
require.False(end.Add(deadline).Before(msg.Expiration()))
117-
innerMsg, ok := msg.Message().(*p2p.GetAcceptedStateSummary)
118-
require.True(ok)
117+
require.IsType(&p2p.GetAcceptedStateSummary{}, msg.Message())
118+
innerMsg := msg.Message().(*p2p.GetAcceptedStateSummary)
119119
require.Equal(chainID[:], innerMsg.ChainId)
120120
require.Equal(requestID, innerMsg.RequestId)
121121
require.Equal(heights, innerMsg.Heights)
@@ -137,8 +137,8 @@ func TestInboundMsgBuilder(t *testing.T) {
137137
require.Equal(AcceptedStateSummaryOp, msg.Op())
138138
require.Equal(nodeID, msg.NodeID())
139139
require.Equal(mockable.MaxTime, msg.Expiration())
140-
innerMsg, ok := msg.Message().(*p2p.AcceptedStateSummary)
141-
require.True(ok)
140+
require.IsType(&p2p.AcceptedStateSummary{}, msg.Message())
141+
innerMsg := msg.Message().(*p2p.AcceptedStateSummary)
142142
require.Equal(chainID[:], innerMsg.ChainId)
143143
require.Equal(requestID, innerMsg.RequestId)
144144
summaryIDsBytes := make([][]byte, len(summaryIDs))
@@ -169,8 +169,8 @@ func TestInboundMsgBuilder(t *testing.T) {
169169
require.Equal(nodeID, msg.NodeID())
170170
require.False(msg.Expiration().Before(start.Add(deadline)))
171171
require.False(end.Add(deadline).Before(msg.Expiration()))
172-
innerMsg, ok := msg.Message().(*p2p.GetAcceptedFrontier)
173-
require.True(ok)
172+
require.IsType(&p2p.GetAcceptedFrontier{}, msg.Message())
173+
innerMsg := msg.Message().(*p2p.GetAcceptedFrontier)
174174
require.Equal(chainID[:], innerMsg.ChainId)
175175
require.Equal(requestID, innerMsg.RequestId)
176176
require.Equal(engineType, innerMsg.EngineType)
@@ -192,8 +192,8 @@ func TestInboundMsgBuilder(t *testing.T) {
192192
require.Equal(AcceptedFrontierOp, msg.Op())
193193
require.Equal(nodeID, msg.NodeID())
194194
require.Equal(mockable.MaxTime, msg.Expiration())
195-
innerMsg, ok := msg.Message().(*p2p.AcceptedFrontier)
196-
require.True(ok)
195+
require.IsType(&p2p.AcceptedFrontier{}, msg.Message())
196+
innerMsg := msg.Message().(*p2p.AcceptedFrontier)
197197
require.Equal(chainID[:], innerMsg.ChainId)
198198
require.Equal(requestID, innerMsg.RequestId)
199199
containerIDsBytes := make([][]byte, len(containerIDs))
@@ -225,8 +225,8 @@ func TestInboundMsgBuilder(t *testing.T) {
225225
require.Equal(nodeID, msg.NodeID())
226226
require.False(msg.Expiration().Before(start.Add(deadline)))
227227
require.False(end.Add(deadline).Before(msg.Expiration()))
228-
innerMsg, ok := msg.Message().(*p2p.GetAccepted)
229-
require.True(ok)
228+
require.IsType(&p2p.GetAccepted{}, msg.Message())
229+
innerMsg := msg.Message().(*p2p.GetAccepted)
230230
require.Equal(chainID[:], innerMsg.ChainId)
231231
require.Equal(requestID, innerMsg.RequestId)
232232
require.Equal(engineType, innerMsg.EngineType)
@@ -248,8 +248,8 @@ func TestInboundMsgBuilder(t *testing.T) {
248248
require.Equal(AcceptedOp, msg.Op())
249249
require.Equal(nodeID, msg.NodeID())
250250
require.Equal(mockable.MaxTime, msg.Expiration())
251-
innerMsg, ok := msg.Message().(*p2p.Accepted)
252-
require.True(ok)
251+
require.IsType(&p2p.Accepted{}, msg.Message())
252+
innerMsg := msg.Message().(*p2p.Accepted)
253253
require.Equal(chainID[:], innerMsg.ChainId)
254254
require.Equal(requestID, innerMsg.RequestId)
255255
containerIDsBytes := make([][]byte, len(containerIDs))
@@ -281,8 +281,8 @@ func TestInboundMsgBuilder(t *testing.T) {
281281
require.Equal(nodeID, msg.NodeID())
282282
require.False(msg.Expiration().Before(start.Add(deadline)))
283283
require.False(end.Add(deadline).Before(msg.Expiration()))
284-
innerMsg, ok := msg.Message().(*p2p.PushQuery)
285-
require.True(ok)
284+
require.IsType(&p2p.PushQuery{}, msg.Message())
285+
innerMsg := msg.Message().(*p2p.PushQuery)
286286
require.Equal(chainID[:], innerMsg.ChainId)
287287
require.Equal(requestID, innerMsg.RequestId)
288288
require.Equal(container, innerMsg.Container)
@@ -310,8 +310,8 @@ func TestInboundMsgBuilder(t *testing.T) {
310310
require.Equal(nodeID, msg.NodeID())
311311
require.False(msg.Expiration().Before(start.Add(deadline)))
312312
require.False(end.Add(deadline).Before(msg.Expiration()))
313-
innerMsg, ok := msg.Message().(*p2p.PullQuery)
314-
require.True(ok)
313+
require.IsType(&p2p.PullQuery{}, msg.Message())
314+
innerMsg := msg.Message().(*p2p.PullQuery)
315315
require.Equal(chainID[:], innerMsg.ChainId)
316316
require.Equal(requestID, innerMsg.RequestId)
317317
require.Equal(containerIDs[0][:], innerMsg.ContainerId)
@@ -335,8 +335,8 @@ func TestInboundMsgBuilder(t *testing.T) {
335335
require.Equal(ChitsOp, msg.Op())
336336
require.Equal(nodeID, msg.NodeID())
337337
require.Equal(mockable.MaxTime, msg.Expiration())
338-
innerMsg, ok := msg.Message().(*p2p.Chits)
339-
require.True(ok)
338+
require.IsType(&p2p.Chits{}, msg.Message())
339+
innerMsg := msg.Message().(*p2p.Chits)
340340
require.Equal(chainID[:], innerMsg.ChainId)
341341
require.Equal(requestID, innerMsg.RequestId)
342342
containerIDsBytes := make([][]byte, len(containerIDs))
@@ -373,8 +373,8 @@ func TestInboundMsgBuilder(t *testing.T) {
373373
require.Equal(nodeID, msg.NodeID())
374374
require.False(msg.Expiration().Before(start.Add(deadline)))
375375
require.False(end.Add(deadline).Before(msg.Expiration()))
376-
innerMsg, ok := msg.Message().(*p2p.AppRequest)
377-
require.True(ok)
376+
require.IsType(&p2p.AppRequest{}, msg.Message())
377+
innerMsg := msg.Message().(*p2p.AppRequest)
378378
require.Equal(chainID[:], innerMsg.ChainId)
379379
require.Equal(requestID, innerMsg.RequestId)
380380
require.Equal(appBytes, innerMsg.AppBytes)
@@ -396,8 +396,8 @@ func TestInboundMsgBuilder(t *testing.T) {
396396
require.Equal(AppResponseOp, msg.Op())
397397
require.Equal(nodeID, msg.NodeID())
398398
require.Equal(mockable.MaxTime, msg.Expiration())
399-
innerMsg, ok := msg.Message().(*p2p.AppResponse)
400-
require.True(ok)
399+
require.IsType(&p2p.AppResponse{}, msg.Message())
400+
innerMsg := msg.Message().(*p2p.AppResponse)
401401
require.Equal(chainID[:], innerMsg.ChainId)
402402
require.Equal(requestID, innerMsg.RequestId)
403403
require.Equal(appBytes, innerMsg.AppBytes)

message/messages_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ func TestNilInboundMessage(t *testing.T) {
872872
parsedMsg, err := mb.parseInbound(msgBytes, ids.EmptyNodeID, func() {})
873873
require.NoError(err)
874874

875-
pingMsg, ok := parsedMsg.message.(*p2p.Ping)
876-
require.True(ok)
875+
require.IsType(&p2p.Ping{}, parsedMsg.message)
876+
pingMsg := parsedMsg.message.(*p2p.Ping)
877877
require.NotNil(pingMsg)
878878
}

network/throttling/bandwidth_throttler_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ func TestBandwidthThrottler(t *testing.T) {
2525
}
2626
throttlerIntf, err := newBandwidthThrottler(logging.NoLog{}, "", prometheus.NewRegistry(), config)
2727
require.NoError(err)
28-
throttler, ok := throttlerIntf.(*bandwidthThrottlerImpl)
29-
require.True(ok)
28+
require.IsType(&bandwidthThrottlerImpl{}, throttlerIntf)
29+
throttler := throttlerIntf.(*bandwidthThrottlerImpl)
3030
require.NotNil(throttler.log)
3131
require.NotNil(throttler.limiters)
3232
require.Equal(config.RefillRate, throttler.RefillRate)

network/throttling/inbound_resource_throttler_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ func TestNewSystemThrottler(t *testing.T) {
4040
targeter := tracker.NewMockTargeter(ctrl)
4141
throttlerIntf, err := NewSystemThrottler("", reg, config, cpuTracker, targeter)
4242
require.NoError(err)
43-
throttler, ok := throttlerIntf.(*systemThrottler)
44-
require.True(ok)
43+
require.IsType(&systemThrottler{}, throttlerIntf)
44+
throttler := throttlerIntf.(*systemThrottler)
4545
require.Equal(clock, config.Clock)
4646
require.Equal(time.Second, config.MaxRecheckDelay)
4747
require.Equal(cpuTracker, throttler.tracker)

snow/consensus/snowball/unary_snowball_test.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ package snowball
55

66
import (
77
"testing"
8+
9+
"github.com/stretchr/testify/require"
810
)
911

1012
func UnarySnowballStateTest(t *testing.T, sb *unarySnowball, expectedNumSuccessfulPolls, expectedConfidence int, expectedFinalized bool) {
@@ -18,6 +20,8 @@ func UnarySnowballStateTest(t *testing.T, sb *unarySnowball, expectedNumSuccessf
1820
}
1921

2022
func TestUnarySnowball(t *testing.T) {
23+
require := require.New(t)
24+
2125
beta := 2
2226

2327
sb := &unarySnowball{}
@@ -33,10 +37,8 @@ func TestUnarySnowball(t *testing.T) {
3337
UnarySnowballStateTest(t, sb, 2, 1, false)
3438

3539
sbCloneIntf := sb.Clone()
36-
sbClone, ok := sbCloneIntf.(*unarySnowball)
37-
if !ok {
38-
t.Fatalf("Unexpected clone type")
39-
}
40+
require.IsType(&unarySnowball{}, sbCloneIntf)
41+
sbClone := sbCloneIntf.(*unarySnowball)
4042

4143
UnarySnowballStateTest(t, sbClone, 2, 1, false)
4244

snow/consensus/snowball/unary_snowflake_test.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ package snowball
55

66
import (
77
"testing"
8+
9+
"github.com/stretchr/testify/require"
810
)
911

1012
func UnarySnowflakeStateTest(t *testing.T, sf *unarySnowflake, expectedConfidence int, expectedFinalized bool) {
@@ -16,6 +18,8 @@ func UnarySnowflakeStateTest(t *testing.T, sf *unarySnowflake, expectedConfidenc
1618
}
1719

1820
func TestUnarySnowflake(t *testing.T) {
21+
require := require.New(t)
22+
1923
beta := 2
2024

2125
sf := &unarySnowflake{}
@@ -31,10 +35,8 @@ func TestUnarySnowflake(t *testing.T) {
3135
UnarySnowflakeStateTest(t, sf, 1, false)
3236

3337
sfCloneIntf := sf.Clone()
34-
sfClone, ok := sfCloneIntf.(*unarySnowflake)
35-
if !ok {
36-
t.Fatalf("Unexpected clone type")
37-
}
38+
require.IsType(&unarySnowflake{}, sfCloneIntf)
39+
sfClone := sfCloneIntf.(*unarySnowflake)
3840

3941
UnarySnowflakeStateTest(t, sfClone, 1, false)
4042

0 commit comments

Comments
 (0)