Skip to content

Commit

Permalink
Merge branch 'master' into sai/9684-pretty-json-to-tx-commands
Browse files Browse the repository at this point in the history
  • Loading branch information
anilcse authored Jul 22, 2021
2 parents dfb5b50 + adff7d8 commit 8557221
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 7 deletions.
86 changes: 85 additions & 1 deletion x/auth/tx/decoder.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package tx

import (
"fmt"

"google.golang.org/protobuf/encoding/protowire"

"github.com/cosmos/cosmos-sdk/codec"
"github.com/cosmos/cosmos-sdk/codec/unknownproto"
sdk "github.com/cosmos/cosmos-sdk/types"
Expand All @@ -11,10 +15,16 @@ import (
// DefaultTxDecoder returns a default protobuf TxDecoder using the provided Marshaler.
func DefaultTxDecoder(cdc codec.ProtoCodecMarshaler) sdk.TxDecoder {
return func(txBytes []byte) (sdk.Tx, error) {
// Make sure txBytes follow ADR-027.
err := rejectNonADR027(txBytes)
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
}

var raw tx.TxRaw

// reject all unknown proto fields in the root TxRaw
err := unknownproto.RejectUnknownFieldsStrict(txBytes, &raw, cdc.InterfaceRegistry())
err = unknownproto.RejectUnknownFieldsStrict(txBytes, &raw, cdc.InterfaceRegistry())
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
}
Expand Down Expand Up @@ -79,3 +89,77 @@ func DefaultJSONTxDecoder(cdc codec.ProtoCodecMarshaler) sdk.TxDecoder {
}, nil
}
}

// rejectNonADR027 rejects txBytes that do not follow ADR-027. This function
// only checks that:
// - field numbers are in ascending order (1, 2, and potentially multiple 3s),
// - and varints as as short as possible.
// All other ADR-027 edge cases (e.g. TxRaw fields having default values) will
// not happen with TxRaw.
func rejectNonADR027(txBytes []byte) error {
// Make sure all fields are ordered in ascending order with this variable.
prevTagNum := protowire.Number(0)

for len(txBytes) > 0 {
tagNum, wireType, m := protowire.ConsumeTag(txBytes)
if m < 0 {
return fmt.Errorf("invalid length; %w", protowire.ParseError(m))
}
if wireType != protowire.BytesType {
return fmt.Errorf("expected %d wire type, got %d", protowire.VarintType, wireType)
}
if tagNum < prevTagNum {
return fmt.Errorf("txRaw must follow ADR-027, got tagNum %d after tagNum %d", tagNum, prevTagNum)
}
prevTagNum = tagNum

// All 3 fields of TxRaw have wireType == 2, so their next component
// is a varint.
// We make sure that the varint is as short as possible.
lengthPrefix, m := protowire.ConsumeVarint(txBytes[m:])
if m < 0 {
return fmt.Errorf("invalid length; %w", protowire.ParseError(m))
}
n := varintMinLength(lengthPrefix)
if n != m {
return fmt.Errorf("length prefix varint for tagNum %d is not as short as possible, read %d, only need %d", tagNum, m, n)
}

// Skip over the bytes that store fieldNumber and wireType bytes.
_, _, m = protowire.ConsumeField(txBytes)
if m < 0 {
return fmt.Errorf("invalid length; %w", protowire.ParseError(m))
}
txBytes = txBytes[m:]
}

return nil
}

// varintMinLength returns the minimum number of bytes necessary to encode an
// uint using varint encoding.
func varintMinLength(n uint64) int {
switch {
// Note: 1<<N == 2**N.
case n < 1<<7:
return 1
case n < 1<<14:
return 2
case n < 1<<21:
return 3
case n < 1<<28:
return 4
case n < 1<<35:
return 5
case n < 1<<42:
return 6
case n < 1<<49:
return 7
case n < 1<<56:
return 8
case n < 1<<63:
return 9
default:
return 10
}
}
134 changes: 128 additions & 6 deletions x/auth/tx/encode_decode_test.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
package tx

import (
"encoding/binary"
"fmt"
"math"
"testing"

sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"

"github.com/cosmos/cosmos-sdk/types/tx"
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/signing"

"github.com/stretchr/testify/require"
"google.golang.org/protobuf/encoding/protowire"

"github.com/cosmos/cosmos-sdk/codec"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx"
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/signing"
)

func TestDefaultTxDecoderError(t *testing.T) {
Expand Down Expand Up @@ -159,3 +160,124 @@ func TestUnknownFields(t *testing.T) {
_, err = decoder(txBz)
require.Error(t, err)
}

func TestRejectNonADR027(t *testing.T) {
registry := codectypes.NewInterfaceRegistry()
cdc := codec.NewProtoCodec(registry)
decoder := DefaultTxDecoder(cdc)

body := &testdata.TestUpdatedTxBody{Memo: "AAA"} // Look for "65 65 65" when debugging the bytes stream.
bodyBz, err := body.Marshal()
require.NoError(t, err)
authInfo := &testdata.TestUpdatedAuthInfo{Fee: &tx.Fee{GasLimit: 127}} // Look for "127" when debugging the bytes stream.
authInfoBz, err := authInfo.Marshal()
txRaw := &tx.TxRaw{
BodyBytes: bodyBz,
AuthInfoBytes: authInfoBz,
Signatures: [][]byte{{41}, {42}, {43}}, // Look for "42" when debugging the bytes stream.
}

// We know these bytes are ADR-027-compliant.
txBz, err := txRaw.Marshal()

// From the `txBz`, we extract the 3 components:
// bodyBz, authInfoBz, sigsBz.
// In our tests, we will try to decode txs with those 3 components in all
// possible orders.
//
// Consume "BodyBytes" field.
_, _, m := protowire.ConsumeField(txBz)
bodyBz = append([]byte{}, txBz[:m]...)
txBz = txBz[m:] // Skip over "BodyBytes" bytes.
// Consume "AuthInfoBytes" field.
_, _, m = protowire.ConsumeField(txBz)
authInfoBz = append([]byte{}, txBz[:m]...)
txBz = txBz[m:] // Skip over "AuthInfoBytes" bytes.
// Consume "Signature" field, it's the remaining bytes.
sigsBz := append([]byte{}, txBz...)

// bodyBz's length prefix is 5, with `5` as varint encoding. We also try a
// longer varint encoding for 5: `133 00`.
longVarintBodyBz := append(append([]byte{bodyBz[0]}, byte(133), byte(00)), bodyBz[2:]...)

tests := []struct {
name string
txBz []byte
shouldErr bool
}{
{
"authInfo, body, sigs",
append(append(authInfoBz, bodyBz...), sigsBz...),
true,
},
{
"authInfo, sigs, body",
append(append(authInfoBz, sigsBz...), bodyBz...),
true,
},
{
"sigs, body, authInfo",
append(append(sigsBz, bodyBz...), authInfoBz...),
true,
},
{
"sigs, authInfo, body",
append(append(sigsBz, authInfoBz...), bodyBz...),
true,
},
{
"body, sigs, authInfo",
append(append(bodyBz, sigsBz...), authInfoBz...),
true,
},
{
"body, authInfo, sigs (valid txRaw)",
append(append(bodyBz, authInfoBz...), sigsBz...),
false,
},
{
"longer varint than needed",
append(append(longVarintBodyBz, authInfoBz...), sigsBz...),
true,
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
_, err = decoder(tt.txBz)
if tt.shouldErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}

func TestVarintMinLength(t *testing.T) {
tests := []struct {
n uint64
}{
{1<<7 - 1}, {1 << 7},
{1<<14 - 1}, {1 << 14},
{1<<21 - 1}, {1 << 21},
{1<<28 - 1}, {1 << 28},
{1<<35 - 1}, {1 << 35},
{1<<42 - 1}, {1 << 42},
{1<<49 - 1}, {1 << 49},
{1<<56 - 1}, {1 << 56},
{1<<63 - 1}, {1 << 63},
{math.MaxUint64},
}

for _, tt := range tests {
tt := tt
t.Run(fmt.Sprintf("test %d", tt.n), func(t *testing.T) {
l1 := varintMinLength(tt.n)
buf := make([]byte, binary.MaxVarintLen64)
l2 := binary.PutUvarint(buf, tt.n)
require.Equal(t, l2, l1)
})
}
}

0 comments on commit 8557221

Please sign in to comment.