Skip to content

Commit

Permalink
refactor!: remove GetState() on connection interface (cosmos#5769)
Browse files Browse the repository at this point in the history
* rm: GetState() on connection interface

* lint

* lint
  • Loading branch information
colin-axner authored Jan 30, 2024
1 parent 7df8644 commit 2d210e2
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 79 deletions.
7 changes: 4 additions & 3 deletions e2e/tests/interchain_accounts/upgrades_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@ import (
"testing"
"time"

sdkmath "cosmossdk.io/math"
"github.com/cosmos/gogoproto/proto"
"github.com/strangelove-ventures/interchaintest/v8"
"github.com/strangelove-ventures/interchaintest/v8/ibc"
test "github.com/strangelove-ventures/interchaintest/v8/testutil"
testifysuite "github.com/stretchr/testify/suite"

sdkmath "cosmossdk.io/math"

sdk "github.com/cosmos/cosmos-sdk/types"
banktypes "github.com/cosmos/cosmos-sdk/x/bank/types"
govtypes "github.com/cosmos/cosmos-sdk/x/gov/types"
"github.com/cosmos/gogoproto/proto"

"github.com/cosmos/ibc-go/e2e/testsuite"
"github.com/cosmos/ibc-go/e2e/testvalues"
Expand Down Expand Up @@ -162,7 +163,7 @@ func (s *InterchainAccountsChannelUpgradesTestSuite) TestChannelUpgrade_ICAChann
Memo: "e2e",
}

timeout := uint64(1)
timeout := uint64(1)
msgSendTx := controllertypes.NewMsgSendTx(controllerAddress, ibctesting.FirstConnectionID, timeout, packetData)

resp := s.BroadcastMessages(
Expand Down
3 changes: 1 addition & 2 deletions e2e/tests/upgrades/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import (
"testing"
"time"

transfertypes "github.com/cosmos/ibc-go/v8/modules/apps/transfer/types"

"github.com/cosmos/gogoproto/proto"
interchaintest "github.com/strangelove-ventures/interchaintest/v8"
"github.com/strangelove-ventures/interchaintest/v8/chain/cosmos"
Expand All @@ -29,6 +27,7 @@ import (
"github.com/cosmos/ibc-go/e2e/testsuite"
"github.com/cosmos/ibc-go/e2e/testvalues"
feetypes "github.com/cosmos/ibc-go/v8/modules/apps/29-fee/types"
transfertypes "github.com/cosmos/ibc-go/v8/modules/apps/transfer/types"
v7migrations "github.com/cosmos/ibc-go/v8/modules/core/02-client/migrations/v7"
clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types"
connectiontypes "github.com/cosmos/ibc-go/v8/modules/core/03-connection/types"
Expand Down
3 changes: 1 addition & 2 deletions e2e/testsuite/sanitize/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ import (
govtypesv1 "github.com/cosmos/cosmos-sdk/x/gov/types/v1"
grouptypes "github.com/cosmos/cosmos-sdk/x/group"

"github.com/cosmos/ibc-go/e2e/semverutil"
icacontrollertypes "github.com/cosmos/ibc-go/v8/modules/apps/27-interchain-accounts/controller/types"
channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types"

"github.com/cosmos/ibc-go/e2e/semverutil"
)

var (
Expand Down
5 changes: 0 additions & 5 deletions modules/core/03-connection/types/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@ func NewConnectionEnd(state State, clientID string, counterparty Counterparty, v
}
}

// GetState implements the Connection interface
func (c ConnectionEnd) GetState() int32 {
return int32(c.State)
}

// GetClientID implements the Connection interface
func (c ConnectionEnd) GetClientID() string {
return c.ClientId
Expand Down
35 changes: 10 additions & 25 deletions modules/core/04-channel/keeper/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,8 @@ func (k Keeper) ChanOpenTry(
return "", nil, errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, connectionHops[0])
}

if connectionEnd.GetState() != int32(connectiontypes.OPEN) {
return "", nil, errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState,
"connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(),
)
if connectionEnd.State != connectiontypes.OPEN {
return "", nil, errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectionEnd.State)
}

getVersions := connectionEnd.GetVersions()
Expand Down Expand Up @@ -242,11 +239,8 @@ func (k Keeper) ChanOpenAck(
return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0])
}

if connectionEnd.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState,
"connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(),
)
if connectionEnd.State != connectiontypes.OPEN {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectionEnd.State)
}

counterpartyHops := []string{connectionEnd.GetCounterparty().GetConnectionID()}
Expand Down Expand Up @@ -321,11 +315,8 @@ func (k Keeper) ChanOpenConfirm(
return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0])
}

if connectionEnd.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState,
"connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(),
)
if connectionEnd.State != connectiontypes.OPEN {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectionEnd.State)
}

counterpartyHops := []string{connectionEnd.GetCounterparty().GetConnectionID()}
Expand Down Expand Up @@ -405,11 +396,8 @@ func (k Keeper) ChanCloseInit(
return errorsmod.Wrapf(clienttypes.ErrClientNotActive, "client (%s) status is %s", connectionEnd.ClientId, status)
}

if connectionEnd.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState,
"connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(),
)
if connectionEnd.State != connectiontypes.OPEN {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectionEnd.State)
}

k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", channel.State.String(), "new-state", types.CLOSED.String())
Expand Down Expand Up @@ -453,11 +441,8 @@ func (k Keeper) ChanCloseConfirm(
return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0])
}

if connectionEnd.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState,
"connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(),
)
if connectionEnd.State != connectiontypes.OPEN {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectionEnd.State)
}

counterpartyHops := []string{connectionEnd.GetCounterparty().GetConnectionID()}
Expand Down
14 changes: 4 additions & 10 deletions modules/core/04-channel/keeper/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,8 @@ func (k Keeper) RecvPacket(
return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0])
}

if connectionEnd.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState,
"connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(),
)
if connectionEnd.State != connectiontypes.OPEN {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectionEnd.State)
}

// check if packet timed out by comparing it with the latest height of the chain
Expand Down Expand Up @@ -400,11 +397,8 @@ func (k Keeper) AcknowledgePacket(
return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0])
}

if connectionEnd.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState,
"connection state is not OPEN (got %s)", connectiontypes.State(connectionEnd.GetState()).String(),
)
if connectionEnd.State != connectiontypes.OPEN {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectionEnd.State)
}

commitment := k.GetPacketCommitment(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), packet.GetSequence())
Expand Down
51 changes: 20 additions & 31 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,8 @@ func (k Keeper) ChanUpgradeTry(
return types.Channel{}, types.Upgrade{}, errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0])
}

if connection.GetState() != int32(connectiontypes.OPEN) {
return types.Channel{}, types.Upgrade{}, errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String(),
)
if connection.State != connectiontypes.OPEN {
return types.Channel{}, types.Upgrade{}, errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State)
}

// construct expected counterparty channel from information in state
Expand Down Expand Up @@ -276,8 +274,8 @@ func (k Keeper) ChanUpgradeAck(
return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0])
}

if connection.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String())
if connection.State != connectiontypes.OPEN {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State)
}

counterpartyHops := []string{connection.GetCounterparty().GetConnectionID()}
Expand Down Expand Up @@ -412,8 +410,8 @@ func (k Keeper) ChanUpgradeConfirm(
return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0])
}

if connection.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String())
if connection.State != connectiontypes.OPEN {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State)
}

counterpartyHops := []string{connection.GetCounterparty().GetConnectionID()}
Expand Down Expand Up @@ -507,8 +505,8 @@ func (k Keeper) ChanUpgradeOpen(
return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0])
}

if connection.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String())
if connection.State != connectiontypes.OPEN {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State)
}

var counterpartyChannel types.Channel
Expand All @@ -524,8 +522,8 @@ func (k Keeper) ChanUpgradeOpen(
return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, upgrade.Fields.ConnectionHops[0])
}

if upgradeConnection.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(upgradeConnection.GetState()).String())
if upgradeConnection.State != connectiontypes.OPEN {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", upgradeConnection.State)
}

// The counterparty upgrade sequence must be greater than or equal to
Expand Down Expand Up @@ -675,11 +673,8 @@ func (k Keeper) ChanUpgradeCancel(ctx sdk.Context, portID, channelID string, err
return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0])
}

if connection.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState,
"connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String(),
)
if connection.State != connectiontypes.OPEN {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State)
}

if err := k.connectionKeeper.VerifyChannelUpgradeError(
Expand Down Expand Up @@ -746,11 +741,8 @@ func (k Keeper) ChanUpgradeTimeout(
)
}

if connection.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState,
"connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String(),
)
if connection.State != connectiontypes.OPEN {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State)
}

proofTimestamp, err := k.connectionKeeper.GetTimestampAtHeight(ctx, connection, proofHeight)
Expand Down Expand Up @@ -853,8 +845,8 @@ func (k Keeper) startFlushing(ctx sdk.Context, portID, channelID string, upgrade
return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, channel.ConnectionHops[0])
}

if connection.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String())
if connection.State != connectiontypes.OPEN {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State)
}

channel.State = types.FLUSHING
Expand Down Expand Up @@ -896,10 +888,10 @@ func (k Keeper) checkForUpgradeCompatibility(ctx sdk.Context, upgradeFields, cou
return errorsmod.Wrap(connectiontypes.ErrConnectionNotFound, upgradeFields.ConnectionHops[0])
}

if connection.GetState() != int32(connectiontypes.OPEN) {
if connection.State != connectiontypes.OPEN {
// NOTE: this error is expected to be unreachable as the proposed upgrade connectionID should have been
// validated in the upgrade INIT and TRY handlers
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "expected proposed connection to be OPEN (got %s)", connectiontypes.State(connection.GetState()).String())
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "expected proposed connection to be OPEN (got %s)", connection.State)
}

// connectionHops can change in a channelUpgrade, however both sides must still be each other's counterparty.
Expand Down Expand Up @@ -930,11 +922,8 @@ func (k Keeper) validateSelfUpgradeFields(ctx sdk.Context, proposedUpgrade types
return errorsmod.Wrapf(connectiontypes.ErrConnectionNotFound, "failed to retrieve connection: %s", connectionID)
}

if connection.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState,
"connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String(),
)
if connection.State != connectiontypes.OPEN {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connection.State)
}

getVersions := connection.GetVersions()
Expand Down
1 change: 0 additions & 1 deletion modules/core/exported/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ const LocalhostConnectionID string = "connection-localhost"
// ConnectionI describes the required methods for a connection.
type ConnectionI interface {
GetClientID() string
GetState() int32
GetCounterparty() CounterpartyConnectionI
GetDelayPeriod() uint64
ValidateBasic() error
Expand Down

0 comments on commit 2d210e2

Please sign in to comment.