diff --git a/README.md b/README.md index 936648b8..ed770ef8 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,12 @@ Implemented commands: - ZSCORE - ZUNIONSTORE - ZSCAN + - Scripting (all) + - EVAL + - EVALSHA + - SCRIPT LOAD + - SCRIPT EXISTS + - SCRIPT FLUSH Since miniredis is intended to be used in unittests TTLs don't decrease @@ -244,9 +250,8 @@ Commands which will probably not be implemented: - ~~SUBSCRIBE~~ - ~~UNSUBSCRIBE~~ - Scripting (all) - - ~~EVAL~~ - - ~~EVALSHA~~ - - ~~SCRIPT *~~ + - ~~SCRIPT DEBUG~~ + - ~~SCRIPT KILL~~ - Server - ~~BGSAVE~~ - ~~BGWRITEAOF~~ diff --git a/cmd_scripting.go b/cmd_scripting.go new file mode 100644 index 00000000..27e0b2a1 --- /dev/null +++ b/cmd_scripting.go @@ -0,0 +1,308 @@ +package miniredis + +import ( + "crypto/sha1" + "encoding/hex" + "fmt" + "io" + "reflect" + "strconv" + "strings" + + "github.com/garyburd/redigo/redis" + lua "github.com/yuin/gopher-lua" + + "github.com/alicebob/miniredis/server" +) + +func commandsScripting(m *Miniredis) { + m.srv.Register("EVAL", m.cmdEval) + m.srv.Register("EVALSHA", m.cmdEvalsha) + m.srv.Register("SCRIPT", m.cmdScript) +} + +var scriptmap = map[string]string{} + +func byteToString(bs []uint8) string { + b := make([]byte, len(bs)) + for i, v := range bs { + b[i] = byte(v) + } + return string(b) +} + +func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) error { + L := lua.NewState() + defer L.Close() + + // create a redis client for redis.call + conn, err := redis.Dial("tcp", m.srv.Addr().String()) + if err != nil { + return err + } + defer conn.Close() + + // set global variable KEYS + keysTable := L.NewTable() + keysLen, err := strconv.Atoi(args[1]) + if err != nil { + c.WriteError(err.Error()) + return err + } + for i := 0; i < keysLen; i++ { + L.RawSet(keysTable, lua.LNumber(i+1), lua.LString(args[i+2])) + } + L.SetGlobal("KEYS", keysTable) + + // set global variable ARGV + argvTable := L.NewTable() + argvLen := len(args) - 2 - keysLen + for i := 0; i < argvLen; i++ { + L.RawSet(argvTable, lua.LNumber(i+1), lua.LString(args[i+2+keysLen])) + } + L.SetGlobal("ARGV", argvTable) + + // Register call function to lua VM + redisFuncs := map[string]lua.LGFunction{ + "call": func(L *lua.LState) int { + top := L.GetTop() + + cmd := lua.LVAsString(L.Get(1)) + args := make([]interface{}, top-1) + for i := 2; i <= top; i++ { + arg := L.Get(i) + + dataType := arg.Type() + switch dataType { + case lua.LTBool: + args[i-2] = lua.LVAsBool(arg) + case lua.LTNumber: + value, _ := strconv.ParseFloat(lua.LVAsString(arg), 64) + args[i-2] = value + case lua.LTString: + args[i-2] = lua.LVAsString(arg) + case lua.LTNil: + case lua.LTFunction: + case lua.LTUserData: + case lua.LTThread: + case lua.LTTable: + case lua.LTChannel: + default: + args[i-2] = nil + } + } + res, err := conn.Do(cmd, args...) + if err != nil { + L.Push(lua.LNil) + return 1 + } + + pushCount := 0 + resType := reflect.TypeOf(res) + + if resType == nil { + L.Push(lua.LNil) + pushCount++ + } else { + if resType.String() == "int64" { + L.Push(lua.LNumber(res.(int64))) + pushCount++ + } else if resType.String() == "[]uint8" { + L.Push(lua.LString(byteToString(res.([]uint8)))) + pushCount++ + } else if resType.String() == "[]interface {}" { + L.Push(m.redisToLua(L, res)) + pushCount++ + } else { + L.Push(lua.LString(res.(string))) + pushCount++ + } + } + + return pushCount // Notify that we pushed one value to the stack + }, + } + + redisFuncs["pcall"] = redisFuncs["call"] + + // Register command handlers + L.Push(L.NewFunction(func(L *lua.LState) int { + mod := L.RegisterModule("redis", redisFuncs).(*lua.LTable) + L.Push(mod) + return 1 + })) + + L.Push(lua.LString("redis")) + L.Call(1, 0) + + if err := L.DoString(script); err != nil { + c.WriteError(err.Error()) + return err + } + + if L.GetTop() > 0 { + m.luaToRedis(L, c, L.Get(1)) + } else { + c.WriteNull() + } + + return nil +} + +func (m *Miniredis) redisToLua(L *lua.LState, res interface{}) *lua.LTable { + rettb := L.NewTable() + for _, e := range res.([]interface{}) { + if e == nil { + L.RawSet(rettb, lua.LNumber(rettb.Len()+1), lua.LValue(nil)) + continue + } + + if reflect.TypeOf(e).String() == "int64" { + L.RawSet(rettb, lua.LNumber(rettb.Len()+1), lua.LNumber(e.(int64))) + } else if reflect.TypeOf(e).String() == "[]uint8" { + L.RawSet(rettb, lua.LNumber(rettb.Len()+1), lua.LString(byteToString(e.([]uint8)))) + } else if reflect.TypeOf(e).String() == "[]interface {}" { + L.RawSet(rettb, lua.LNumber(rettb.Len()+1), m.redisToLua(L, e)) + } else { + L.RawSet(rettb, lua.LNumber(rettb.Len()+1), lua.LString(e.(string))) + } + } + + return rettb +} + +func (m *Miniredis) luaToRedis(L *lua.LState, c *server.Peer, value lua.LValue) { + if value == nil { + c.WriteNull() + return + } + + switch value.Type() { + case lua.LTNil: + c.WriteNull() + case lua.LTBool: + if lua.LVAsBool(value) { + c.WriteInt(1) + } else { + c.WriteInt(0) + } + case lua.LTNumber: + c.WriteInt(int(lua.LVAsNumber(value))) + case lua.LTString: + c.WriteInline(lua.LVAsString(value)) + case lua.LTTable: + result := []lua.LValue{} + for j := 1; true; j++ { + val := L.GetTable(value, lua.LNumber(j)) + if val == nil { + result = append(result, val) + continue + } + + if val.Type() == lua.LTNil { + break + } + + result = append(result, val) + } + + c.WriteLen(len(result)) + for _, r := range result { + m.luaToRedis(L, c, r) + } + default: + c.WriteInline(lua.LVAsString(value)) + } +} + +func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + err := m.runLuaScript(c, args[0], args) + if err != nil { + c.WriteError(err.Error()) + } +} + +func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + if script, ok := scriptmap[args[0]]; ok { + err := m.runLuaScript(c, script, args) + if err != nil { + c.WriteError(err.Error()) + } + } else { + c.WriteError(fmt.Sprintf("Invalid SHA %v", args[0])) + } +} + +func (m *Miniredis) cmdScript(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + switch strings.Trim(strings.ToLower(args[0]), " \t") { + case "load": + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + + shaList := []string{} + for i := 1; i < len(args); i++ { + h := sha1.New() + io.WriteString(h, args[i]) + hash := hex.EncodeToString(h.Sum(nil)) + scriptmap[hash] = args[i] + shaList = append(shaList, hash) + } + + c.WriteLen(len(shaList)) + for _, sha := range shaList { + c.WriteBulk(sha) + } + case "exists": + if len(args) < 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + + c.WriteLen(len(args) - 1) + for i := 1; i < len(args); i++ { + if _, ok := scriptmap[args[i]]; ok { + c.WriteInt(1) + } else { + c.WriteInt(0) + } + } + case "flush": + for k := range scriptmap { + delete(scriptmap, k) + } + c.WriteOK() + default: + c.WriteError("Not implemented yet") + } +} diff --git a/cmd_scripting_test.go b/cmd_scripting_test.go new file mode 100644 index 00000000..3859691a --- /dev/null +++ b/cmd_scripting_test.go @@ -0,0 +1,318 @@ +package miniredis + +import ( + "testing" + + "github.com/garyburd/redigo/redis" +) + +func TestCmdEvalReplyConversion(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + cases := map[string]struct { + script string + args []interface{} + expected interface{} + }{ + "Return nil": { + script: "", + args: []interface{}{ + 0, + }, + }, + "Return boolean true": { + script: "return true", + args: []interface{}{ + 0, + }, + expected: int64(1), + }, + "Return boolean false": { + script: "return true", + args: []interface{}{ + 0, + }, + expected: int64(1), + }, + "Return single number": { + script: "return 10", + args: []interface{}{ + 0, + }, + expected: int64(10), + }, + "Return single float": { + script: "return 12.345", + args: []interface{}{ + 0, + }, + expected: int64(12), + }, + "Return multiple number": { + script: "return 10, 20", + args: []interface{}{ + 0, + }, + expected: int64(10), + }, + "Return single string": { + script: "return 'test'", + args: []interface{}{ + 0, + }, + expected: "test", + }, + "Return multiple string": { + script: "return 'test1', 'test2'", + args: []interface{}{ + 0, + }, + expected: "test1", + }, + "Return single table multiple integer": { + script: "return {10, 20}", + args: []interface{}{ + 0, + }, + expected: []interface{}{ + int64(10), + int64(20), + }, + }, + "Return single table multiple string": { + script: "return {'test1', 'test2'}", + args: []interface{}{ + 0, + }, + expected: []interface{}{ + "test1", + "test2", + }, + }, + "Return nested table": { + script: "return {10, 20, {30, 40}}", + args: []interface{}{ + 0, + }, + expected: []interface{}{ + int64(10), + int64(20), + []interface{}{ + int64(30), + int64(40), + }, + }, + }, + "Return combination table": { + script: "return {10, 20, {30, 'test', true, 40}, false}", + args: []interface{}{ + 0, + }, + expected: []interface{}{ + int64(10), + int64(20), + []interface{}{ + int64(30), + "test", + int64(1), + int64(40), + }, + int64(0), + }, + }, + "KEYS and ARGV": { + script: "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", + args: []interface{}{ + 2, + "key1", + "key2", + "first", + "second", + }, + expected: []interface{}{ + "key1", + "key2", + "first", + "second", + }, + }, + } + + for id, tc := range cases { + args := make([]interface{}, len(tc.args)+1) + args[0] = tc.script + for index, arg := range tc.args { + args[index+1] = arg + } + + reply, err := c.Do("EVAL", args...) + if err != nil { + t.Errorf("%v: Unexpected error: %v", id, err) + } + + equals(t, tc.expected, reply) + } +} + +func TestCmdEvalResponse(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + defer c.Close() + + { + v, err := c.Do("EVAL", "return redis.call('set','foo','bar')", 0) + ok(t, err) + equals(t, "OK", v) + } + + { + v, err := c.Do("EVAL", "return redis.call('get','foo')", 0) + ok(t, err) + equals(t, "bar", v) + } + + { + v, err := c.Do("EVAL", "return redis.call('HMSET', 'mkey', 'foo','bar','foo1','bar1')", 0) + ok(t, err) + equals(t, "OK", v) + } + + { + v, err := c.Do("EVAL", "return redis.call('HGETALL','mkey')", 0) + ok(t, err) + equals(t, []interface{}{"foo", "bar", "foo1", "bar1"}, v) + } + + { + v, err := c.Do("EVAL", "return redis.call('HMGET','mkey', 'foo1')", 0) + ok(t, err) + equals(t, []interface{}{"bar1"}, v) + } + + { + v, err := c.Do("EVAL", "return redis.call('HMGET','mkey', 'foo')", 0) + ok(t, err) + equals(t, []interface{}{"bar"}, v) + } + + { + v, err := c.Do("EVAL", "return redis.call('HMGET','mkey', 'bad', 'key')", 0) + ok(t, err) + equals(t, []interface{}{nil, nil}, v) + } +} + +func TestCmdScript(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + defer c.Close() + + // SCRIPT LOAD + { + v, err := redis.Strings(c.Do("SCRIPT", "LOAD", "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}", "return redis.call('set','foo','bar')")) + ok(t, err) + equals(t, []string{"a42059b356c875f0717db19a51f6aaca9ae659ea", "2fa2b029f72572e803ff55a09b1282699aecae6a"}, v) + } + + // SCRIPT EXISTS + { + v, err := redis.Int64s(c.Do("SCRIPT", "exists", "a42059b356c875f0717db19a51f6aaca9ae659ea", "2fa2b029f72572e803ff55a09b1282699aecae6a", "invalid sha")) + ok(t, err) + equals(t, []int64{1, 1, 0}, v) + } + + // SCRIPT FLUSH + { + v, err := redis.String(c.Do("SCRIPT", "flush")) + ok(t, err) + equals(t, "OK", v) + } +} + +func TestCmdScriptAndEvalsha(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + defer c.Close() + + // SCRIPT LOAD + { + v, err := redis.Strings(c.Do("SCRIPT", "LOAD", "redis.call('set', KEYS[1], ARGV[1])\n return redis.call('get', KEYS[1]) ")) + ok(t, err) + equals(t, []string{"054a13c20b748da2922a5f37f144342de21b8650"}, v) + } + + // TEST EVALSHA + { + v, err := c.Do("EVALSHA", "054a13c20b748da2922a5f37f144342de21b8650", 1, "test_key", "test_value") + ok(t, err) + equals(t, "test_value", v) + } + +} + +func TestCmdScriptAndEvalshaErrorRedisCall(t *testing.T) { + c, err := redis.Dial("tcp", "127.0.0.1:6379") + ok(t, err) + defer c.Close() + + // SCRIPT LOAD + { + v, err := redis.String(c.Do("EVAL", "return redis.call('invalid', 'key', 'value') ", 0)) + ok(t, err) + equals(t, "6a5ccb5fcaf42edce7f9bcb529e58d0f5c2d97c4", v) + } +} +func TestCmdScriptAndEvalshaErrorRedisPCall(t *testing.T) { + c, err := redis.Dial("tcp", "127.0.0.1:6379") + ok(t, err) + defer c.Close() + + // SCRIPT LOAD + { + v, err := redis.String(c.Do("EVAL", "return redis.pcall('invalid', 'key', 'value') ", 0)) + ok(t, err) + equals(t, "6a5ccb5fcaf42edce7f9bcb529e58d0f5c2d97c4", v) + } +} + +func TestCmdScriptAndEvalshaError(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + defer c.Close() + + // SCRIPT LOAD + { + v, err := redis.String(c.Do("EVAL", "return redis.call('invalid', 'key', 'value') ", 0)) + ok(t, err) + equals(t, "6a5ccb5fcaf42edce7f9bcb529e58d0f5c2d97c4", v) + } + + // SCRIPT LOAD + { + v, err := redis.String(c.Do("EVAL", "return redis.pcall('invalid', 'key', 'value') ", 0)) + ok(t, err) + equals(t, "6a5ccb5fcaf42edce7f9bcb529e58d0f5c2d97c4", v) + } + +} diff --git a/miniredis.go b/miniredis.go index 2fe52a35..faea68e2 100644 --- a/miniredis.go +++ b/miniredis.go @@ -135,6 +135,7 @@ func (m *Miniredis) start(s *server.Server) error { commandsSet(m) commandsSortedSet(m) commandsTransaction(m) + commandsScripting(m) return nil }