diff --git a/x/programs/examples/counter_test.go b/x/programs/examples/counter_test.go deleted file mode 100644 index c6ea953d84..0000000000 --- a/x/programs/examples/counter_test.go +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright (C) 2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package examples - -import ( - "context" - "os" - "testing" - - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/utils/logging" - "github.com/stretchr/testify/require" - "go.uber.org/zap" - - "github.com/ava-labs/hypersdk/x/programs/engine" - "github.com/ava-labs/hypersdk/x/programs/examples/imports/pstate" - "github.com/ava-labs/hypersdk/x/programs/examples/storage" - "github.com/ava-labs/hypersdk/x/programs/host" - "github.com/ava-labs/hypersdk/x/programs/program" - "github.com/ava-labs/hypersdk/x/programs/runtime" - "github.com/ava-labs/hypersdk/x/programs/tests" - - iprogram "github.com/ava-labs/hypersdk/x/programs/examples/imports/program" -) - -// go test -v -timeout 30s -run ^TestCounterProgram$ github.com/ava-labs/hypersdk/x/programs/examples -func TestCounterProgram(t *testing.T) { - require := require.New(t) - db := newTestDB() - maxUnits := uint64(10000000) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - programID := ids.GenerateTestID() - callContext := program.Context{ProgramID: programID} - cfg := runtime.NewConfig() - log := logging.NewLogger( - "", - logging.NewWrappedCore( - logging.Info, - os.Stderr, - logging.Plain.ConsoleEncoder(), - )) - - eng := engine.New(engine.NewConfig()) - // define supported imports - importsBuilder := host.NewImportsBuilder() - importsBuilder.Register("state", func() host.Import { - return pstate.New(log, db) - }) - - importsBuilder.Register("program", func() host.Import { - return iprogram.New(log, eng, db, cfg, &callContext) - }) - imports := importsBuilder.Build() - - wasmBytes := tests.ReadFixture(t, "../tests/fixture/counter.wasm") - rt := runtime.New(log, eng, imports, cfg) - - err := rt.Initialize(ctx, callContext, wasmBytes, maxUnits) - require.NoError(err) - - balance, err := rt.Meter().GetBalance() - require.NoError(err) - require.Equal(maxUnits, balance) - - // simulate create program transaction - - err = storage.SetProgram(ctx, db, programID, wasmBytes) - require.NoError(err) - - mem, err := rt.Memory() - require.NoError(err) - - // generate alice keys - alicePublicKey, err := newKey() - require.NoError(err) - - // write alice's key to stack and get pointer - alicePtr, err := writeToMem(alicePublicKey, mem) - require.NoError(err) - - // create counter for alice on program 1 - result, err := rt.Call(ctx, "initialize_address", callContext, alicePtr) - require.NoError(err) - require.Equal(int64(1), result[0]) - - alicePtr, err = writeToMem(alicePublicKey, mem) - require.NoError(err) - - // validate counter at 0 - result, err = rt.Call(ctx, "get_value", callContext, alicePtr) - require.NoError(err) - require.Equal(int64(0), result[0]) - - // initialize second runtime to create second counter program with an empty - // meter. - rt2 := runtime.New(log, eng, imports, cfg) - err = rt2.Initialize(ctx, callContext, wasmBytes, engine.NoUnits) - - require.NoError(err) - - // define max units to transfer to second runtime - unitsTransfer := uint64(2000000) - - // transfer the units from the original runtime to the new runtime before - // any calls are made. - _, err = rt.Meter().TransferUnitsTo(rt2.Meter(), unitsTransfer) - require.NoError(err) - - // simulate creating second program transaction - programID2 := ids.GenerateTestID() - err = storage.SetProgram(ctx, db, programID2, wasmBytes) - require.NoError(err) - - mem2, err := rt2.Memory() - require.NoError(err) - - // write alice's key to stack and get pointer - alicePtr2, err := writeToMem(alicePublicKey, mem2) - require.NoError(err) - - callContext1 := program.Context{ProgramID: programID} - callContext2 := program.Context{ProgramID: programID2} - - // initialize counter for alice on runtime 2 - result, err = rt2.Call(ctx, "initialize_address", callContext2, alicePtr2) - require.NoError(err) - require.Equal(int64(1), result[0]) - - // increment alice's counter on program 2 by 10 - incAmount := int64(10) - incAmountPtr, err := writeToMem(incAmount, mem2) - require.NoError(err) - - alicePtr2, err = writeToMem(alicePublicKey, mem2) - - require.NoError(err) - result, err = rt2.Call(ctx, "inc", callContext2, alicePtr2, incAmountPtr) - require.NoError(err) - require.Equal(int64(1), result[0]) - - alicePtr2, err = writeToMem(alicePublicKey, mem2) - require.NoError(err) - - result, err = rt2.Call(ctx, "get_value", callContext2, alicePtr2) - require.NoError(err) - require.Equal(incAmount, result[0]) - - balance, err = rt2.Meter().GetBalance() - require.NoError(err) - - // transfer balance back to original runtime - _, err = rt2.Meter().TransferUnitsTo(rt.Meter(), balance) - require.NoError(err) - - // increment alice's counter on program 1 - onePtr, err := writeToMem(int64(1), mem) - require.NoError(err) - - alicePtr, err = writeToMem(alicePublicKey, mem) - require.NoError(err) - - result, err = rt.Call(ctx, "inc", callContext1, alicePtr, onePtr) - require.NoError(err) - require.Equal(int64(1), result[0]) - - alicePtr, err = writeToMem(alicePublicKey, mem) - require.NoError(err) - - result, err = rt.Call(ctx, "get_value", callContext1, alicePtr) - require.NoError(err) - - log.Debug("count program 1", - zap.Int64("alice", result[0]), - ) - - // write program id 2 to stack of program 1 - target, err := writeToMem(programID2, mem) - require.NoError(err) - - maxUnitsProgramToProgram := int64(1000000) - maxUnitsProgramToProgramPtr, err := writeToMem(maxUnitsProgramToProgram, mem) - require.NoError(err) - - // increment alice's counter on program 2 - fivePtr, err := writeToMem(int64(5), mem) - require.NoError(err) - alicePtr, err = writeToMem(alicePublicKey, mem) - require.NoError(err) - result, err = rt.Call(ctx, "inc_external", callContext1, target, maxUnitsProgramToProgramPtr, alicePtr, fivePtr) - require.NoError(err) - require.Equal(int64(1), result[0]) - - target, err = writeToMem(programID2, mem) - require.NoError(err) - alicePtr, err = writeToMem(alicePublicKey, mem) - require.NoError(err) - maxUnitsProgramToProgramPtr, err = writeToMem(maxUnitsProgramToProgram, mem) - require.NoError(err) - // expect alice's counter on program 2 to be 15 - result, err = rt.Call(ctx, "get_value_external", callContext1, target, maxUnitsProgramToProgramPtr, alicePtr) - require.NoError(err) - require.Equal(int64(15), result[0]) - balance, err = rt.Meter().GetBalance() - require.NoError(err) - require.Greater(balance, uint64(0)) -} diff --git a/x/programs/examples/imports/program/program.go b/x/programs/examples/imports/program/program.go index 81433cb920..85b07941a3 100644 --- a/x/programs/examples/imports/program/program.go +++ b/x/programs/examples/imports/program/program.go @@ -11,6 +11,7 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/logging" "github.com/bytecodealliance/wasmtime-go/v14" + "github.com/near/borsh-go" "go.uber.org/zap" "github.com/ava-labs/hypersdk/consts" @@ -58,17 +59,19 @@ func (i *Import) Register(link *host.Link, callContext program.Context) error { return link.RegisterImportFn(Name, "call_program", i.callProgramFn(callContext)) } +type callProgramFnArgs struct { + ProgramID []byte + Function []byte + Args []byte + MaxUnits int64 +} + // callProgramFn makes a call to an entry function of a program in the context of another program's ID. -func (i *Import) callProgramFn(callContext program.Context) func(*wasmtime.Caller, int32, int32, int32, int32, int32, int32, int64) int64 { +func (i *Import) callProgramFn(callContext program.Context) func(*wasmtime.Caller, int32, int32) int64 { return func( wasmCaller *wasmtime.Caller, - programPtr int32, - programLen int32, - functionPtr int32, - functionLen int32, - argsPtr int32, - argsLen int32, - maxUnits int64, + memOffset int32, + size int32, ) int64 { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -82,25 +85,25 @@ func (i *Import) callProgramFn(callContext program.Context) func(*wasmtime.Calle return -1 } - // get the entry function for invoke to call. - functionBytes, err := memory.Range(uint32(functionPtr), uint32(functionLen)) + bytes, err := memory.Range(uint32(memOffset), uint32(size)) if err != nil { - i.log.Error("failed to read function name from memory", + i.log.Error("failed to read call arguments from memory", zap.Error(err), ) return -1 } - programIDBytes, err := memory.Range(uint32(programPtr), uint32(programLen)) - if err != nil { - i.log.Error("failed to read id from memory", + args := callProgramFnArgs{} + if err := borsh.Deserialize(&args, bytes); err != nil { + i.log.Error("failed to unmarshal call arguments", + zap.Error(err), ) return -1 } // get the program bytes from storage - programWasmBytes, err := getProgramWasmBytes(i.log, i.mu, programIDBytes) + programWasmBytes, err := getProgramWasmBytes(i.log, i.mu, args.ProgramID) if err != nil { i.log.Error("failed to get program bytes from storage", zap.Error(err), @@ -119,11 +122,11 @@ func (i *Import) callProgramFn(callContext program.Context) func(*wasmtime.Calle } // transfer the units from the caller to the new runtime before any calls are made. - balance, err := i.meter.TransferUnitsTo(rt.Meter(), uint64(maxUnits)) + balance, err := i.meter.TransferUnitsTo(rt.Meter(), uint64(args.MaxUnits)) if err != nil { i.log.Error("failed to transfer units", zap.Uint64("balance", balance), - zap.Int64("required", maxUnits), + zap.Int64("required", args.MaxUnits), zap.Error(err), ) return -1 @@ -146,14 +149,6 @@ func (i *Import) callProgramFn(callContext program.Context) func(*wasmtime.Calle } }() - argsBytes, err := memory.Range(uint32(argsPtr), uint32(argsLen)) - if err != nil { - i.log.Error("failed to read program args from memory", - zap.Error(err), - ) - return -1 - } - rtMemory, err := rt.Memory() if err != nil { i.log.Error("failed to get memory from runtime", @@ -163,7 +158,7 @@ func (i *Import) callProgramFn(callContext program.Context) func(*wasmtime.Calle } // sync args to new runtime and return arguments to the invoke call - params, err := getCallArgs(ctx, rtMemory, argsBytes) + params, err := getCallArgs(ctx, rtMemory, args.Args) if err != nil { i.log.Error("failed to unmarshal call arguments", zap.Error(err), @@ -171,9 +166,9 @@ func (i *Import) callProgramFn(callContext program.Context) func(*wasmtime.Calle return -1 } - functionName := string(functionBytes) + functionName := string(args.Function) res, err := rt.Call(ctx, functionName, program.Context{ - ProgramID: ids.ID(programIDBytes), + ProgramID: ids.ID(args.ProgramID), // Actor: callContext.ProgramID, // OriginatingActor: callContext.OriginatingActor, }, params...) diff --git a/x/programs/examples/imports/pstate/pstate.go b/x/programs/examples/imports/pstate/pstate.go index bf0baf333c..efca9cbeb5 100644 --- a/x/programs/examples/imports/pstate/pstate.go +++ b/x/programs/examples/imports/pstate/pstate.go @@ -9,6 +9,7 @@ import ( "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/utils/logging" + "github.com/near/borsh-go" "go.uber.org/zap" "github.com/ava-labs/hypersdk/state" @@ -43,38 +44,44 @@ func (*Import) Name() string { func (i *Import) Register(link *host.Link, _ program.Context) error { i.meter = link.Meter() wrap := wrap.New(link) - if err := wrap.RegisterAnyParamFn(Name, "put", 6, i.putFnVariadic); err != nil { + if err := wrap.RegisterAnyParamFn(Name, "put", 2, i.putFnVariadic); err != nil { return err } - if err := wrap.RegisterAnyParamFn(Name, "get", 4, i.getFnVariadic); err != nil { + if err := wrap.RegisterAnyParamFn(Name, "get", 2, i.getFnVariadic); err != nil { return err } - return wrap.RegisterAnyParamFn(Name, "delete", 4, i.deleteFnVariadic) + return wrap.RegisterAnyParamFn(Name, "delete", 2, i.deleteFnVariadic) } func (i *Import) putFnVariadic(caller *program.Caller, args ...int32) (*types.Val, error) { - if len(args) != 6 { - return nil, errors.New("expected 6 arguments") + if len(args) != 2 { + return nil, errors.New("expected 2 arguments") } - return i.putFn(caller, args[0], args[1], args[2], args[3], args[4], args[5]) + return i.putFn(caller, args[0], args[1]) } func (i *Import) getFnVariadic(caller *program.Caller, args ...int32) (*types.Val, error) { - if len(args) != 4 { - return nil, errors.New("expected 4 arguments") + if len(args) != 2 { + return nil, errors.New("expected 2 arguments") } - return i.getFn(caller, args[0], args[1], args[2], args[3]) + return i.getFn(caller, args[0], args[1]) } func (i *Import) deleteFnVariadic(caller *program.Caller, args ...int32) (*types.Val, error) { - if len(args) != 4 { - return nil, errors.New("expected 4 arguments") + if len(args) != 2 { + return nil, errors.New("expected 2 arguments") } - return i.deleteFn(caller, args[0], args[1], args[2], args[3]) + return i.deleteFn(caller, args[0], args[1]) } -func (i *Import) putFn(caller *program.Caller, idPtr int32, idLen int32, keyPtr int32, keyLen int32, valuePtr int32, valueLen int32) (*types.Val, error) { +type putArgs struct { + ProgramID [32]byte + Key []byte + Value []byte +} + +func (i *Import) putFn(caller *program.Caller, memOffset int32, size int32) (*types.Val, error) { memory, err := caller.Memory() if err != nil { i.log.Error("failed to get memory from caller", @@ -83,32 +90,25 @@ func (i *Import) putFn(caller *program.Caller, idPtr int32, idLen int32, keyPtr return nil, err } - programIDBytes, err := memory.Range(uint32(idPtr), uint32(idLen)) + bytes, err := memory.Range(uint32(memOffset), uint32(size)) if err != nil { - i.log.Error("failed to read program id from memory", + i.log.Error("failed to read args from program memory", zap.Error(err), ) return nil, err } - keyBytes, err := memory.Range(uint32(keyPtr), uint32(keyLen)) + args := putArgs{} + err = borsh.Deserialize(&args, bytes) if err != nil { - i.log.Error("failed to read key from memory", + i.log.Error("failed to deserialize args", zap.Error(err), ) return nil, err } - valueBytes, err := memory.Range(uint32(valuePtr), uint32(valueLen)) - if err != nil { - i.log.Error("failed to read value from memory", - zap.Error(err), - ) - return nil, err - } - - k := storage.ProgramPrefixKey(programIDBytes, keyBytes) - err = i.mu.Insert(context.Background(), k, valueBytes) + k := storage.ProgramPrefixKey(args.ProgramID[:], args.Key) + err = i.mu.Insert(context.Background(), k, args.Value) if err != nil { i.log.Error("failed to insert into storage", zap.Error(err), @@ -119,7 +119,12 @@ func (i *Import) putFn(caller *program.Caller, idPtr int32, idLen int32, keyPtr return types.ValI32(0), nil } -func (i *Import) getFn(caller *program.Caller, idPtr int32, idLen int32, keyPtr int32, keyLen int32) (*types.Val, error) { +type getAndDeleteArgs struct { + ProgramID [32]byte + Key []byte +} + +func (i *Import) getFn(caller *program.Caller, memOffset int32, size int32) (*types.Val, error) { memory, err := caller.Memory() if err != nil { i.log.Error("failed to get memory from caller", @@ -128,22 +133,24 @@ func (i *Import) getFn(caller *program.Caller, idPtr int32, idLen int32, keyPtr return nil, err } - programIDBytes, err := memory.Range(uint32(idPtr), uint32(idLen)) + bytes, err := memory.Range(uint32(memOffset), uint32(size)) if err != nil { - i.log.Error("failed to read program id from memory", + i.log.Error("failed to read args from program memory", zap.Error(err), ) return nil, err } - keyBytes, err := memory.Range(uint32(keyPtr), uint32(keyLen)) + args := getAndDeleteArgs{} + err = borsh.Deserialize(&args, bytes) if err != nil { - i.log.Error("failed to read key from memory", + i.log.Error("failed to deserialize args", zap.Error(err), ) return nil, err } - k := storage.ProgramPrefixKey(programIDBytes, keyBytes) + + k := storage.ProgramPrefixKey(args.ProgramID[:], args.Key) val, err := i.mu.GetValue(context.Background(), k) if err != nil { if errors.Is(err, database.ErrNotFound) { @@ -183,7 +190,7 @@ func (i *Import) getFn(caller *program.Caller, idPtr int32, idLen int32, keyPtr return types.ValI32(int32(valPtr)), nil } -func (i *Import) deleteFn(caller *program.Caller, idPtr int32, idLen int32, keyPtr int32, keyLen int32) (*types.Val, error) { +func (i *Import) deleteFn(caller *program.Caller, memOffset int32, size int32) (*types.Val, error) { memory, err := caller.Memory() if err != nil { i.log.Error("failed to get memory from caller", @@ -192,23 +199,24 @@ func (i *Import) deleteFn(caller *program.Caller, idPtr int32, idLen int32, keyP return nil, err } - programIDBytes, err := memory.Range(uint32(idPtr), uint32(idLen)) + bytes, err := memory.Range(uint32(memOffset), uint32(size)) if err != nil { - i.log.Error("failed to read program id from memory", + i.log.Error("failed to read args from program memory", zap.Error(err), ) return nil, err } - keyBytes, err := memory.Range(uint32(keyPtr), uint32(keyLen)) + args := getAndDeleteArgs{} + err = borsh.Deserialize(&args, bytes) if err != nil { - i.log.Error("failed to read key from memory", + i.log.Error("failed to deserialize args", zap.Error(err), ) return nil, err } - k := storage.ProgramPrefixKey(programIDBytes, keyBytes) + k := storage.ProgramPrefixKey(args.ProgramID[:], args.Key) if err := i.mu.Remove(context.Background(), k); err != nil { i.log.Error("failed to remove from storage", zap.Error(err)) return types.ValI32(-1), nil diff --git a/x/programs/examples/token_test.go b/x/programs/examples/token_test.go deleted file mode 100644 index 4735c60a2c..0000000000 --- a/x/programs/examples/token_test.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright (C) 2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package examples - -import ( - "context" - "os" - "testing" - - "github.com/ava-labs/avalanchego/database" - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/utils/logging" - "github.com/stretchr/testify/require" - - "github.com/ava-labs/hypersdk/x/programs/engine" - "github.com/ava-labs/hypersdk/x/programs/examples/imports/pstate" - "github.com/ava-labs/hypersdk/x/programs/examples/storage" - "github.com/ava-labs/hypersdk/x/programs/host" - "github.com/ava-labs/hypersdk/x/programs/runtime" - "github.com/ava-labs/hypersdk/x/programs/tests" -) - -// go test -v -timeout 30s -run ^TestTokenProgram$ github.com/ava-labs/hypersdk/x/programs/examples -memprofile benchvset.mem -cpuprofile benchvset.cpu -func TestTokenProgram(t *testing.T) { - t.Run("BurnUserTokens", func(t *testing.T) { - wasmBytes := tests.ReadFixture(t, "../tests/fixture/token.wasm") - require := require.New(t) - maxUnits := uint64(200000) - eng := engine.New(engine.NewConfig()) - program := newTokenProgram(maxUnits, eng, runtime.NewConfig(), wasmBytes) - require.NoError(program.Run(context.Background())) - - rt := runtime.New(program.log, program.engine, program.imports, program.cfg) - ctx := context.Background() - callContext := program.Context() - err := rt.Initialize(ctx, callContext, program.programBytes, program.maxUnits) - require.NoError(err) - - // simulate create program transaction - programID := program.ProgramID() - err = storage.SetProgram(ctx, program.db, programID, program.programBytes) - require.NoError(err) - - mem, err := rt.Memory() - require.NoError(err) - - // initialize program - _, err = rt.Call(ctx, "init", callContext) - require.NoError(err, "failed to initialize program") - - // generate alice keys - alicePublicKey, err := newKey() - require.NoError(err) - - // write alice's key to stack and get pointer - alicePtr, err := writeToMem(alicePublicKey, mem) - require.NoError(err) - - // mint 100 tokens to alice - mintAlice := int64(1000) - mintAlicePtr, err := writeToMem(mintAlice, mem) - require.NoError(err) - - _, err = rt.Call(ctx, "mint_to", callContext, alicePtr, mintAlicePtr) - require.NoError(err) - - alicePtr, err = writeToMem(alicePublicKey, mem) - require.NoError(err) - - // check balance of alice - result, err := rt.Call(ctx, "get_balance", callContext, alicePtr) - require.NoError(err) - require.Equal(int64(1000), result[0]) - - // read alice balance from state db - aliceBalance, err := program.GetUserBalanceFromState(ctx, programID, alicePublicKey) - require.NoError(err) - require.Equal(uint32(1000), aliceBalance) - - alicePtr, err = writeToMem(alicePublicKey, mem) - require.NoError(err) - - // burn alice tokens - _, err = rt.Call(ctx, "burn_from", callContext, alicePtr) - require.NoError(err) - - // check balance of alice from state db - _, err = program.GetUserBalanceFromState(ctx, programID, alicePublicKey) - require.ErrorIs(err, database.ErrNotFound) - }) - - wasmBytes := tests.ReadFixture(t, "../tests/fixture/token.wasm") - require := require.New(t) - maxUnits := uint64(200000) - eng := engine.New(engine.NewConfig()) - program := newTokenProgram(maxUnits, eng, runtime.NewConfig(), wasmBytes) - require.NoError(program.Run(context.Background())) -} - -// go test -v -benchmem -run=^$ -bench ^BenchmarkTokenProgram$ github.com/ava-labs/hypersdk/x/programs/examples -memprofile benchvset.mem -cpuprofile benchvset.cpu -func BenchmarkTokenProgram(b *testing.B) { - wasmBytes := tests.ReadFixture(b, "../tests/fixture/token.wasm") - maxUnits := uint64(80000) - - cfg := runtime.NewConfig(). - SetCompileStrategy(engine.CompileWasm) - - ecfg, err := engine.NewConfigBuilder(). - WithDefaultCache(true). - Build() - require.NoError(b, err) - eng := engine.New(ecfg) - - b.Run("benchmark_token_program_compile_and_cache", func(b *testing.B) { - for i := 0; i < b.N; i++ { - b.StopTimer() - program := newTokenProgram(maxUnits, eng, cfg, wasmBytes) - b.StartTimer() - require.NoError(b, program.Run(context.Background())) - } - }) - - b.Run("benchmark_token_program_compile_and_cache_short", func(b *testing.B) { - for i := 0; i < b.N; i++ { - b.StopTimer() - program := newTokenProgram(maxUnits, eng, cfg, wasmBytes) - b.StartTimer() - require.NoError(b, program.RunShort(context.Background())) - } - }) - - cfg = runtime.NewConfig(). - SetCompileStrategy(engine.PrecompiledWasm) - ecfg, err = engine.NewConfigBuilder(). - WithDefaultCache(true). - Build() - eng = engine.New(ecfg) - require.NoError(b, err) - preCompiledTokenProgramBytes, err := engine.PreCompileWasmBytes(eng, wasmBytes, cfg.LimitMaxMemory) - require.NoError(b, err) - - b.ResetTimer() - b.Run("benchmark_token_program_precompile", func(b *testing.B) { - for i := 0; i < b.N; i++ { - b.StopTimer() - program := newTokenProgram(maxUnits, eng, cfg, preCompiledTokenProgramBytes) - b.StartTimer() - require.NoError(b, program.Run(context.Background())) - } - }) - - b.Run("benchmark_token_program_precompile_short", func(b *testing.B) { - for i := 0; i < b.N; i++ { - b.StopTimer() - program := newTokenProgram(maxUnits, eng, cfg, preCompiledTokenProgramBytes) - b.StartTimer() - require.NoError(b, program.RunShort(context.Background())) - } - }) -} - -func newTokenProgram(maxUnits uint64, engine *engine.Engine, cfg *runtime.Config, programBytes []byte) *Token { - db := newTestDB() - - log := logging.NewLogger( - "", - logging.NewWrappedCore( - logging.Info, - os.Stderr, - logging.Plain.ConsoleEncoder(), - )) - - // define imports - importsBuilder := host.NewImportsBuilder() - importsBuilder.Register("state", func() host.Import { - return pstate.New(log, db) - }) - - id := ids.GenerateTestID() - - return NewToken(id, log, engine, programBytes, db, cfg, importsBuilder.Build(), maxUnits) -} diff --git a/x/programs/examples/utils.go b/x/programs/examples/utils.go index e33df72d4f..24a7c9914b 100644 --- a/x/programs/examples/utils.go +++ b/x/programs/examples/utils.go @@ -43,12 +43,6 @@ type testDB struct { db *memdb.Database } -func newTestDB() *testDB { - return &testDB{ - db: memdb.New(), - } -} - func (c *testDB) GetValue(_ context.Context, key []byte) ([]byte, error) { return c.db.Get(key) } diff --git a/x/programs/rust/examples/counter/src/lib.rs b/x/programs/rust/examples/counter/src/lib.rs index b572af8777..9937b1ba05 100644 --- a/x/programs/rust/examples/counter/src/lib.rs +++ b/x/programs/rust/examples/counter/src/lib.rs @@ -45,7 +45,7 @@ pub fn inc(context: Context, to: Address, amount: i64) -> bool { #[public] pub fn inc_external(_: Context, target: Program, max_units: i64, of: Address, amount: i64) -> i64 { let params = params!(&of, &amount).unwrap(); - target.call_function("inc", params, max_units).unwrap() + target.call_function("inc", ¶ms, max_units).unwrap() } /// Gets the count at the address. @@ -63,7 +63,7 @@ pub fn get_value(context: Context, of: Address) -> i64 { pub fn get_value_external(_: Context, target: Program, max_units: i64, of: Address) -> i64 { let params = params!(&of).unwrap(); target - .call_function("get_value", params, max_units) + .call_function("get_value", ¶ms, max_units) .unwrap() } diff --git a/x/programs/rust/wasmlanche-sdk/src/memory.rs b/x/programs/rust/wasmlanche-sdk/src/memory.rs index c800d2ce4a..1f230eb1e0 100644 --- a/x/programs/rust/wasmlanche-sdk/src/memory.rs +++ b/x/programs/rust/wasmlanche-sdk/src/memory.rs @@ -16,24 +16,6 @@ thread_local! { static GLOBAL_STORE: RefCell> = RefCell::new(HashMap::new()); } -/// Converts a pointer to a i64 with the first 4 bytes of the pointer -/// representing the length of the memory block. -/// # Errors -/// Returns an [`StateError`] if the pointer or length of `args` exceeds -/// the maximum size of a u32. -#[allow(clippy::cast_possible_truncation)] -pub fn to_ffi_ptr(arg: &[u8]) -> Result { - let ptr = arg.as_ptr(); - let len = arg.len(); - - // Make sure the pointer and length fit into u32 - if ptr as usize > u32::MAX as usize || len > u32::MAX as usize { - return Err(StateError::IntegerConversion); - } - - Ok(CPointer(ptr, len)) -} - /// Converts a raw pointer to a deserialized value. /// Expects the first 4 bytes of the pointer to represent the `length` of the serialized value, /// with the subsequent `length` bytes comprising the serialized data. diff --git a/x/programs/rust/wasmlanche-sdk/src/params.rs b/x/programs/rust/wasmlanche-sdk/src/params.rs index 37ab16bc18..a4c8794de0 100644 --- a/x/programs/rust/wasmlanche-sdk/src/params.rs +++ b/x/programs/rust/wasmlanche-sdk/src/params.rs @@ -1,9 +1,6 @@ -use crate::{ - memory::{to_ffi_ptr, CPointer}, - state::Error as StateError, - Error, -}; +use crate::Error; use borsh::BorshSerialize; +use std::ops::Deref; #[macro_export] macro_rules! params { @@ -22,9 +19,11 @@ pub struct Param(Vec); /// A collection of [borsh] serialized parameters. pub struct Params(Vec); -impl Params { - pub(crate) fn into_ffi_ptr(self) -> Result { - to_ffi_ptr(&self.0) +impl Deref for Params { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 } } diff --git a/x/programs/rust/wasmlanche-sdk/src/program.rs b/x/programs/rust/wasmlanche-sdk/src/program.rs index 8e67d78de2..70b4006b9e 100644 --- a/x/programs/rust/wasmlanche-sdk/src/program.rs +++ b/x/programs/rust/wasmlanche-sdk/src/program.rs @@ -1,5 +1,4 @@ use crate::{ - memory::{to_ffi_ptr, CPointer}, state::{Error as StateError, Key, State}, Params, }; @@ -45,25 +44,32 @@ impl Program { pub fn call_function( &self, function_name: &str, - args: Params, + args: &Params, max_units: i64, ) -> Result { - // flatten the args into a single byte vector - let target = to_ffi_ptr(self.id())?; - let function = to_ffi_ptr(function_name.as_bytes())?; - let args = args.into_ffi_ptr()?; + #[link(wasm_import_module = "program")] + extern "C" { + #[link_name = "call_program"] + fn ffi(ptr: *const u8, len: usize) -> i64; + } - Ok(unsafe { _call_program(target, function, args, max_units) }) + let args = CallProgramArgs { + target_id: self.id(), + function: function_name.as_bytes(), + args_ptr: args, + max_units, + }; + + let args_bytes = borsh::to_vec(&args).map_err(|_| StateError::Serialization)?; + + Ok(unsafe { ffi(args_bytes.as_ptr(), args_bytes.len()) }) } } -#[link(wasm_import_module = "program")] -extern "C" { - #[link_name = "call_program"] - fn _call_program( - target_id: CPointer, - function: CPointer, - args_ptr: CPointer, - max_units: i64, - ) -> i64; +#[derive(BorshSerialize)] +struct CallProgramArgs<'a> { + target_id: &'a [u8], + function: &'a [u8], + args_ptr: &'a [u8], + max_units: i64, } diff --git a/x/programs/rust/wasmlanche-sdk/src/state.rs b/x/programs/rust/wasmlanche-sdk/src/state.rs index 9bd98a08c0..78bea4bc87 100644 --- a/x/programs/rust/wasmlanche-sdk/src/state.rs +++ b/x/programs/rust/wasmlanche-sdk/src/state.rs @@ -102,15 +102,25 @@ where let val_bytes = if let Some(val) = self.cache.get(&key) { val } else { - let val = unsafe { host::get_bytes(&self.program, &key.clone().into())? }; - let val_ptr = val as *const u8; - // TODO write a test for that - if val_ptr.is_null() { - return Err(Error::Read); - } - - // TODO Wrap in OK for now, change from_raw_ptr to return Result - let bytes = into_bytes(val_ptr).ok_or(Error::InvalidPointer)?; + let args = GetAndDeleteArgs { + caller: self.program, + // TODO: shouldn't have to clone here + key: key.clone().into().0, + }; + + let args_bytes = borsh::to_vec(&args).map_err(|_| StateError::Serialization)?; + + let ptr = host::get_bytes(&args_bytes)?; + + let bytes = into_bytes(ptr).ok_or(Error::InvalidPointer)?; + + // TODO: + // should be able to do something like the following + // `let key = Key(args.key);` + // to avoid cloning. The problem is we convert into a Key without knowing + // that we can convert back into a K. + // Either we need the key to actually be `Key` instead of `Into` + // or we put the bound `K: From` as well. self.cache.entry(key).or_insert(bytes) }; @@ -124,15 +134,31 @@ where pub fn delete(&mut self, key: K) -> Result<(), Error> { self.cache.remove(&key); - unsafe { host::delete_bytes(&self.program, &key.into()) } + let args = GetAndDeleteArgs { + caller: self.program, + key: key.into().0, + }; + + let args_bytes = borsh::to_vec(&args).map_err(|_| StateError::Serialization)?; + + host::delete_bytes(&args_bytes) } /// Apply all pending operations to storage and mark the cache as flushed fn flush(&mut self) -> Result<(), Error> { - for (key, value) in self.cache.drain() { - unsafe { - host::put_bytes(&self.program, &key.into(), &value)?; - } + let args_iter = self + .cache + .drain() + .map(|(key, val)| (key.into(), val)) + .map(|(key, val)| PutArgs { + caller: self.program, + key: key.0, + bytes: val, + }) + .map(|args| borsh::to_vec(&args).map_err(|_| StateError::Serialization)); + + for args in args_iter { + host::put_bytes(&args?)?; } Ok(()) @@ -159,87 +185,65 @@ impl Key { } } -macro_rules! ffi_linker { - ($mod:literal, $link:literal, $caller:ident, $key:ident) => { - #[link(wasm_import_module = $mod)] - extern "C" { - #[link_name = $link] - fn ffi(caller: CPointer, key: CPointer) -> i32; - } - - let $caller = to_ffi_ptr($caller.id())?; - let $key = to_ffi_ptr($key)?; - }; - ($mod:literal, $link:literal, $caller:ident, $key:ident, $value:ident) => { - #[link(wasm_import_module = $mod)] - extern "C" { - #[link_name = $link] - fn ffi(caller: CPointer, key: CPointer, value: CPointer) -> i32; - } - - let $caller = to_ffi_ptr($caller.id())?; - let $key = to_ffi_ptr($key)?; - let $value = to_ffi_ptr($value)?; - }; +#[derive(BorshSerialize)] +struct PutArgs { + caller: Program, + key: Vec, + bytes: Vec, } -macro_rules! call_host_fn { - ( - wasm_import_module = $mod:literal - link_name = $link:literal - args = ($caller:ident, $key:ident) - ) => {{ - ffi_linker!($mod, $link, $caller, $key); - - unsafe { ffi($caller, $key) } - }}; - - ( - wasm_import_module = $mod:literal - link_name = $link:literal - args = ($caller:ident, $key:ident, $value:ident) - ) => {{ - ffi_linker!($mod, $link, $caller, $key, $value); - - unsafe { ffi($caller, $key, $value) } - }}; +#[derive(BorshSerialize)] +struct GetAndDeleteArgs { + caller: Program, + key: Vec, } mod host { - use super::{Key, Program}; - use crate::{ - memory::{to_ffi_ptr, CPointer}, - state::Error, - }; + use crate::state::Error; /// Persists the bytes at key on the host storage. - pub(super) unsafe fn put_bytes(caller: &Program, key: &Key, bytes: &[u8]) -> Result<(), Error> { - match call_host_fn! { - wasm_import_module = "state" - link_name = "put" - args = (caller, key, bytes) - } { + pub(super) fn put_bytes(bytes: &[u8]) -> Result<(), Error> { + #[link(wasm_import_module = "state")] + extern "C" { + #[link_name = "put"] + fn ffi(ptr: *const u8, len: usize) -> usize; + } + + let result = unsafe { ffi(bytes.as_ptr(), bytes.len()) }; + + match result { 0 => Ok(()), _ => Err(Error::Write), } } /// Gets the bytes associated with the key from the host. - pub(super) unsafe fn get_bytes(caller: &Program, key: &Key) -> Result { - Ok(call_host_fn! { - wasm_import_module = "state" - link_name = "get" - args = (caller, key) - }) + pub(super) fn get_bytes(bytes: &[u8]) -> Result<*const u8, Error> { + #[link(wasm_import_module = "state")] + extern "C" { + #[link_name = "get"] + fn ffi(ptr: *const u8, len: usize) -> *const u8; + } + + let result = unsafe { ffi(bytes.as_ptr(), bytes.len()) }; + + if result.is_null() { + Err(Error::Read) + } else { + Ok(result) + } } /// Deletes the bytes at key ptr from the host storage - pub(super) unsafe fn delete_bytes(caller: &Program, key: &Key) -> Result<(), Error> { - match call_host_fn! { - wasm_import_module = "state" - link_name = "delete" - args = (caller, key) - } { + pub(super) fn delete_bytes(bytes: &[u8]) -> Result<(), Error> { + #[link(wasm_import_module = "state")] + extern "C" { + #[link_name = "delete"] + fn ffi(ptr: *const u8, len: usize) -> i32; + } + + let result = unsafe { ffi(bytes.as_ptr(), bytes.len()) }; + match result { 0 => Ok(()), _ => Err(Error::Delete), } diff --git a/x/programs/tests/fixture/counter.wasm b/x/programs/tests/fixture/counter.wasm deleted file mode 100755 index 243cae1aff..0000000000 Binary files a/x/programs/tests/fixture/counter.wasm and /dev/null differ diff --git a/x/programs/tests/fixture/token.wasm b/x/programs/tests/fixture/token.wasm deleted file mode 100755 index 185ceec548..0000000000 Binary files a/x/programs/tests/fixture/token.wasm and /dev/null differ