Skip to content

Commit

Permalink
Added salt to starknet transaction for cairo contract address generat…
Browse files Browse the repository at this point in the history
…ion (erigontech#3178)

Co-authored-by: Aleksandr Borodulin <a.borodulin@axioma.lv>
  • Loading branch information
Cript and aleksandrborodulin authored Jan 3, 2022
1 parent 1bfc2ff commit 8203cdf
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 41 deletions.
11 changes: 9 additions & 2 deletions cmd/starknet/cmd/generate_raw_tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import (

type Flags struct {
Contract string
Salt string
PrivateKey string
Datadir string
Output string
}

Expand All @@ -27,17 +29,22 @@ func init() {
generateRawTxCmd.Flags().StringVarP(&flags.Contract, "contract", "c", "", "Path to compiled cairo contract in JSON format")
generateRawTxCmd.MarkFlagRequired("contract")

generateRawTxCmd.Flags().StringVarP(&flags.Salt, "salt", "s", "", "Cairo contract address salt")
generateRawTxCmd.MarkFlagRequired("salt")

generateRawTxCmd.Flags().StringVarP(&flags.PrivateKey, "private_key", "k", "", "Private key")
generateRawTxCmd.MarkFlagRequired("private_key")

generateRawTxCmd.Flags().StringVarP(&flags.Output, "output", "o", "", "Path to file where sign transaction will be saved")
rootCmd.PersistentFlags().StringVar(&flags.Datadir, "datadir", "", "path to Erigon working directory")

generateRawTxCmd.Flags().StringVarP(&flags.Output, "output", "o", "", "Path to file where sign transaction will be saved. Print to stdout if empty.")

generateRawTxCmd.RunE = func(cmd *cobra.Command, args []string) error {
rawTxGenerator := services.NewRawTxGenerator(flags.PrivateKey)

fs := os.DirFS("/")
buf := bytes.NewBuffer(nil)
err := rawTxGenerator.CreateFromFS(fs, strings.Trim(flags.Contract, "/"), buf)
err := rawTxGenerator.CreateFromFS(fs, strings.Trim(flags.Contract, "/"), []byte(flags.Salt), buf)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion cmd/starknet/services/raw_tx_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type RawTxGenerator struct {
privateKey string
}

func (g RawTxGenerator) CreateFromFS(fileSystem fs.FS, contractFileName string, writer io.Writer) error {
func (g RawTxGenerator) CreateFromFS(fileSystem fs.FS, contractFileName string, salt []byte, writer io.Writer) error {
privateKey, err := crypto.HexToECDSA(g.privateKey)
if err != nil {
return ErrInvalidPrivateKey
Expand All @@ -47,6 +47,7 @@ func (g RawTxGenerator) CreateFromFS(fileSystem fs.FS, contractFileName string,
Value: uint256.NewInt(1),
Gas: 1,
Data: enc,
Salt: salt,
},
}

Expand Down
14 changes: 8 additions & 6 deletions cmd/starknet/services/raw_tx_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ func TestCreate(t *testing.T) {
name string
privateKey string
fileName string
salt string
want string
}{
{name: "success", privateKey: privateKey, fileName: "contract_test.json", want: "03f86583127ed80180800180019637623232363136323639323233613230356235643764c080a0ceb955e6039bf37dbf77e4452a10b4a47906bbbd2f6dcf0c15bccb052d3bbb60a03de24d584a0a20523f55a137ebc651e2b092fbc3728d67c9fda09da9f0edd154"},
{name: "success", privateKey: privateKey, fileName: "contract_test.json", salt: "contract_address_salt", want: "03f87b83127ed8018080018001963762323236313632363932323361323035623564376495636f6e74726163745f616464726573735f73616c74c080a0ceb955e6039bf37dbf77e4452a10b4a47906bbbd2f6dcf0c15bccb052d3bbb60a03de24d584a0a20523f55a137ebc651e2b092fbc3728d67c9fda09da9f0edd154"},
}

fs := fstest.MapFS{
Expand All @@ -31,12 +32,13 @@ func TestCreate(t *testing.T) {
}

buf := bytes.NewBuffer(nil)
err := rawTxGenerator.CreateFromFS(fs, tt.fileName, buf)

err := rawTxGenerator.CreateFromFS(fs, tt.fileName, []byte(tt.salt), buf)
assertNoError(t, err)

if hex.EncodeToString(buf.Bytes()) != tt.want {
t.Error("got not equals want")
got := hex.EncodeToString(buf.Bytes())

if got != tt.want {
t.Errorf("got %q not equals want %q", got, tt.want)
}
})
}
Expand All @@ -62,7 +64,7 @@ func TestErrorCreate(t *testing.T) {
}

buf := bytes.NewBuffer(nil)
err := rawTxGenerator.CreateFromFS(fs, tt.fileName, buf)
err := rawTxGenerator.CreateFromFS(fs, tt.fileName, []byte{}, buf)

if tt.error != nil {
assertError(t, err, tt.error)
Expand Down
5 changes: 5 additions & 0 deletions core/types/legacy_tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type CommonTx struct {
To *common.Address `rlp:"nil"` // nil means contract creation
Value *uint256.Int // wei amount
Data []byte // contract invocation input data
Salt []byte // cairo contract address salt
V, R, S uint256.Int // signature values
}

Expand Down Expand Up @@ -65,6 +66,10 @@ func (ct CommonTx) GetData() []byte {
return ct.Data
}

func (ct CommonTx) GetSalt() []byte {
return ct.Salt
}

func (ct CommonTx) GetSender() (common.Address, bool) {
if sc := ct.from.Load(); sc != nil {
return sc.(common.Address), true
Expand Down
68 changes: 36 additions & 32 deletions core/types/starknet_tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,20 @@ func (tx *StarknetTransaction) DecodeRLP(s *rlp.Stream) error {
return err
}
var b []byte
if b, err = s.Bytes(); err != nil {
if b, err = s.Uint256Bytes(); err != nil {
return err
}
if len(b) > 32 {
return fmt.Errorf("wrong size for ChainID: %d", len(b))
}
//tx.ChainID = new(uint256.Int).SetBytes(b)
tx.ChainID = new(uint256.Int).SetBytes(b)
if tx.Nonce, err = s.Uint(); err != nil {
return err
}
if b, err = s.Bytes(); err != nil {
if b, err = s.Uint256Bytes(); err != nil {
return err
}
if len(b) > 32 {
return fmt.Errorf("wrong size for MaxPriorityFeePerGas: %d", len(b))
}
tx.Tip = new(uint256.Int).SetBytes(b)
if b, err = s.Bytes(); err != nil {
if b, err = s.Uint256Bytes(); err != nil {
return err
}
if len(b) > 32 {
return fmt.Errorf("wrong size for MaxFeePerGas: %d", len(b))
}
tx.FeeCap = new(uint256.Int).SetBytes(b)
if tx.Gas, err = s.Uint(); err != nil {
return err
Expand All @@ -70,42 +61,33 @@ func (tx *StarknetTransaction) DecodeRLP(s *rlp.Stream) error {
tx.To = &common.Address{}
copy((*tx.To)[:], b)
}
if b, err = s.Bytes(); err != nil {
if b, err = s.Uint256Bytes(); err != nil {
return err
}
if len(b) > 32 {
return fmt.Errorf("wrong size for Value: %d", len(b))
}
tx.Value = new(uint256.Int).SetBytes(b)
if tx.Data, err = s.Bytes(); err != nil {
return err
}
if tx.Salt, err = s.Bytes(); err != nil {
return err
}
// decode AccessList
tx.AccessList = AccessList{}
if err = decodeAccessList(&tx.AccessList, s); err != nil {
return err
}
// decode V
if b, err = s.Bytes(); err != nil {
if b, err = s.Uint256Bytes(); err != nil {
return err
}
if len(b) > 32 {
return fmt.Errorf("wrong size for V: %d", len(b))
}
tx.V.SetBytes(b)
if b, err = s.Bytes(); err != nil {
if b, err = s.Uint256Bytes(); err != nil {
return err
}
if len(b) > 32 {
return fmt.Errorf("wrong size for R: %d", len(b))
}
tx.R.SetBytes(b)
if b, err = s.Bytes(); err != nil {
if b, err = s.Uint256Bytes(); err != nil {
return err
}
if len(b) > 32 {
return fmt.Errorf("wrong size for S: %d", len(b))
}
tx.S.SetBytes(b)
return s.ListEnd()
}
Expand All @@ -115,15 +97,15 @@ func (tx StarknetTransaction) GetPrice() *uint256.Int {
}

func (tx StarknetTransaction) GetTip() *uint256.Int {
panic("implement me")
return tx.Tip
}

func (tx StarknetTransaction) GetEffectiveGasTip(baseFee *uint256.Int) *uint256.Int {
panic("implement me")
}

func (tx StarknetTransaction) GetFeeCap() *uint256.Int {
panic("implement me")
return tx.FeeCap
}

func (tx StarknetTransaction) Cost() *uint256.Int {
Expand Down Expand Up @@ -184,7 +166,7 @@ func (tx StarknetTransaction) GetAccessList() AccessList {
}

func (tx StarknetTransaction) RawSignatureValues() (*uint256.Int, *uint256.Int, *uint256.Int) {
panic("implement me")
return &tx.V, &tx.R, &tx.S
}

func (tx StarknetTransaction) MarshalBinary(w io.Writer) error {
Expand Down Expand Up @@ -269,6 +251,10 @@ func (tx StarknetTransaction) encodePayload(w io.Writer, b []byte, payloadSize,
if err := EncodeString(tx.Data, w, b); err != nil {
return err
}
// encode cairo contract address salt
if err := EncodeString(tx.Salt, w, b); err != nil {
return err
}
// prefix
if err := EncodeStructSizePrefix(accessListLen, w, b); err != nil {
return err
Expand Down Expand Up @@ -338,6 +324,7 @@ func (tx StarknetTransaction) payloadSize() (payloadSize int, nonceLen, gasLen,
valueLen = (tx.Value.BitLen() + 7) / 8
}
payloadSize += valueLen

// size of Data
payloadSize++
switch len(tx.Data) {
Expand All @@ -352,6 +339,22 @@ func (tx StarknetTransaction) payloadSize() (payloadSize int, nonceLen, gasLen,
}
payloadSize += len(tx.Data)
}

// size of cairo contract address salt
payloadSize++
switch len(tx.Salt) {
case 0:
case 1:
if tx.Salt[0] >= 128 {
payloadSize++
}
default:
if len(tx.Salt) >= 56 {
payloadSize += (bits.Len(uint(len(tx.Salt))) + 7) / 8
}
payloadSize += len(tx.Salt)
}

// size of AccessList
payloadSize++
accessListLen = accessListSize(tx.AccessList)
Expand Down Expand Up @@ -393,6 +396,7 @@ func (tx StarknetTransaction) copy() *StarknetTransaction {
Nonce: tx.Nonce,
To: tx.To,
Data: common.CopyBytes(tx.Data),
Salt: common.CopyBytes(tx.Salt),
Gas: tx.Gas,
Value: new(uint256.Int),
},
Expand Down
124 changes: 124 additions & 0 deletions core/types/starknet_tx_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package types

import (
"bytes"
"encoding/hex"
"github.com/holiman/uint256"
"github.com/ledgerwatch/erigon/common"
"github.com/ledgerwatch/erigon/crypto"
"github.com/ledgerwatch/erigon/params"
"github.com/ledgerwatch/erigon/rlp"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/sha3"
"math/big"
"testing"
)

var (
chainConfig = params.AllEthashProtocolChanges
address = common.HexToAddress("b94f5374fce5edbc8e2a8697c15331677e6ebf0b")
)

func TestStarknetTxDecodeRLP(t *testing.T) {
require := require.New(t)

privateKey, _ := crypto.HexToECDSA(generatePrivateKey(t))

signature, _ := crypto.Sign(sha3.New256().Sum(nil), privateKey)
signer := MakeSigner(chainConfig, 1)

cases := []struct {
name string
tx *StarknetTransaction
}{
{name: "with data, without salt", tx: &StarknetTransaction{
CommonTx: CommonTx{
ChainID: uint256.NewInt(chainConfig.ChainID.Uint64()),
Nonce: 1,
Value: uint256.NewInt(20),
Gas: 1,
To: &address,
Data: []byte("{\"abi\": []}"),
Salt: []byte("contract_salt"),
},
Tip: uint256.NewInt(1),
FeeCap: uint256.NewInt(1),
}},
{name: "with data and salt", tx: &StarknetTransaction{
CommonTx: CommonTx{
ChainID: uint256.NewInt(chainConfig.ChainID.Uint64()),
Nonce: 1,
Value: uint256.NewInt(20),
Gas: 1,
To: &address,
Data: []byte("{\"abi\": []}"),
Salt: []byte{},
},
Tip: uint256.NewInt(1),
FeeCap: uint256.NewInt(1),
}},
}

for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tx := tt.tx

signedTx, err := tx.WithSignature(*signer, signature)
require.NoError(err)

buf := bytes.NewBuffer(nil)

err = signedTx.MarshalBinary(buf)
require.NoError(err)

encodedTx := buf.Bytes()

txn, err := DecodeTransaction(rlp.NewStream(bytes.NewReader(encodedTx), uint64(len(encodedTx))))
require.NoError(err)

require.Equal(signedTx.GetChainID(), txn.GetChainID())
require.Equal(signedTx.GetNonce(), txn.GetNonce())
require.Equal(signedTx.GetTip(), txn.GetTip())
require.Equal(signedTx.GetFeeCap(), txn.GetFeeCap())
require.Equal(signedTx.GetGas(), txn.GetGas())
require.Equal(signedTx.GetTo(), txn.GetTo())
require.Equal(signedTx.GetValue(), txn.GetValue())
require.Equal(signedTx.GetData(), txn.GetData())
require.Equal(signedTx.GetSalt(), txn.GetSalt())

txV, txR, txS := signedTx.RawSignatureValues()
txnV, txnR, txnS := txn.RawSignatureValues()

require.Equal(txV, txnV)
require.Equal(txR, txnR)
require.Equal(txS, txnS)
})
}
}

func generatePrivateKey(t testing.TB) string {
t.Helper()

privateKey, err := crypto.GenerateKey()
if err != nil {
t.Error(err)
}

return hex.EncodeToString(crypto.FromECDSA(privateKey))
}

func starknetTransaction(chainId *big.Int, address common.Address) *StarknetTransaction {
return &StarknetTransaction{
CommonTx: CommonTx{
ChainID: uint256.NewInt(chainId.Uint64()),
Nonce: 1,
Value: uint256.NewInt(20),
Gas: 1,
To: &address,
Data: []byte("{\"abi\": []}"),
Salt: []byte("contract_salt"),
},
Tip: uint256.NewInt(1),
FeeCap: uint256.NewInt(1),
}
}
Loading

0 comments on commit 8203cdf

Please sign in to comment.