Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert application state changes on failed acknowledgements #107

Merged
merged 13 commits into from
Apr 12, 2021
Next Next commit
make changes and start fixing tests
  • Loading branch information
AdityaSripal authored and colin-axner committed Apr 7, 2021
commit 5849d5a11837310c98612792586eebb9c20519e8
2 changes: 1 addition & 1 deletion modules/apps/transfer/keeper/mbt_relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ func (suite *KeeperTestSuite) TestModelBasedRelay() {
0)
}
case "OnRecvPacket":
err = suite.chainB.App.TransferKeeper.OnRecvPacket(suite.chainB.GetContext(), packet, tc.packet.Data)
_, err = suite.chainB.App.TransferKeeper.OnRecvPacket(suite.chainB.GetContext(), packet, tc.packet.Data)
case "OnTimeoutPacket":
registerDenom()
err = suite.chainB.App.TransferKeeper.OnTimeoutPacket(suite.chainB.GetContext(), packet, tc.packet.Data)
Expand Down
22 changes: 13 additions & 9 deletions modules/apps/transfer/keeper/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,20 +183,20 @@ func (k Keeper) SendTransfer(
// and sent to the receiving address. Otherwise if the sender chain is sending
// back tokens this chain originally transferred to it, the tokens are
// unescrowed and sent to the receiving address.
func (k Keeper) OnRecvPacket(ctx sdk.Context, packet channeltypes.Packet, data types.FungibleTokenPacketData) error {
func (k Keeper) OnRecvPacket(ctx sdk.Context, packet channeltypes.Packet, data types.FungibleTokenPacketData) (revert bool, err error) {
// validate packet data upon receiving
if err := data.ValidateBasic(); err != nil {
return err
return false, err
}

if !k.GetReceiveEnabled(ctx) {
return types.ErrReceiveDisabled
return false, types.ErrReceiveDisabled
}

// decode the receiver address
receiver, err := sdk.AccAddressFromBech32(data.Receiver)
if err != nil {
return err
return false, err
}

labels := []metrics.Label{
Expand Down Expand Up @@ -237,7 +237,7 @@ func (k Keeper) OnRecvPacket(ctx sdk.Context, packet channeltypes.Packet, data t
// counterparty module. The bug may occur in bank or any part of the code that allows
// the escrow address to be drained. A malicious counterparty module could drain the
// escrow address by allowing more tokens to be sent back then were escrowed.
return sdkerrors.Wrap(err, "unable to unescrow tokens, this may be caused by a malicious counterparty module or a bug: please open an issue on counterparty module")
return false, sdkerrors.Wrap(err, "unable to unescrow tokens, this may be caused by a malicious counterparty module or a bug: please open an issue on counterparty module")
}

defer func() {
Expand All @@ -256,7 +256,7 @@ func (k Keeper) OnRecvPacket(ctx sdk.Context, packet channeltypes.Packet, data t
)
}()

return nil
return false, nil
}

// sender chain is the source, mint vouchers
Expand Down Expand Up @@ -289,14 +289,18 @@ func (k Keeper) OnRecvPacket(ctx sdk.Context, packet channeltypes.Packet, data t
if err := k.bankKeeper.MintCoins(
ctx, types.ModuleName, sdk.NewCoins(voucher),
); err != nil {
return err
// error returned by bank keeper, revert all state changes by callback
// and continue tx processing
return true, err
}

// send to receiver
if err := k.bankKeeper.SendCoinsFromModuleToAccount(
ctx, types.ModuleName, receiver, sdk.NewCoins(voucher),
); err != nil {
panic(fmt.Sprintf("unable to send coins from module to account despite previously minting coins to module account: %v", err))
// error returned by bank keeper, revert all state changes by callback
// and continue tx processing
return true, err
}

defer func() {
Expand All @@ -315,7 +319,7 @@ func (k Keeper) OnRecvPacket(ctx sdk.Context, packet channeltypes.Packet, data t
)
}()

return nil
return false, nil
}

// OnAcknowledgementPacket responds to the the success or failure of a packet
Expand Down
2 changes: 1 addition & 1 deletion modules/apps/transfer/keeper/relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func (suite *KeeperTestSuite) TestOnRecvPacket() {
data := types.NewFungibleTokenPacketData(trace.GetFullDenomPath(), amount.Uint64(), suite.chainA.SenderAccount.GetAddress().String(), receiver)
packet := channeltypes.NewPacket(data.GetBytes(), seq, channelA.PortID, channelA.ID, channelB.PortID, channelB.ID, clienttypes.NewHeight(0, 100), 0)

err = suite.chainB.App.TransferKeeper.OnRecvPacket(suite.chainB.GetContext(), packet, data)
_, err = suite.chainB.App.TransferKeeper.OnRecvPacket(suite.chainB.GetContext(), packet, data)

if tc.expPass {
suite.Require().NoError(err)
Expand Down
8 changes: 4 additions & 4 deletions modules/apps/transfer/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,15 @@ func (am AppModule) OnChanCloseConfirm(
func (am AppModule) OnRecvPacket(
ctx sdk.Context,
packet channeltypes.Packet,
) (*sdk.Result, []byte, error) {
) (*sdk.Result, []byte, bool, error) {
var data types.FungibleTokenPacketData
if err := types.ModuleCdc.UnmarshalJSON(packet.GetData(), &data); err != nil {
return nil, nil, sdkerrors.Wrapf(sdkerrors.ErrUnknownRequest, "cannot unmarshal ICS-20 transfer packet data: %s", err.Error())
return nil, nil, false, sdkerrors.Wrapf(sdkerrors.ErrUnknownRequest, "cannot unmarshal ICS-20 transfer packet data: %s", err.Error())
}

acknowledgement := channeltypes.NewResultAcknowledgement([]byte{byte(1)})

err := am.keeper.OnRecvPacket(ctx, packet, data)
revert, err := am.keeper.OnRecvPacket(ctx, packet, data)
if err != nil {
acknowledgement = channeltypes.NewErrorAcknowledgement(err.Error())
}
Expand All @@ -349,7 +349,7 @@ func (am AppModule) OnRecvPacket(
// NOTE: acknowledgement will be written synchronously during IBC handler execution.
return &sdk.Result{
Events: ctx.EventManager().Events().ToABCIEvents(),
}, acknowledgement.GetBytes(), nil
}, acknowledgement.GetBytes(), revert, nil
}

// OnAcknowledgementPacket implements the IBCModule interface
Expand Down
8 changes: 4 additions & 4 deletions modules/core/03-connection/keeper/verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ func (suite *KeeperTestSuite) TestVerifyPacketCommitment() {
connection.ClientId = ibctesting.InvalidID
}

packet := channeltypes.NewPacket(ibctesting.TestHash, 1, channelA.PortID, channelA.ID, channelB.PortID, channelB.ID, defaultTimeoutHeight, 0)
packet := channeltypes.NewPacket(ibctesting.MockPacketData, 1, channelA.PortID, channelA.ID, channelB.PortID, channelB.ID, defaultTimeoutHeight, 0)
err := suite.coordinator.SendPacket(suite.chainA, suite.chainB, packet, clientB)
suite.Require().NoError(err)

Expand Down Expand Up @@ -344,7 +344,7 @@ func (suite *KeeperTestSuite) TestVerifyPacketAcknowledgement() {
}

// send and receive packet
packet := channeltypes.NewPacket(ibctesting.TestHash, 1, channelA.PortID, channelA.ID, channelB.PortID, channelB.ID, defaultTimeoutHeight, 0)
packet := channeltypes.NewPacket(ibctesting.MockPacketData, 1, channelA.PortID, channelA.ID, channelB.PortID, channelB.ID, defaultTimeoutHeight, 0)
err := suite.coordinator.SendPacket(suite.chainA, suite.chainB, packet, clientB)
suite.Require().NoError(err)

Expand Down Expand Up @@ -412,7 +412,7 @@ func (suite *KeeperTestSuite) TestVerifyPacketReceiptAbsence() {
}

// send, only receive if specified
packet := channeltypes.NewPacket(ibctesting.TestHash, 1, channelA.PortID, channelA.ID, channelB.PortID, channelB.ID, defaultTimeoutHeight, 0)
packet := channeltypes.NewPacket(ibctesting.MockPacketData, 1, channelA.PortID, channelA.ID, channelB.PortID, channelB.ID, defaultTimeoutHeight, 0)
err := suite.coordinator.SendPacket(suite.chainA, suite.chainB, packet, clientB)
suite.Require().NoError(err)

Expand Down Expand Up @@ -481,7 +481,7 @@ func (suite *KeeperTestSuite) TestVerifyNextSequenceRecv() {
}

// send and receive packet
packet := channeltypes.NewPacket(ibctesting.TestHash, 1, channelA.PortID, channelA.ID, channelB.PortID, channelB.ID, defaultTimeoutHeight, 0)
packet := channeltypes.NewPacket(ibctesting.MockPacketData, 1, channelA.PortID, channelA.ID, channelB.PortID, channelB.ID, defaultTimeoutHeight, 0)
err := suite.coordinator.SendPacket(suite.chainA, suite.chainB, packet, clientB)
suite.Require().NoError(err)

Expand Down
Loading