From b3c991f12ec52d308cf18ed426a5390440f75979 Mon Sep 17 00:00:00 2001 From: Tsachi Herman <24438559+tsachiherman@users.noreply.github.com> Date: Thu, 10 Oct 2024 09:36:21 -0400 Subject: [PATCH] add simulateTransaction endpoint (#1610) Signed-off-by: Tsachi Herman <24438559+tsachiherman@users.noreply.github.com> --- api/jsonrpc/client.go | 27 ++ api/jsonrpc/server.go | 75 +++++- examples/vmwithcontracts/actions/call.go | 16 +- .../cmd/vmwithcontracts-cli/cmd/action.go | 29 +- examples/vmwithcontracts/storage/recorder.go | 70 ----- examples/vmwithcontracts/vm/client.go | 26 -- examples/vmwithcontracts/vm/server.go | 52 ---- state/keys.go | 42 ++- state/keys_test.go | 58 ++++ state/recorder.go | 97 +++++++ state/recorder_test.go | 247 ++++++++++++++++++ 11 files changed, 580 insertions(+), 159 deletions(-) delete mode 100644 examples/vmwithcontracts/storage/recorder.go create mode 100644 state/recorder.go create mode 100644 state/recorder_test.go diff --git a/api/jsonrpc/client.go b/api/jsonrpc/client.go index f35487d433..c433b1c87b 100644 --- a/api/jsonrpc/client.go +++ b/api/jsonrpc/client.go @@ -239,3 +239,30 @@ func Wait(ctx context.Context, interval time.Duration, check func(ctx context.Co } return ctx.Err() } + +func (cli *JSONRPCClient) SimulateActions(ctx context.Context, actions chain.Actions, actor codec.Address) ([]SimulateActionResult, error) { + args := &SimulatActionsArgs{ + Actor: actor, + } + + for _, action := range actions { + marshaledAction, err := chain.MarshalTyped(action) + if err != nil { + return nil, err + } + args.Actions = append(args.Actions, marshaledAction) + } + + resp := new(SimulateActionsReply) + err := cli.requester.SendRequest( + ctx, + "simulateActions", + args, + resp, + ) + if err != nil { + return nil, err + } + + return resp.ActionResults, nil +} diff --git a/api/jsonrpc/server.go b/api/jsonrpc/server.go index 936c595e98..9af967a0be 100644 --- a/api/jsonrpc/server.go +++ b/api/jsonrpc/server.go @@ -18,6 +18,7 @@ import ( "github.com/ava-labs/hypersdk/codec" "github.com/ava-labs/hypersdk/consts" "github.com/ava-labs/hypersdk/fees" + "github.com/ava-labs/hypersdk/state" "github.com/ava-labs/hypersdk/state/tstate" ) @@ -27,6 +28,11 @@ const ( var _ api.HandlerFactory[api.VM] = (*JSONRPCServerFactory)(nil) +var ( + errSimulateZeroActions = errors.New("simulateAction expects at least a single action, none found") + errTransactionExtraBytes = errors.New("transaction has extra bytes") +) + type JSONRPCServerFactory struct{} func (JSONRPCServerFactory) New(vm api.VM) (api.Handler, error) { @@ -95,7 +101,7 @@ func (j *JSONRPCServer) SubmitTx( return fmt.Errorf("%w: unable to unmarshal on public service", err) } if !rtx.Empty() { - return errors.New("tx has extra bytes") + return errTransactionExtraBytes } if err := tx.Verify(ctx); err != nil { return err @@ -237,3 +243,70 @@ func (j *JSONRPCServer) Execute( } return nil } + +type SimulatActionsArgs struct { + Actions []codec.Bytes `json:"actions"` + Actor codec.Address `json:"actor"` +} + +type SimulateActionResult struct { + Output codec.Bytes `json:"output"` + StateKeys state.Keys `json:"stateKeys"` +} + +type SimulateActionsReply struct { + ActionResults []SimulateActionResult `json:"actionresults"` +} + +func (j *JSONRPCServer) SimulateActions( + req *http.Request, + args *SimulatActionsArgs, + reply *SimulateActionsReply, +) error { + ctx, span := j.vm.Tracer().Start(req.Context(), "JSONRPCServer.SimulateActions") + defer span.End() + + actionRegistry := j.vm.ActionRegistry() + var actions chain.Actions + for _, actionBytes := range args.Actions { + actionsReader := codec.NewReader(actionBytes, len(actionBytes)) + action, err := (*actionRegistry).Unmarshal(actionsReader) + if err != nil { + return err + } + if !actionsReader.Empty() { + return errTransactionExtraBytes + } + actions = append(actions, action) + } + if len(actions) == 0 { + return errSimulateZeroActions + } + currentState, err := j.vm.ImmutableState(ctx) + if err != nil { + return err + } + + currentTime := time.Now().UnixMilli() + for _, action := range actions { + recorder := state.NewRecorder(currentState) + actionOutput, err := action.Execute(ctx, j.vm.Rules(currentTime), recorder, currentTime, args.Actor, ids.Empty) + + var actionResult SimulateActionResult + if actionOutput == nil { + actionResult.Output = []byte{} + } else { + actionResult.Output, err = chain.MarshalTyped(actionOutput) + if err != nil { + return fmt.Errorf("failed to marshal output: %w", err) + } + } + if err != nil { + return err + } + actionResult.StateKeys = recorder.GetStateKeys() + reply.ActionResults = append(reply.ActionResults, actionResult) + currentState = recorder + } + return nil +} diff --git a/examples/vmwithcontracts/actions/call.go b/examples/vmwithcontracts/actions/call.go index 2aa36fd404..60271a94c5 100644 --- a/examples/vmwithcontracts/actions/call.go +++ b/examples/vmwithcontracts/actions/call.go @@ -21,7 +21,10 @@ import ( var _ chain.Action = (*Call)(nil) -const MaxCallDataSize = units.MiB +const ( + MaxCallDataSize = units.MiB + MaxResultSizeLimit = units.MiB +) type StateKeyPermission struct { Key string @@ -68,7 +71,7 @@ func (t *Call) Execute( actor codec.Address, _ ids.ID, ) (codec.Typed, error) { - resutBytes, err := t.r.CallContract(ctx, &runtime.CallInfo{ + callInfo := &runtime.CallInfo{ Contract: t.ContractAddress, Actor: actor, State: &storage.ContractStateManager{Mutable: mu}, @@ -77,11 +80,13 @@ func (t *Call) Execute( Timestamp: uint64(timestamp), Fuel: t.Fuel, Value: t.Value, - }) + } + resultBytes, err := t.r.CallContract(ctx, callInfo) if err != nil { return nil, err } - return &Result{Value: resutBytes}, nil + consumedFuel := t.Fuel - callInfo.RemainingFuel() + return &Result{Value: resultBytes, ConsumedFuel: consumedFuel}, nil } func (t *Call) ComputeUnits(chain.Rules) uint64 { @@ -134,7 +139,8 @@ func (*Call) ValidRange(chain.Rules) (int64, int64) { } type Result struct { - Value []byte `serialize:"true" json:"value"` + Value []byte `serialize:"true" json:"value"` + ConsumedFuel uint64 `serialize:"true" json:"consumedfuel"` } func (*Result) GetTypeID() uint8 { diff --git a/examples/vmwithcontracts/cmd/vmwithcontracts-cli/cmd/action.go b/examples/vmwithcontracts/cmd/vmwithcontracts-cli/cmd/action.go index 3c0f1ffe0a..da13e4255e 100644 --- a/examples/vmwithcontracts/cmd/vmwithcontracts-cli/cmd/action.go +++ b/examples/vmwithcontracts/cmd/vmwithcontracts-cli/cmd/action.go @@ -5,6 +5,8 @@ package cmd import ( "context" + "errors" + "fmt" "os" "github.com/near/borsh-go" @@ -15,9 +17,12 @@ import ( "github.com/ava-labs/hypersdk/cli/prompt" "github.com/ava-labs/hypersdk/codec" "github.com/ava-labs/hypersdk/examples/vmwithcontracts/actions" + "github.com/ava-labs/hypersdk/examples/vmwithcontracts/vm" "github.com/ava-labs/hypersdk/utils" ) +var errUnexpectedSimulateActionsOutput = errors.New("returned output from SimulateActions was not actions.Result") + var actionCmd = &cobra.Command{ Use: "action", RunE: func(*cobra.Command, []string) error { @@ -141,18 +146,34 @@ var callCmd = &cobra.Command{ ContractAddress: contractAddress, Value: amount, Function: function, + Fuel: uint64(1000000000), + } + + actionSimulationResults, err := cli.SimulateActions(ctx, chain.Actions{action}, priv.Address) + if err != nil { + return err + } + if len(actionSimulationResults) != 1 { + return fmt.Errorf("unexpected number of returned actions. One action expected, %d returned", len(actionSimulationResults)) } + actionSimulationResult := actionSimulationResults[0] - specifiedStateKeysSet, fuel, err := bcli.Simulate(ctx, *action, priv.Address) + rtx := codec.NewReader(actionSimulationResult.Output, len(actionSimulationResult.Output)) + + simulationResultOutput, err := (*vm.OutputParser).Unmarshal(rtx) if err != nil { return err } + simulationResult, ok := simulationResultOutput.(*actions.Result) + if !ok { + return errUnexpectedSimulateActionsOutput + } - action.SpecifiedStateKeys = make([]actions.StateKeyPermission, 0, len(specifiedStateKeysSet)) - for key, value := range specifiedStateKeysSet { + action.SpecifiedStateKeys = make([]actions.StateKeyPermission, 0, len(actionSimulationResult.StateKeys)) + for key, value := range actionSimulationResult.StateKeys { action.SpecifiedStateKeys = append(action.SpecifiedStateKeys, actions.StateKeyPermission{Key: key, Permission: value}) } - action.Fuel = fuel + action.Fuel = simulationResult.ConsumedFuel // Confirm action cont, err := prompt.Continue() diff --git a/examples/vmwithcontracts/storage/recorder.go b/examples/vmwithcontracts/storage/recorder.go deleted file mode 100644 index d544fca78f..0000000000 --- a/examples/vmwithcontracts/storage/recorder.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (C) 2024, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package storage - -import ( - "context" - "errors" - - "github.com/ava-labs/avalanchego/database" - "github.com/ava-labs/avalanchego/utils/set" - - "github.com/ava-labs/hypersdk/state" -) - -type Recorder struct { - State state.Immutable - changedValues map[string][]byte - ReadState set.Set[string] - WriteState set.Set[string] -} - -func NewRecorder(db state.Immutable) *Recorder { - return &Recorder{State: db, changedValues: map[string][]byte{}} -} - -func (r *Recorder) Insert(_ context.Context, key []byte, value []byte) error { - stringKey := string(key) - r.WriteState.Add(stringKey) - r.changedValues[stringKey] = value - return nil -} - -func (r *Recorder) Remove(_ context.Context, key []byte) error { - stringKey := string(key) - r.WriteState.Add(stringKey) - r.changedValues[stringKey] = nil - return nil -} - -func (r *Recorder) GetValue(ctx context.Context, key []byte) (value []byte, err error) { - stringKey := string(key) - r.ReadState.Add(stringKey) - if value, ok := r.changedValues[stringKey]; ok { - if value == nil { - return nil, database.ErrNotFound - } - return value, nil - } - return r.State.GetValue(ctx, key) -} - -func (r *Recorder) GetStateKeys() state.Keys { - result := state.Keys{} - for key := range r.ReadState { - result.Add(key, state.Read) - } - for key := range r.WriteState { - if _, err := r.State.GetValue(context.Background(), []byte(key)); err != nil && errors.Is(err, database.ErrNotFound) { - if r.changedValues[key] == nil { - // not a real write since the key was not already present and is being deleted - continue - } - // wasn't found so needs to be allocated - result.Add(key, state.Allocate) - } - result.Add(key, state.Write) - } - return result -} diff --git a/examples/vmwithcontracts/vm/client.go b/examples/vmwithcontracts/vm/client.go index 0935584988..8c7b847c37 100644 --- a/examples/vmwithcontracts/vm/client.go +++ b/examples/vmwithcontracts/vm/client.go @@ -5,7 +5,6 @@ package vm import ( "context" - "encoding/hex" "encoding/json" "strings" "time" @@ -13,12 +12,10 @@ import ( "github.com/ava-labs/hypersdk/api/jsonrpc" "github.com/ava-labs/hypersdk/chain" "github.com/ava-labs/hypersdk/codec" - "github.com/ava-labs/hypersdk/examples/vmwithcontracts/actions" "github.com/ava-labs/hypersdk/examples/vmwithcontracts/consts" "github.com/ava-labs/hypersdk/examples/vmwithcontracts/storage" "github.com/ava-labs/hypersdk/genesis" "github.com/ava-labs/hypersdk/requester" - "github.com/ava-labs/hypersdk/state" "github.com/ava-labs/hypersdk/utils" ) @@ -137,26 +134,3 @@ func CreateParser(genesisBytes []byte) (chain.Parser, error) { } return NewParser(&genesis), nil } - -func (cli *JSONRPCClient) Simulate(ctx context.Context, callTx actions.Call, actor codec.Address) (state.Keys, uint64, error) { - resp := new(SimulateCallTxReply) - err := cli.requester.SendRequest( - ctx, - "simulateCallContractTx", - &SimulateCallTxArgs{CallTx: callTx, Actor: actor}, - resp, - ) - if err != nil { - return nil, 0, err - } - result := state.Keys{} - for _, entry := range resp.StateKeys { - hexBytes, err := hex.DecodeString(entry.HexKey) - if err != nil { - return nil, 0, err - } - - result.Add(string(hexBytes), state.Permissions(entry.Permissions)) - } - return result, resp.FuelConsumed, nil -} diff --git a/examples/vmwithcontracts/vm/server.go b/examples/vmwithcontracts/vm/server.go index 60f098d738..50adfd33ff 100644 --- a/examples/vmwithcontracts/vm/server.go +++ b/examples/vmwithcontracts/vm/server.go @@ -4,18 +4,13 @@ package vm import ( - "context" - "encoding/hex" "net/http" "github.com/ava-labs/hypersdk/api" "github.com/ava-labs/hypersdk/codec" - "github.com/ava-labs/hypersdk/examples/vmwithcontracts/actions" "github.com/ava-labs/hypersdk/examples/vmwithcontracts/consts" "github.com/ava-labs/hypersdk/examples/vmwithcontracts/storage" "github.com/ava-labs/hypersdk/genesis" - "github.com/ava-labs/hypersdk/state" - "github.com/ava-labs/hypersdk/x/contracts/runtime" ) const JSONRPCEndpoint = "/vmwithcontractsapi" @@ -68,50 +63,3 @@ func (j *JSONRPCServer) Balance(req *http.Request, args *BalanceArgs, reply *Bal reply.Amount = balance return err } - -type SimulateCallTxArgs struct { - CallTx actions.Call `json:"callTx"` - Actor codec.Address `json:"actor"` -} - -type SimulateStateKey struct { - HexKey string `json:"hex"` - Permissions byte `json:"perm"` -} -type SimulateCallTxReply struct { - StateKeys []SimulateStateKey `json:"stateKeys"` - FuelConsumed uint64 `json:"fuel"` -} - -func (j *JSONRPCServer) SimulateCallContractTx(req *http.Request, args *SimulateCallTxArgs, reply *SimulateCallTxReply) (err error) { - stateKeys, fuelConsumed, err := j.simulate(req.Context(), args.CallTx, args.Actor) - if err != nil { - return err - } - reply.StateKeys = make([]SimulateStateKey, 0, len(stateKeys)) - for key, permission := range stateKeys { - reply.StateKeys = append(reply.StateKeys, SimulateStateKey{HexKey: hex.EncodeToString([]byte(key)), Permissions: byte(permission)}) - } - reply.FuelConsumed = fuelConsumed - return nil -} - -func (j *JSONRPCServer) simulate(ctx context.Context, t actions.Call, actor codec.Address) (state.Keys, uint64, error) { - currentState, err := j.vm.ImmutableState(ctx) - if err != nil { - return nil, 0, err - } - recorder := storage.NewRecorder(currentState) - startFuel := uint64(1000000000) - callInfo := &runtime.CallInfo{ - Contract: t.ContractAddress, - Actor: actor, - State: &storage.ContractStateManager{Mutable: recorder}, - FunctionName: t.Function, - Params: t.CallData, - Fuel: startFuel, - Value: t.Value, - } - _, err = wasmRuntime.CallContract(ctx, callInfo) - return recorder.GetStateKeys(), startFuel - callInfo.RemainingFuel(), err -} diff --git a/state/keys.go b/state/keys.go index f57d345aa3..13cb8d934c 100644 --- a/state/keys.go +++ b/state/keys.go @@ -3,7 +3,13 @@ package state -import "github.com/ava-labs/hypersdk/keys" +import ( + "encoding/hex" + "encoding/json" + "fmt" + + "github.com/ava-labs/hypersdk/keys" +) const ( Read Permissions = 1 @@ -50,6 +56,40 @@ func (k Keys) ChunkSizes() ([]uint16, bool) { return chunks, true } +type permsJSON []string + +type keysJSON struct { + Perms [8]permsJSON +} + +func (k Keys) MarshalJSON() ([]byte, error) { + var keysJSON keysJSON + for key, perm := range k { + keysJSON.Perms[perm] = append(keysJSON.Perms[perm], hex.EncodeToString([]byte(key))) + } + return json.Marshal(keysJSON) +} + +func (k *Keys) UnmarshalJSON(b []byte) error { + var keysJSON keysJSON + if err := json.Unmarshal(b, &keysJSON); err != nil { + return err + } + for perm, keyList := range keysJSON.Perms { + if perm < int(None) || perm > int(All) { + return fmt.Errorf("invalid permission encoded in json %d", perm) + } + for _, encodedKey := range keyList { + key, err := hex.DecodeString(encodedKey) + if err != nil { + return err + } + (*k)[string(key)] = Permissions(perm) + } + } + return nil +} + // Has returns true if [p] has all the permissions that are contained in require func (p Permissions) Has(require Permissions) bool { return require&^p == 0 diff --git a/state/keys_test.go b/state/keys_test.go index 27733dd777..8e819f836a 100644 --- a/state/keys_test.go +++ b/state/keys_test.go @@ -4,6 +4,9 @@ package state import ( + "crypto/sha256" + "encoding/binary" + "math/rand" "slices" "testing" @@ -123,3 +126,58 @@ func TestHasPermissions(t *testing.T) { } } } + +func TestKeysMarshalingSimple(t *testing.T) { + require := require.New(t) + + // test with read permission. + keys := Keys{} + require.True(keys.Add("key1", Read)) + bytes, err := keys.MarshalJSON() + require.NoError(err) + require.Equal([]byte{0x7b, 0x22, 0x50, 0x65, 0x72, 0x6d, 0x73, 0x22, 0x3a, 0x5b, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x5b, 0x22, 0x36, 0x62, 0x36, 0x35, 0x37, 0x39, 0x33, 0x31, 0x22, 0x5d, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x5d, 0x7d}, bytes) + keys = Keys{} + require.NoError(keys.UnmarshalJSON(bytes)) + require.Len(keys, 1) + require.Equal(Read, keys["key1"]) + + // test with read+write permission. + keys = Keys{} + require.True(keys.Add("key2", Read|Write)) + bytes, err = keys.MarshalJSON() + require.NoError(err) + require.Equal([]byte{0x7b, 0x22, 0x50, 0x65, 0x72, 0x6d, 0x73, 0x22, 0x3a, 0x5b, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x5b, 0x22, 0x36, 0x62, 0x36, 0x35, 0x37, 0x39, 0x33, 0x32, 0x22, 0x5d, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x5d, 0x7d}, bytes) + keys = Keys{} + require.NoError(keys.UnmarshalJSON(bytes)) + require.Len(keys, 1) + require.Equal(Read|Write, keys["key2"]) +} + +func (k Keys) compare(k2 Keys) bool { + if len(k) != len(k2) { + return false + } + for k1, v1 := range k { + if v2, has := k2[k1]; !has || v1 != v2 { + return false + } + } + return true +} + +func TestKeysMarshalingFuzz(t *testing.T) { + require := require.New(t) + rand := rand.New(rand.NewSource(0)) //nolint:gosec + for fuzzIteration := 0; fuzzIteration < 1000; fuzzIteration++ { + keys := Keys{} + for keyIdx := 0; keyIdx < rand.Int()%32; keyIdx++ { + key := sha256.Sum256(binary.BigEndian.AppendUint64(nil, uint64(keyIdx))) + keys.Add(string(key[:]), Permissions(rand.Int()%(int(All)+1))) + } + bytes, err := keys.MarshalJSON() + require.NoError(err) + decodedKeys := Keys{} + require.NoError(decodedKeys.UnmarshalJSON(bytes)) + require.True(keys.compare(decodedKeys)) + } +} diff --git a/state/recorder.go b/state/recorder.go new file mode 100644 index 0000000000..fba5cda6c6 --- /dev/null +++ b/state/recorder.go @@ -0,0 +1,97 @@ +// Copyright (C) 2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package state + +import ( + "context" + "errors" + + "github.com/ava-labs/avalanchego/database" +) + +// The Recorder wraps an [Immutable] state and records what keys are accessed +// and what permissions are required. +// Maintains same definition of required permissions as TStateView +type Recorder struct { + // State is the underlying [Immutable] object + state Immutable + stateKeys map[string][]byte + + changedValues map[string][]byte + keys Keys +} + +func NewRecorder(db Immutable) *Recorder { + return &Recorder{state: db, changedValues: map[string][]byte{}, stateKeys: map[string][]byte{}, keys: Keys{}} +} + +func (r *Recorder) checkState(ctx context.Context, key []byte) ([]byte, error) { + if val, has := r.stateKeys[string(key)]; has { + return val, nil + } + value, err := r.state.GetValue(ctx, key) + if err == nil { + // no error, key found. + r.stateKeys[string(key)] = value + return value, nil + } + + if errors.Is(err, database.ErrNotFound) { + r.stateKeys[string(key)] = nil + err = nil + } + return nil, err +} + +func (r *Recorder) Insert(ctx context.Context, key []byte, value []byte) error { + stringKey := string(key) + + stateKeyVal, err := r.checkState(ctx, key) + if err != nil { + return err + } + + if stateKeyVal != nil { + // underlying storage already has that key. + r.keys[stringKey] |= Write + } else { + // underlying storage doesn't have that key. + r.keys[stringKey] |= Allocate | Write + } + + // save the updated value. + r.changedValues[stringKey] = value + return nil +} + +func (r *Recorder) Remove(_ context.Context, key []byte) error { + stringKey := string(key) + r.keys[stringKey] |= Write + r.changedValues[stringKey] = nil + return nil +} + +func (r *Recorder) GetValue(ctx context.Context, key []byte) (value []byte, err error) { + stringKey := string(key) + + stateKeyVal, err := r.checkState(ctx, key) + if err != nil { + return nil, err + } + r.keys[stringKey] |= Read + if value, ok := r.changedValues[stringKey]; ok { + if value == nil { // value was removed. + return nil, database.ErrNotFound + } + return value, nil + } + if stateKeyVal == nil { // no such key exist. + return nil, database.ErrNotFound + } + return stateKeyVal, nil +} + +func (r *Recorder) GetStateKeys() Keys { + return r.keys +} diff --git a/state/recorder_test.go b/state/recorder_test.go new file mode 100644 index 0000000000..bfcf3b21cd --- /dev/null +++ b/state/recorder_test.go @@ -0,0 +1,247 @@ +// Copyright (C) 2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package state_test + +import ( + "context" + "crypto/rand" + "testing" + + "github.com/ava-labs/avalanchego/database" + "github.com/stretchr/testify/require" + + "github.com/ava-labs/hypersdk/keys" + "github.com/ava-labs/hypersdk/state" + "github.com/ava-labs/hypersdk/state/tstate" +) + +func randomNewKey() []byte { + randNewKey := make([]byte, 30, 32) + _, err := rand.Read(randNewKey) + if err != nil { + panic(err) + } + return keys.EncodeChunks(randNewKey, 1) +} + +func randomizeView(tstate *tstate.TState, keyCount int) (*tstate.TStateView, [][]byte, map[string]state.Permissions, map[string][]byte) { + keys := make([][]byte, keyCount) + values := make([][32]byte, keyCount) + storage := map[string][]byte{} + scope := map[string]state.Permissions{} + for i := 0; i < keyCount; i++ { + keys[i] = randomNewKey() + _, err := rand.Read(values[i][:]) + if err != nil { + panic(err) + } + storage[string(keys[i])] = values[i][:] + scope[string(keys[i])] = state.All + } + // create new view + return tstate.NewView(scope, storage), keys, scope, storage +} + +func TestRecorderInnerFuzz(t *testing.T) { + tstateObj := tstate.New(1000) + require := require.New(t) + + var ( + stateView *tstate.TStateView + keys [][]byte + scope map[string]state.Permissions + removedKeys map[string]bool + ) + + pickExistingKeyAtRandom := func() []byte { + randKey := make([]byte, 1) + _, err := rand.Read(randKey) + require.NoError(err) + randKey[0] %= byte(len(keys)) + for removedKeys[string(keys[randKey[0]])] { + _, err := rand.Read(randKey) + randKey[0] %= byte(len(keys)) + require.NoError(err) + } + return keys[randKey[0]] + } + for i := 0; i < 10000; i++ { + stateView, keys, scope, _ = randomizeView(tstateObj, 32) + removedKeys = map[string]bool{} + // wrap with recorder. + recorder := state.NewRecorder(stateView) + for j := 0; j <= 32; j++ { + op := make([]byte, 1) + _, err := rand.Read(op) + require.NoError(err) + switch op[0] % 6 { + case 0: // insert into existing entry + randKey := pickExistingKeyAtRandom() + err := recorder.Insert(context.Background(), randKey, []byte{1, 2, 3, 4}) + require.NoError(err) + require.True(recorder.GetStateKeys()[string(randKey)].Has(state.Write)) + case 1: // insert into new entry + randNewKey := randomNewKey() + // add the new key to the scope + scope[string(randNewKey)] = state.Allocate | state.Write + err := recorder.Insert(context.Background(), randNewKey, []byte{1, 2, 3, 4}) + require.NoError(err) + require.True(recorder.GetStateKeys()[string(randNewKey)].Has(state.Allocate | state.Write)) + keys = append(keys, randNewKey) + case 2: // remove existing entry + randKey := pickExistingKeyAtRandom() + err := recorder.Remove(context.Background(), randKey) + require.NoError(err) + removedKeys[string(randKey)] = true + require.True(recorder.GetStateKeys()[string(randKey)].Has(state.Write)) + case 3: // remove non existing entry + randKey := randomNewKey() + err := recorder.Remove(context.Background(), randKey) + require.NoError(err) + require.True(recorder.GetStateKeys()[string(randKey)].Has(state.Write)) + case 4: // get value of existing entry + randKey := pickExistingKeyAtRandom() + val, err := recorder.GetValue(context.Background(), randKey) + require.NoError(err) + require.NotEmpty(val) + require.True(recorder.GetStateKeys()[string(randKey)].Has(state.Read)) + case 5: // get value of non existing entry + randKey := randomNewKey() + // add the new key to the scope + scope[string(randKey)] = state.Read + value, err := recorder.GetValue(context.Background(), randKey) + require.ErrorIs(err, database.ErrNotFound) + require.Empty(value) + require.True(recorder.GetStateKeys()[string(randKey)].Has(state.Read)) + } + } + } +} + +type testingReadonlyDatasource struct { + storage map[string][]byte +} + +func (c *testingReadonlyDatasource) GetValue(_ context.Context, key []byte) (value []byte, err error) { + if v, has := c.storage[string(key)]; has { + return v, nil + } + return nil, database.ErrNotFound +} + +func TestRecorderSideBySideFuzz(t *testing.T) { + tstateObj := tstate.New(1000) + require := require.New(t) + + var ( + stateView *tstate.TStateView + keys [][]byte + scope map[string]state.Permissions + removedKeys map[string]bool + storage map[string][]byte + ) + + pickExistingKeyAtRandom := func() []byte { + randKey := make([]byte, 1) + _, err := rand.Read(randKey) + require.NoError(err) + randKey[0] %= byte(len(keys)) + for removedKeys[string(keys[randKey[0]])] { + _, err := rand.Read(randKey) + randKey[0] %= byte(len(keys)) + require.NoError(err) + } + return keys[randKey[0]] + } + randomValue := func() []byte { + randVal := make([]byte, 32) + _, err := rand.Read(randVal) + require.NoError(err) + return randVal + } + + for i := 0; i < 10000; i++ { + stateView, keys, scope, storage = randomizeView(tstateObj, 32) + removedKeys = map[string]bool{} + // wrap with recorder. + recorder := state.NewRecorder(&testingReadonlyDatasource{storage}) + for j := 0; j <= 32; j++ { + op := make([]byte, 1) + _, err := rand.Read(op) + require.NoError(err) + switch op[0] % 6 { + case 0: // insert into existing entry + randKey := pickExistingKeyAtRandom() + randVal := randomValue() + + err := recorder.Insert(context.Background(), randKey, randVal) + require.NoError(err) + require.True(recorder.GetStateKeys()[string(randKey)].Has(state.Write)) + + err = stateView.Insert(context.Background(), randKey, randVal) + require.NoError(err) + case 1: // insert into new entry + randNewKey := randomNewKey() + randVal := randomValue() + + err := recorder.Insert(context.Background(), randNewKey, randVal) + require.NoError(err) + require.True(recorder.GetStateKeys()[string(randNewKey)].Has(state.Allocate | state.Write)) + + // add the new key to the scope + scope[string(randNewKey)] = state.Write | state.Allocate + err = stateView.Insert(context.Background(), randNewKey, randVal) + require.NoError(err) + + keys = append(keys, randNewKey) + case 2: // remove existing entry + randKey := pickExistingKeyAtRandom() + + err := recorder.Remove(context.Background(), randKey) + require.NoError(err) + require.True(recorder.GetStateKeys()[string(randKey)].Has(state.Write)) + + err = stateView.Remove(context.Background(), randKey) + require.NoError(err) + + removedKeys[string(randKey)] = true + case 3: // remove non existing entry + randKey := randomNewKey() + + err := recorder.Remove(context.Background(), randKey) + require.NoError(err) + require.True(recorder.GetStateKeys()[string(randKey)].Has(state.Write)) + + // add the new key to the scope + scope[string(randKey)] = state.Write + err = stateView.Remove(context.Background(), randKey) + require.NoError(err) + case 4: // get value of existing entry + randKey := pickExistingKeyAtRandom() + + val, err := recorder.GetValue(context.Background(), randKey) + require.NoError(err) + require.NotEmpty(val) + require.True(recorder.GetStateKeys()[string(randKey)].Has(state.Read)) + + val, err = stateView.GetValue(context.Background(), randKey) + require.NoError(err) + require.NotEmpty(val) + case 5: // get value of non existing entry + randKey := randomNewKey() + + value, err := recorder.GetValue(context.Background(), randKey) + require.ErrorIs(err, database.ErrNotFound) + require.Empty(value) + require.True(recorder.GetStateKeys()[string(randKey)].Has(state.Read)) + + // add the new key to the scope + scope[string(randKey)] = state.Read + value, err = stateView.GetValue(context.Background(), randKey) + require.ErrorIs(err, database.ErrNotFound) + require.Empty(value) + } + } + } +}