Skip to content

use require.IsType for type assertions in tests #1458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 3, 2023
4 changes: 2 additions & 2 deletions api/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ func TestNewTokenHappyPath(t *testing.T) {
})
require.NoError(t, err, "couldn't parse new token")

claims, ok := token.Claims.(*endpointClaims)
require.True(t, ok, "expected auth token's claims to be type endpointClaims but is different type")
require.IsType(t, &endpointClaims{}, token.Claims)
claims := token.Claims.(*endpointClaims)
require.ElementsMatch(t, endpoints, claims.Endpoints, "token has wrong endpoint claims")

shouldExpireAt := jwt.NewNumericDate(now.Add(defaultTokenLifespan))
Expand Down
18 changes: 6 additions & 12 deletions database/manager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,9 @@ func TestMeterDBManager(t *testing.T) {
dbs := manager.GetDatabases()
require.Len(dbs, 3)

_, ok := dbs[0].Database.(*meterdb.Database)
require.True(ok)
_, ok = dbs[1].Database.(*meterdb.Database)
require.False(ok)
_, ok = dbs[2].Database.(*meterdb.Database)
require.False(ok)
require.IsType(&meterdb.Database{}, dbs[0].Database)
require.IsType(&memdb.Database{}, dbs[1].Database)
require.IsType(&memdb.Database{}, dbs[2].Database)

// Confirm that the error from a name conflict is handled correctly
_, err = m.NewMeterDBManager("", registry)
Expand Down Expand Up @@ -355,12 +352,9 @@ func TestCompleteMeterDBManager(t *testing.T) {
dbs := manager.GetDatabases()
require.Len(dbs, 3)

_, ok := dbs[0].Database.(*meterdb.Database)
require.True(ok)
_, ok = dbs[1].Database.(*meterdb.Database)
require.True(ok)
_, ok = dbs[2].Database.(*meterdb.Database)
require.True(ok)
require.IsType(&meterdb.Database{}, dbs[0].Database)
require.IsType(&meterdb.Database{}, dbs[1].Database)
require.IsType(&meterdb.Database{}, dbs[2].Database)

// Confirm that the error from a name conflict is handled correctly
_, err = m.NewCompleteMeterDBManager("", registry)
Expand Down
39 changes: 19 additions & 20 deletions indexer/indexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ func TestNewIndexer(t *testing.T) {

idxrIntf, err := NewIndexer(config)
require.NoError(err)
idxr, ok := idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr := idxrIntf.(*indexer)
require.NotNil(idxr.codec)
require.NotNil(idxr.log)
require.NotNil(idxr.db)
Expand Down Expand Up @@ -118,8 +118,8 @@ func TestMarkHasRunAndShutdown(t *testing.T) {
config.DB = versiondb.New(baseDB)
idxrIntf, err = NewIndexer(config)
require.NoError(err)
idxr, ok := idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr := idxrIntf.(*indexer)
require.True(idxr.hasRunBefore)
require.NoError(idxr.Close())
shutdown.Wait()
Expand Down Expand Up @@ -150,8 +150,8 @@ func TestIndexer(t *testing.T) {
// Create indexer
idxrIntf, err := NewIndexer(config)
require.NoError(err)
idxr, ok := idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr := idxrIntf.(*indexer)
now := time.Now()
idxr.clock.Set(now)

Expand Down Expand Up @@ -232,10 +232,10 @@ func TestIndexer(t *testing.T) {
config.DB = versiondb.New(baseDB)
idxrIntf, err = NewIndexer(config)
require.NoError(err)
idxr, ok = idxrIntf.(*indexer)
require.IsType(&indexer{}, idxrIntf)
idxr = idxrIntf.(*indexer)
now = time.Now()
idxr.clock.Set(now)
require.True(ok)
require.Len(idxr.blockIndices, 0)
require.Len(idxr.txIndices, 0)
require.Len(idxr.vtxIndices, 0)
Expand Down Expand Up @@ -389,8 +389,8 @@ func TestIndexer(t *testing.T) {
config.DB = versiondb.New(baseDB)
idxrIntf, err = NewIndexer(config)
require.NoError(err)
idxr, ok = idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr = idxrIntf.(*indexer)
idxr.RegisterChain("chain1", chain1Ctx, chainVM)
idxr.RegisterChain("chain2", chain2Ctx, dagVM)

Expand Down Expand Up @@ -427,8 +427,8 @@ func TestIncompleteIndex(t *testing.T) {
}
idxrIntf, err := NewIndexer(config)
require.NoError(err)
idxr, ok := idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr := idxrIntf.(*indexer)
require.False(idxr.indexingEnabled)

// Register a chain
Expand All @@ -454,8 +454,8 @@ func TestIncompleteIndex(t *testing.T) {
config.DB = versiondb.New(baseDB)
idxrIntf, err = NewIndexer(config)
require.NoError(err)
idxr, ok = idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr = idxrIntf.(*indexer)
require.True(idxr.indexingEnabled)

// Register the chain again. Should die due to incomplete index.
Expand All @@ -470,8 +470,8 @@ func TestIncompleteIndex(t *testing.T) {
config.DB = versiondb.New(baseDB)
idxrIntf, err = NewIndexer(config)
require.NoError(err)
idxr, ok = idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr = idxrIntf.(*indexer)
require.True(idxr.allowIncompleteIndex)

// Register the chain again. Should be OK
Expand All @@ -486,8 +486,7 @@ func TestIncompleteIndex(t *testing.T) {
config.DB = versiondb.New(baseDB)
idxrIntf, err = NewIndexer(config)
require.NoError(err)
_, ok = idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
}

// Ensure we only index chains in the primary network
Expand All @@ -513,8 +512,8 @@ func TestIgnoreNonDefaultChains(t *testing.T) {
// Create indexer
idxrIntf, err := NewIndexer(config)
require.NoError(err)
idxr, ok := idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr := idxrIntf.(*indexer)

// Assert state is right
chain1Ctx := snow.DefaultConsensusContextTest()
Expand Down
52 changes: 26 additions & 26 deletions message/inbound_msg_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(nodeID, msg.NodeID())
require.False(msg.Expiration().Before(start.Add(deadline)))
require.False(end.Add(deadline).Before(msg.Expiration()))
innerMsg, ok := msg.Message().(*p2p.GetStateSummaryFrontier)
require.True(ok)
require.IsType(&p2p.GetStateSummaryFrontier{}, msg.Message())
innerMsg := msg.Message().(*p2p.GetStateSummaryFrontier)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
},
Expand All @@ -87,8 +87,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(StateSummaryFrontierOp, msg.Op())
require.Equal(nodeID, msg.NodeID())
require.Equal(mockable.MaxTime, msg.Expiration())
innerMsg, ok := msg.Message().(*p2p.StateSummaryFrontier)
require.True(ok)
require.IsType(&p2p.StateSummaryFrontier{}, msg.Message())
innerMsg := msg.Message().(*p2p.StateSummaryFrontier)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(summary, innerMsg.Summary)
Expand All @@ -114,8 +114,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(nodeID, msg.NodeID())
require.False(msg.Expiration().Before(start.Add(deadline)))
require.False(end.Add(deadline).Before(msg.Expiration()))
innerMsg, ok := msg.Message().(*p2p.GetAcceptedStateSummary)
require.True(ok)
require.IsType(&p2p.GetAcceptedStateSummary{}, msg.Message())
innerMsg := msg.Message().(*p2p.GetAcceptedStateSummary)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(heights, innerMsg.Heights)
Expand All @@ -137,8 +137,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(AcceptedStateSummaryOp, msg.Op())
require.Equal(nodeID, msg.NodeID())
require.Equal(mockable.MaxTime, msg.Expiration())
innerMsg, ok := msg.Message().(*p2p.AcceptedStateSummary)
require.True(ok)
require.IsType(&p2p.AcceptedStateSummary{}, msg.Message())
innerMsg := msg.Message().(*p2p.AcceptedStateSummary)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
summaryIDsBytes := make([][]byte, len(summaryIDs))
Expand Down Expand Up @@ -169,8 +169,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(nodeID, msg.NodeID())
require.False(msg.Expiration().Before(start.Add(deadline)))
require.False(end.Add(deadline).Before(msg.Expiration()))
innerMsg, ok := msg.Message().(*p2p.GetAcceptedFrontier)
require.True(ok)
require.IsType(&p2p.GetAcceptedFrontier{}, msg.Message())
innerMsg := msg.Message().(*p2p.GetAcceptedFrontier)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(engineType, innerMsg.EngineType)
Expand All @@ -192,8 +192,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(AcceptedFrontierOp, msg.Op())
require.Equal(nodeID, msg.NodeID())
require.Equal(mockable.MaxTime, msg.Expiration())
innerMsg, ok := msg.Message().(*p2p.AcceptedFrontier)
require.True(ok)
require.IsType(&p2p.AcceptedFrontier{}, msg.Message())
innerMsg := msg.Message().(*p2p.AcceptedFrontier)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
containerIDsBytes := make([][]byte, len(containerIDs))
Expand Down Expand Up @@ -225,8 +225,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(nodeID, msg.NodeID())
require.False(msg.Expiration().Before(start.Add(deadline)))
require.False(end.Add(deadline).Before(msg.Expiration()))
innerMsg, ok := msg.Message().(*p2p.GetAccepted)
require.True(ok)
require.IsType(&p2p.GetAccepted{}, msg.Message())
innerMsg := msg.Message().(*p2p.GetAccepted)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(engineType, innerMsg.EngineType)
Expand All @@ -248,8 +248,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(AcceptedOp, msg.Op())
require.Equal(nodeID, msg.NodeID())
require.Equal(mockable.MaxTime, msg.Expiration())
innerMsg, ok := msg.Message().(*p2p.Accepted)
require.True(ok)
require.IsType(&p2p.Accepted{}, msg.Message())
innerMsg := msg.Message().(*p2p.Accepted)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
containerIDsBytes := make([][]byte, len(containerIDs))
Expand Down Expand Up @@ -281,8 +281,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(nodeID, msg.NodeID())
require.False(msg.Expiration().Before(start.Add(deadline)))
require.False(end.Add(deadline).Before(msg.Expiration()))
innerMsg, ok := msg.Message().(*p2p.PushQuery)
require.True(ok)
require.IsType(&p2p.PushQuery{}, msg.Message())
innerMsg := msg.Message().(*p2p.PushQuery)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(container, innerMsg.Container)
Expand Down Expand Up @@ -310,8 +310,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(nodeID, msg.NodeID())
require.False(msg.Expiration().Before(start.Add(deadline)))
require.False(end.Add(deadline).Before(msg.Expiration()))
innerMsg, ok := msg.Message().(*p2p.PullQuery)
require.True(ok)
require.IsType(&p2p.PullQuery{}, msg.Message())
innerMsg := msg.Message().(*p2p.PullQuery)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(containerIDs[0][:], innerMsg.ContainerId)
Expand All @@ -335,8 +335,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(ChitsOp, msg.Op())
require.Equal(nodeID, msg.NodeID())
require.Equal(mockable.MaxTime, msg.Expiration())
innerMsg, ok := msg.Message().(*p2p.Chits)
require.True(ok)
require.IsType(&p2p.Chits{}, msg.Message())
innerMsg := msg.Message().(*p2p.Chits)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
containerIDsBytes := make([][]byte, len(containerIDs))
Expand Down Expand Up @@ -373,8 +373,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(nodeID, msg.NodeID())
require.False(msg.Expiration().Before(start.Add(deadline)))
require.False(end.Add(deadline).Before(msg.Expiration()))
innerMsg, ok := msg.Message().(*p2p.AppRequest)
require.True(ok)
require.IsType(&p2p.AppRequest{}, msg.Message())
innerMsg := msg.Message().(*p2p.AppRequest)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(appBytes, innerMsg.AppBytes)
Expand All @@ -396,8 +396,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(AppResponseOp, msg.Op())
require.Equal(nodeID, msg.NodeID())
require.Equal(mockable.MaxTime, msg.Expiration())
innerMsg, ok := msg.Message().(*p2p.AppResponse)
require.True(ok)
require.IsType(&p2p.AppResponse{}, msg.Message())
innerMsg := msg.Message().(*p2p.AppResponse)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(appBytes, innerMsg.AppBytes)
Expand Down
4 changes: 2 additions & 2 deletions message/messages_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ func TestNilInboundMessage(t *testing.T) {
parsedMsg, err := mb.parseInbound(msgBytes, ids.EmptyNodeID, func() {})
require.NoError(err)

pingMsg, ok := parsedMsg.message.(*p2p.Ping)
require.True(ok)
require.IsType(&p2p.Ping{}, parsedMsg.message)
pingMsg := parsedMsg.message.(*p2p.Ping)
require.NotNil(pingMsg)
}
4 changes: 2 additions & 2 deletions network/throttling/bandwidth_throttler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ func TestBandwidthThrottler(t *testing.T) {
}
throttlerIntf, err := newBandwidthThrottler(logging.NoLog{}, "", prometheus.NewRegistry(), config)
require.NoError(err)
throttler, ok := throttlerIntf.(*bandwidthThrottlerImpl)
require.True(ok)
require.IsType(&bandwidthThrottlerImpl{}, throttlerIntf)
throttler := throttlerIntf.(*bandwidthThrottlerImpl)
require.NotNil(throttler.log)
require.NotNil(throttler.limiters)
require.Equal(config.RefillRate, throttler.RefillRate)
Expand Down
4 changes: 2 additions & 2 deletions network/throttling/inbound_resource_throttler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ func TestNewSystemThrottler(t *testing.T) {
targeter := tracker.NewMockTargeter(ctrl)
throttlerIntf, err := NewSystemThrottler("", reg, config, cpuTracker, targeter)
require.NoError(err)
throttler, ok := throttlerIntf.(*systemThrottler)
require.True(ok)
require.IsType(&systemThrottler{}, throttlerIntf)
throttler := throttlerIntf.(*systemThrottler)
require.Equal(clock, config.Clock)
require.Equal(time.Second, config.MaxRecheckDelay)
require.Equal(cpuTracker, throttler.tracker)
Expand Down
10 changes: 6 additions & 4 deletions snow/consensus/snowball/unary_snowball_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package snowball

import (
"testing"

"github.com/stretchr/testify/require"
)

func UnarySnowballStateTest(t *testing.T, sb *unarySnowball, expectedNumSuccessfulPolls, expectedConfidence int, expectedFinalized bool) {
Expand All @@ -18,6 +20,8 @@ func UnarySnowballStateTest(t *testing.T, sb *unarySnowball, expectedNumSuccessf
}

func TestUnarySnowball(t *testing.T) {
require := require.New(t)

beta := 2

sb := &unarySnowball{}
Expand All @@ -33,10 +37,8 @@ func TestUnarySnowball(t *testing.T) {
UnarySnowballStateTest(t, sb, 2, 1, false)

sbCloneIntf := sb.Clone()
sbClone, ok := sbCloneIntf.(*unarySnowball)
if !ok {
t.Fatalf("Unexpected clone type")
}
require.IsType(&unarySnowball{}, sbCloneIntf)
sbClone := sbCloneIntf.(*unarySnowball)

UnarySnowballStateTest(t, sbClone, 2, 1, false)

Expand Down
10 changes: 6 additions & 4 deletions snow/consensus/snowball/unary_snowflake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package snowball

import (
"testing"

"github.com/stretchr/testify/require"
)

func UnarySnowflakeStateTest(t *testing.T, sf *unarySnowflake, expectedConfidence int, expectedFinalized bool) {
Expand All @@ -16,6 +18,8 @@ func UnarySnowflakeStateTest(t *testing.T, sf *unarySnowflake, expectedConfidenc
}

func TestUnarySnowflake(t *testing.T) {
require := require.New(t)

beta := 2

sf := &unarySnowflake{}
Expand All @@ -31,10 +35,8 @@ func TestUnarySnowflake(t *testing.T) {
UnarySnowflakeStateTest(t, sf, 1, false)

sfCloneIntf := sf.Clone()
sfClone, ok := sfCloneIntf.(*unarySnowflake)
if !ok {
t.Fatalf("Unexpected clone type")
}
require.IsType(&unarySnowflake{}, sfCloneIntf)
sfClone := sfCloneIntf.(*unarySnowflake)

UnarySnowflakeStateTest(t, sfClone, 1, false)

Expand Down
Loading