Skip to content

Commit

Permalink
fix: add replay protection on upgraded channels (cosmos#5651)
Browse files Browse the repository at this point in the history
* test: add integration test for double spend attack

* refactor: draft alternative approach to fixing double spend

* refactor: cleanup tests, deduplicate key storage, add documentation

* godoc

* test: add packet already recevied unit test case

* satisfy the linter

* imp: add additional comment to integration test

* imp: add a little more info to the test comment

* review suggestions + make setRecvStartSeqeuence private
  • Loading branch information
colin-axner authored Jan 22, 2024
1 parent 40564ed commit 8b6932b
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 34 deletions.
5 changes: 5 additions & 0 deletions modules/core/04-channel/keeper/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ func (k Keeper) CheckForUpgradeCompatibility(ctx sdk.Context, upgradeFields, cou
func (k Keeper) SetUpgradeErrorReceipt(ctx sdk.Context, portID, channelID string, errorReceipt types.ErrorReceipt) {
k.setUpgradeErrorReceipt(ctx, portID, channelID, errorReceipt)
}

// SetRecvStartSequence is a wrapper around setRecvStartSequence to allow the function to be directly called in tests.
func (k Keeper) SetRecvStartSequence(ctx sdk.Context, portID, channelID string, sequence uint64) {
k.setRecvStartSequence(ctx, portID, channelID, sequence)
}
19 changes: 11 additions & 8 deletions modules/core/04-channel/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -623,17 +623,20 @@ func (k Keeper) HasInflightPackets(ctx sdk.Context, portID, channelID string) bo
return iterator.Valid()
}

// SetPruningSequenceEnd sets the channel's pruning sequence end to the store.
func (k Keeper) SetPruningSequenceEnd(ctx sdk.Context, portID, channelID string, sequence uint64) {
// setRecvStartSequence sets the channel's recv start sequence to the store.
func (k Keeper) setRecvStartSequence(ctx sdk.Context, portID, channelID string, sequence uint64) {
store := ctx.KVStore(k.storeKey)
bz := sdk.Uint64ToBigEndian(sequence)
store.Set(host.PruningSequenceEndKey(portID, channelID), bz)
store.Set(host.RecvStartSequenceKey(portID, channelID), bz)
}

// GetPruningSequenceEnd gets a channel's pruning sequence end from the store.
func (k Keeper) GetPruningSequenceEnd(ctx sdk.Context, portID, channelID string) (uint64, bool) {
// GetRecvStartSequence gets a channel's recv start sequence from the store.
// The recv start sequence will be set to the counterparty's next sequence send
// upon a successful channel upgrade. It will be used for replay protection of
// historical packets and as the upper bound for pruning stale packet receives.
func (k Keeper) GetRecvStartSequence(ctx sdk.Context, portID, channelID string) (uint64, bool) {
store := ctx.KVStore(k.storeKey)
bz := store.Get(host.PruningSequenceEndKey(portID, channelID))
bz := store.Get(host.RecvStartSequenceKey(portID, channelID))
if len(bz) == 0 {
return 0, false
}
Expand Down Expand Up @@ -675,9 +678,9 @@ func (k Keeper) PruneAcknowledgements(ctx sdk.Context, portID, channelID string,
if !found {
return 0, 0, errorsmod.Wrapf(types.ErrPruningSequenceStartNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}
pruningSequenceEnd, found := k.GetPruningSequenceEnd(ctx, portID, channelID)
pruningSequenceEnd, found := k.GetRecvStartSequence(ctx, portID, channelID)
if !found {
return 0, 0, errorsmod.Wrapf(types.ErrPruningSequenceEndNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
return 0, 0, errorsmod.Wrapf(types.ErrRecvStartSequenceNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}

start := pruningSequenceStart
Expand Down
10 changes: 5 additions & 5 deletions modules/core/04-channel/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ func (suite *KeeperTestSuite) TestPruneAcknowledgements() {
// Assert that PruneSequenceStart and PruneSequenceEnd are both set to 1.
start, found := suite.chainA.App.GetIBCKeeper().ChannelKeeper.GetPruningSequenceStart(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.Require().True(found)
end, found := suite.chainA.App.GetIBCKeeper().ChannelKeeper.GetPruningSequenceEnd(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
end, found := suite.chainA.App.GetIBCKeeper().ChannelKeeper.GetRecvStartSequence(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.Require().True(found)

suite.Require().Equal(uint64(1), start)
Expand All @@ -600,7 +600,7 @@ func (suite *KeeperTestSuite) TestPruneAcknowledgements() {
},
func() {},
func(pruned, left uint64) {
sequenceEnd, found := suite.chainA.App.GetIBCKeeper().ChannelKeeper.GetPruningSequenceEnd(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
sequenceEnd, found := suite.chainA.App.GetIBCKeeper().ChannelKeeper.GetRecvStartSequence(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.Require().True(found)

// We expect nothing to be left and sequenceStart == sequenceEnd.
Expand Down Expand Up @@ -672,7 +672,7 @@ func (suite *KeeperTestSuite) TestPruneAcknowledgements() {
limit = 15
},
func(pruned, left uint64) {
sequenceEnd, found := suite.chainA.App.GetIBCKeeper().ChannelKeeper.GetPruningSequenceEnd(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
sequenceEnd, found := suite.chainA.App.GetIBCKeeper().ChannelKeeper.GetRecvStartSequence(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.Require().True(found)

// We expect nothing to be left and sequenceStart == sequenceEnd.
Expand Down Expand Up @@ -825,10 +825,10 @@ func (suite *KeeperTestSuite) TestPruneAcknowledgements() {
func() {},
func() {
store := suite.chainA.GetContext().KVStore(suite.chainA.GetSimApp().GetKey(exported.StoreKey))
store.Delete(host.PruningSequenceEndKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID))
store.Delete(host.RecvStartSequenceKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID))
},
func(_, _ uint64) {},
types.ErrPruningSequenceEndNotFound,
types.ErrRecvStartSequenceNotFound,
},
}

Expand Down
17 changes: 15 additions & 2 deletions modules/core/04-channel/keeper/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,20 @@ func (k Keeper) RecvPacket(
return errorsmod.Wrap(err, "couldn't verify counterparty packet commitment")
}

// REPLAY PROTECTION: The recvStartSequence will prevent historical proofs from allowing replay
// attacks on packets processed in previous lifecycles of a channel. After a successful channel
// upgrade all packets under the recvStartSequence will have been processed and thus should be
// rejected.
recvStartSequence, _ := k.GetRecvStartSequence(ctx, packet.GetDestPort(), packet.GetDestChannel())
if packet.GetSequence() < recvStartSequence {
return errorsmod.Wrap(types.ErrPacketReceived, "packet already processed in previous channel upgrade")
}

switch channel.Ordering {
case types.UNORDERED:
// check if the packet receipt has been received already for unordered channels
// REPLAY PROTECTION: Packet receipts will indicate that a packet has already been received
// on unordered channels. Packet receipts must not be pruned, unless it has been marked stale
// by the increase of the recvStartSequence.
_, found := k.GetPacketReceipt(ctx, packet.GetDestPort(), packet.GetDestChannel(), packet.GetSequence())
if found {
emitRecvPacketEvent(ctx, packet, channel)
Expand All @@ -212,7 +223,7 @@ func (k Keeper) RecvPacket(
// All verification complete, update state
// For unordered channels we must set the receipt so it can be verified on the other side.
// This receipt does not contain any data, since the packet has not yet been processed,
// it's just a single store key set to an empty string to indicate that the packet has been received
// it's just a single store key set to a single byte to indicate that the packet has been received
k.SetPacketReceipt(ctx, packet.GetDestPort(), packet.GetDestChannel(), packet.GetSequence())

case types.ORDERED:
Expand All @@ -233,6 +244,8 @@ func (k Keeper) RecvPacket(
return types.ErrNoOpMsg
}

// REPLAY PROTECTION: Ordered channels require packets to be received in a strict order.
// Any out of order or previously received packets are rejected.
if packet.GetSequence() != nextSequenceRecv {
return errorsmod.Wrapf(
types.ErrPacketSequenceOutOfOrder,
Expand Down
15 changes: 15 additions & 0 deletions modules/core/04-channel/keeper/packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,21 @@ func (suite *KeeperTestSuite) TestRecvPacket() {
},
types.ErrSequenceReceiveNotFound,
},
{
"packet already received",
func() {
suite.coordinator.Setup(path)

sequence, err := path.EndpointA.SendPacket(defaultTimeoutHeight, disabledTimeoutTimestamp, ibctesting.MockPacketData)
suite.Require().NoError(err)
packet = types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, defaultTimeoutHeight, disabledTimeoutTimestamp)
channelCap = suite.chainB.GetChannelCapability(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)

// set recv seq start to indicate packet was processed in previous upgrade
suite.chainB.App.GetIBCKeeper().ChannelKeeper.SetRecvStartSequence(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, sequence+1)
},
types.ErrPacketReceived,
},
{
"receipt already stored",
func() {
Expand Down
6 changes: 4 additions & 2 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -581,8 +581,10 @@ func (k Keeper) WriteUpgradeOpenChannel(ctx sdk.Context, portID, channelID strin
k.SetNextSequenceAck(ctx, portID, channelID, upgrade.NextSequenceSend)
}

// set the counterparty next sequence send as pruning sequence end in order to have upper bound to prune to
k.SetPruningSequenceEnd(ctx, portID, channelID, counterpartyUpgrade.NextSequenceSend)
// Set the counterparty next sequence send as the recv start sequence.
// This will be the upper bound for pruning and it will allow for replay
// protection of historical packets.
k.setRecvStartSequence(ctx, portID, channelID, counterpartyUpgrade.NextSequenceSend)

// First upgrade for this channel will set the pruning sequence to 1, the starting sequence for pruning.
// Subsequent upgrades will not modify the pruning sequence thereby allowing pruning to continue from the last
Expand Down
16 changes: 8 additions & 8 deletions modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1404,8 +1404,8 @@ func (suite *KeeperTestSuite) TestWriteUpgradeOpenChannel_Ordering() {
// Assert that pruning sequence start has not been initialized.
suite.Require().False(suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.HasPruningSequenceStart(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID))

// Assert that pruning sequence end has not been set
counterpartyNextSequenceSend, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetPruningSequenceEnd(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
// Assert that recv start sequence has not been set
counterpartyNextSequenceSend, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetRecvStartSequence(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.Require().False(found)
suite.Require().Equal(uint64(0), counterpartyNextSequenceSend)
},
Expand Down Expand Up @@ -1433,8 +1433,8 @@ func (suite *KeeperTestSuite) TestWriteUpgradeOpenChannel_Ordering() {
suite.Require().True(found)
suite.Require().Equal(uint64(1), pruningSeq)

// Assert that pruning sequence end has been set correctly
counterpartySequenceSend, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetPruningSequenceEnd(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
// Assert that the recv start sequence has been set correctly
counterpartySequenceSend, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetRecvStartSequence(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.Require().True(found)
suite.Require().Equal(uint64(2), counterpartySequenceSend)
},
Expand Down Expand Up @@ -1464,8 +1464,8 @@ func (suite *KeeperTestSuite) TestWriteUpgradeOpenChannel_Ordering() {
// Assert that pruning sequence start has not been initialized.
suite.Require().False(suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.HasPruningSequenceStart(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID))

// Assert that pruning sequence end has not been set
counterpartyNextSequenceSend, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetPruningSequenceEnd(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
// Assert that recv start sequence has not been set
counterpartyNextSequenceSend, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetRecvStartSequence(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.Require().False(found)
suite.Require().Equal(uint64(0), counterpartyNextSequenceSend)
},
Expand Down Expand Up @@ -1494,8 +1494,8 @@ func (suite *KeeperTestSuite) TestWriteUpgradeOpenChannel_Ordering() {
suite.Require().True(found)
suite.Require().Equal(uint64(1), pruningSeq)

// Assert that pruning sequence end has been set correctly
counterpartySequenceSend, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetPruningSequenceEnd(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
// Assert that the recv start sequence has been set correctly
counterpartySequenceSend, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetRecvStartSequence(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.Require().True(found)
suite.Require().Equal(uint64(2), counterpartySequenceSend)
},
Expand Down
2 changes: 1 addition & 1 deletion modules/core/04-channel/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,5 @@ var (
ErrTimeoutNotReached = errorsmod.Register(SubModuleName, 39, "timeout not reached")
ErrTimeoutElapsed = errorsmod.Register(SubModuleName, 40, "timeout elapsed")
ErrPruningSequenceStartNotFound = errorsmod.Register(SubModuleName, 41, "pruning sequence start not found")
ErrPruningSequenceEndNotFound = errorsmod.Register(SubModuleName, 42, "pruning sequence end not found")
ErrRecvStartSequenceNotFound = errorsmod.Register(SubModuleName, 42, "recv start sequence not found")
)
14 changes: 7 additions & 7 deletions modules/core/24-host/packet_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const (
KeyPacketAckPrefix = "acks"
KeyPacketReceiptPrefix = "receipts"
KeyPruningSequenceStart = "pruningSequenceStart"
KeyPruningSequenceEnd = "pruningSequenceEnd"
KeyRecvStartSequence = "recvStartSequence"
)

// ICS04
Expand Down Expand Up @@ -103,14 +103,14 @@ func PruningSequenceStartKey(portID, channelID string) []byte {
return []byte(PruningSequenceStartPath(portID, channelID))
}

// PruningSequenceEndPath defines the path under which the pruning sequence end is stored
func PruningSequenceEndPath(portID, channelID string) string {
return fmt.Sprintf("%s/%s", KeyPruningSequenceEnd, channelPath(portID, channelID))
// RecvStartSequencePath defines the path under which the recv start sequence is stored
func RecvStartSequencePath(portID, channelID string) string {
return fmt.Sprintf("%s/%s", KeyRecvStartSequence, channelPath(portID, channelID))
}

// PruningSequenceEndKey returns the store key for the pruning sequence end of a particular channel
func PruningSequenceEndKey(portID, channelID string) []byte {
return []byte(PruningSequenceEndPath(portID, channelID))
// RecvStartSequenceKey returns the store key for the recv start sequence of a particular channel
func RecvStartSequenceKey(portID, channelID string) []byte {
return []byte(RecvStartSequencePath(portID, channelID))
}

func sequencePath(sequence uint64) string {
Expand Down
98 changes: 98 additions & 0 deletions modules/core/integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package ibc_test

import (
clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types"
channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types"
host "github.com/cosmos/ibc-go/v8/modules/core/24-host"
ibctesting "github.com/cosmos/ibc-go/v8/testing"
ibcmock "github.com/cosmos/ibc-go/v8/testing/mock"
)

// If packet receipts are pruned, it may be possible to double spend via a
// replay attack by resubmitting the same proof used to process the original receive.
// Core IBC performs an additional check to ensure that any packet being received
// MUST NOT be in the range of packet receipts which are allowed to be pruned thus
// adding replay protection for upgraded channels.
// This test has been added to ensure we have replay protection after
// pruning stale state upon the successful completion of a channel upgrade.
func (suite *IBCTestSuite) TestReplayProtectionAfterReceivePruning() {
var path *ibctesting.Path

testCases := []struct {
name string
malleate func()
}{
{
"unordered channel upgrades version",
func() {
path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = ibcmock.UpgradeVersion
path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = ibcmock.UpgradeVersion
},
},
{
"ordered channel upgrades to unordered channel",
func() {
path.EndpointA.ChannelConfig.Order = channeltypes.ORDERED
path.EndpointB.ChannelConfig.Order = channeltypes.ORDERED

path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Ordering = channeltypes.UNORDERED
path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Ordering = channeltypes.UNORDERED
},
},
}

for _, tc := range testCases {
tc := tc
suite.Run(tc.name, func() {
suite.SetupTest()
path = ibctesting.NewPath(suite.chainA, suite.chainB)

tc.malleate()

suite.coordinator.Setup(path)

// Setup replay attack by sending a packet. We will save the receive
// proof to replay relaying after the channel upgrade compeletes.
disabledTimeoutTimestamp := uint64(0)
timeoutHeight := clienttypes.NewHeight(1, 110)
sequence, err := path.EndpointA.SendPacket(timeoutHeight, disabledTimeoutTimestamp, ibctesting.MockPacketData)
suite.Require().NoError(err)
packet := channeltypes.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, timeoutHeight, disabledTimeoutTimestamp)

// save receive proof for replay submission
packetKey := host.PacketCommitmentKey(packet.GetSourcePort(), packet.GetSourceChannel(), packet.GetSequence())
proof, proofHeight := path.EndpointA.Chain.QueryProof(packetKey)
recvMsg := channeltypes.NewMsgRecvPacket(packet, proof, proofHeight, path.EndpointB.Chain.SenderAccount.GetAddress().String())

err = path.RelayPacket(packet)
suite.Require().NoError(err)

// perform upgrade
err = path.EndpointA.ChanUpgradeInit()
suite.Require().NoError(err)

err = path.EndpointB.ChanUpgradeTry()
suite.Require().NoError(err)

err = path.EndpointA.ChanUpgradeAck()
suite.Require().NoError(err)

err = path.EndpointB.ChanUpgradeConfirm()
suite.Require().NoError(err)

err = path.EndpointA.ChanUpgradeOpen()
suite.Require().NoError(err)

// prune stale receive state
msgPrune := channeltypes.NewMsgPruneAcknowledgements(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, 1, path.EndpointB.Chain.SenderAccount.GetAddress().String())
res, err := path.EndpointB.Chain.SendMsgs(msgPrune)
suite.Require().NotNil(res)
suite.Require().NoError(err)

// replay initial packet send
res, err = path.EndpointB.Chain.SendMsgs(recvMsg)
suite.Require().NotNil(res)
suite.Require().ErrorContains(err, channeltypes.ErrPacketReceived.Error(), "replay protection missing")
})
}
}
2 changes: 1 addition & 1 deletion modules/core/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2697,7 +2697,7 @@ func (suite *KeeperTestSuite) TestPruneAcknowledgements() {
func() {
msg.PortId = "portidone"
},
channeltypes.ErrPruningSequenceEndNotFound,
channeltypes.ErrRecvStartSequenceNotFound,
},
}

Expand Down

0 comments on commit 8b6932b

Please sign in to comment.