diff --git a/README.md b/README.md index 54e674f1..f0dd2655 100644 --- a/README.md +++ b/README.md @@ -13,12 +13,17 @@ `Godis` is a golang implementation of Redis Server, which intents to provide an example of writing a high concurrent middleware using golang. -Godis implemented most features of redis, including 5 data structures, ttl, publish/subscribe, geo and AOF persistence. - -Godis can run as a server side cluster which is transparent to client. You can connect to any node in the cluster to -access all data in the cluster. - -Godis has a concurrent core, so you don't have to worry about your commands blocking the server too much. +Key Features: + +- support string, list, hash, set, sorted set +- ttl +- publish/suscribe +- geo +- aof and aof rewrite +- Transaction. The `multi` command is Atomic and Isolated. If any errors are encountered during execution, godis will rollback the executed commands +- server side cluster which is transparent to client. You can connect to any node in the cluster to + access all data in the cluster. +- a concurrent core, so you don't have to worry about your commands blocking the server too much. If you could read Chinese, you can find more details in [My Blog](https://www.cnblogs.com/Finley/category/1598973.html). diff --git a/README_CN.md b/README_CN.md index 3da3a975..07a49b9a 100644 --- a/README_CN.md +++ b/README_CN.md @@ -9,11 +9,14 @@ Godis 是一个用 Go 语言实现的 Redis 服务器。本项目旨在为尝试使用 Go 语言开发高并发中间件的朋友提供一些参考。 -Godis 实现了 Redis 的大多数功能,包括5种数据结构、TTL、发布订阅、地理位置以及 AOF 持久化。 - -Godis 支持集群模式,集群对客户端是透明的只要连接上集群中任意一个节点就可以访问集群中所有数据。 - -Godis 是并行工作的, 无需担心您的操作会阻塞整个服务器. +关键功能: +- 支持 string, list, hash, set, sorted set 数据结构 +- 自动过期功能(TTL) +- 地理位置 +- AOF 持久化及AOF重写 +- 事务. Multi 命令开启的事务具有`原子性`和`隔离性`. 若在执行过程中遇到错误, godis 会回滚已执行的命令 +- 内置集群模式. 集群对客户端是透明的, 您可以像使用单机版 redis 一样使用 godis 集群 +- 并行引擎, 无需担心您的操作会阻塞整个服务器. 可以在[我的博客](https://www.cnblogs.com/Finley/category/1598973.html)了解更多关于 Godis 的信息。 diff --git a/aof.go b/aof.go index 7b1b3385..d410a861 100644 --- a/aof.go +++ b/aof.go @@ -3,10 +3,7 @@ package godis import ( "github.com/hdt3213/godis/config" "github.com/hdt3213/godis/datastruct/dict" - List "github.com/hdt3213/godis/datastruct/list" "github.com/hdt3213/godis/datastruct/lock" - "github.com/hdt3213/godis/datastruct/set" - SortedSet "github.com/hdt3213/godis/datastruct/sortedset" "github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/redis/parser" @@ -148,100 +145,6 @@ func (db *DB) aofRewrite() { db.finishRewrite(file) } -var setCmd = []byte("SET") - -func stringToCmd(key string, bytes []byte) *reply.MultiBulkReply { - args := make([][]byte, 3) - args[0] = setCmd - args[1] = []byte(key) - args[2] = bytes - return reply.MakeMultiBulkReply(args) -} - -var rPushAllCmd = []byte("RPUSH") - -func listToCmd(key string, list *List.LinkedList) *reply.MultiBulkReply { - args := make([][]byte, 2+list.Len()) - args[0] = rPushAllCmd - args[1] = []byte(key) - list.ForEach(func(i int, val interface{}) bool { - bytes, _ := val.([]byte) - args[2+i] = bytes - return true - }) - return reply.MakeMultiBulkReply(args) -} - -var sAddCmd = []byte("SADD") - -func setToCmd(key string, set *set.Set) *reply.MultiBulkReply { - args := make([][]byte, 2+set.Len()) - args[0] = sAddCmd - args[1] = []byte(key) - i := 0 - set.ForEach(func(val string) bool { - args[2+i] = []byte(val) - i++ - return true - }) - return reply.MakeMultiBulkReply(args) -} - -var hMSetCmd = []byte("HMSET") - -func hashToCmd(key string, hash dict.Dict) *reply.MultiBulkReply { - args := make([][]byte, 2+hash.Len()*2) - args[0] = hMSetCmd - args[1] = []byte(key) - i := 0 - hash.ForEach(func(field string, val interface{}) bool { - bytes, _ := val.([]byte) - args[2+i*2] = []byte(field) - args[3+i*2] = bytes - i++ - return true - }) - return reply.MakeMultiBulkReply(args) -} - -var zAddCmd = []byte("ZADD") - -func zSetToCmd(key string, zset *SortedSet.SortedSet) *reply.MultiBulkReply { - args := make([][]byte, 2+zset.Len()*2) - args[0] = zAddCmd - args[1] = []byte(key) - i := 0 - zset.ForEach(int64(0), int64(zset.Len()), true, func(element *SortedSet.Element) bool { - value := strconv.FormatFloat(element.Score, 'f', -1, 64) - args[2+i*2] = []byte(value) - args[3+i*2] = []byte(element.Member) - i++ - return true - }) - return reply.MakeMultiBulkReply(args) -} - -// EntityToCmd serialize data entity to redis command -func EntityToCmd(key string, entity *DataEntity) *reply.MultiBulkReply { - if entity == nil { - return nil - } - var cmd *reply.MultiBulkReply - switch val := entity.Data.(type) { - case []byte: - cmd = stringToCmd(key, val) - case *List.LinkedList: - cmd = listToCmd(key, val) - case *set.Set: - cmd = setToCmd(key, val) - case dict.Dict: - cmd = hashToCmd(key, val) - case *SortedSet.SortedSet: - cmd = zSetToCmd(key, val) - } - return cmd -} - func (db *DB) startRewrite() (*os.File, int64, error) { db.pausingAof.Lock() // pausing aof defer db.pausingAof.Unlock() diff --git a/aof_test.go b/aof_test.go index 84eefb7f..83705835 100644 --- a/aof_test.go +++ b/aof_test.go @@ -2,8 +2,7 @@ package godis import ( "github.com/hdt3213/godis/config" - "github.com/hdt3213/godis/datastruct/utils" - utils2 "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/redis/reply" "io/ioutil" "os" @@ -34,31 +33,31 @@ func TestAof(t *testing.T) { for i := 0; i < size; i++ { key := strconv.Itoa(cursor) cursor++ - execSet(aofWriteDB, utils2.ToBytesList(key, utils2.RandString(8), "EX", "10000")) + execSet(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8), "EX", "10000")) keys = append(keys, key) } for i := 0; i < size; i++ { key := strconv.Itoa(cursor) cursor++ - execRPush(aofWriteDB, utils2.ToBytesList(key, utils2.RandString(8))) + execRPush(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8))) keys = append(keys, key) } for i := 0; i < size; i++ { key := strconv.Itoa(cursor) cursor++ - execHSet(aofWriteDB, utils2.ToBytesList(key, utils2.RandString(8), utils2.RandString(8))) + execHSet(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8), utils.RandString(8))) keys = append(keys, key) } for i := 0; i < size; i++ { key := strconv.Itoa(cursor) cursor++ - execSAdd(aofWriteDB, utils2.ToBytesList(key, utils2.RandString(8))) + execSAdd(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8))) keys = append(keys, key) } for i := 0; i < size; i++ { key := strconv.Itoa(cursor) cursor++ - execZAdd(aofWriteDB, utils2.ToBytesList(key, "10", utils2.RandString(8))) + execZAdd(aofWriteDB, utils.ToCmdLine(key, "10", utils.RandString(8))) keys = append(keys, key) } aofWriteDB.Close() // wait for aof finished @@ -105,44 +104,44 @@ func TestRewriteAOF(t *testing.T) { for i := 0; i < size; i++ { key := "str" + strconv.Itoa(cursor) cursor++ - execSet(aofWriteDB, utils2.ToBytesList(key, utils2.RandString(8))) - execSet(aofWriteDB, utils2.ToBytesList(key, utils2.RandString(8))) + execSet(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8))) + execSet(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8))) keys = append(keys, key) } // test ttl for i := 0; i < size; i++ { key := "str" + strconv.Itoa(cursor) cursor++ - execSet(aofWriteDB, utils2.ToBytesList(key, utils2.RandString(8), "EX", "1000")) + execSet(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8), "EX", "1000")) ttlKeys = append(ttlKeys, key) } for i := 0; i < size; i++ { key := "list" + strconv.Itoa(cursor) cursor++ - execRPush(aofWriteDB, utils2.ToBytesList(key, utils2.RandString(8))) - execRPush(aofWriteDB, utils2.ToBytesList(key, utils2.RandString(8))) + execRPush(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8))) + execRPush(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8))) keys = append(keys, key) } for i := 0; i < size; i++ { key := "hash" + strconv.Itoa(cursor) cursor++ - field := utils2.RandString(8) - execHSet(aofWriteDB, utils2.ToBytesList(key, field, utils2.RandString(8))) - execHSet(aofWriteDB, utils2.ToBytesList(key, field, utils2.RandString(8))) + field := utils.RandString(8) + execHSet(aofWriteDB, utils.ToCmdLine(key, field, utils.RandString(8))) + execHSet(aofWriteDB, utils.ToCmdLine(key, field, utils.RandString(8))) keys = append(keys, key) } for i := 0; i < size; i++ { key := "set" + strconv.Itoa(cursor) cursor++ - member := utils2.RandString(8) - execSAdd(aofWriteDB, utils2.ToBytesList(key, member)) - execSAdd(aofWriteDB, utils2.ToBytesList(key, member)) + member := utils.RandString(8) + execSAdd(aofWriteDB, utils.ToCmdLine(key, member)) + execSAdd(aofWriteDB, utils.ToCmdLine(key, member)) keys = append(keys, key) } for i := 0; i < size; i++ { key := "zset" + strconv.Itoa(cursor) cursor++ - execZAdd(aofWriteDB, utils2.ToBytesList(key, "10", utils2.RandString(8))) + execZAdd(aofWriteDB, utils.ToCmdLine(key, "10", utils.RandString(8))) keys = append(keys, key) } time.Sleep(time.Second) // wait for async goroutine finish its job @@ -167,7 +166,7 @@ func TestRewriteAOF(t *testing.T) { } } for _, key := range ttlKeys { - ret := execTTL(aofReadDB, utils2.ToBytesList(key)) + ret := execTTL(aofReadDB, utils.ToCmdLine(key)) intResult, ok := ret.(*reply.IntReply) if !ok { t.Errorf("expected int reply, actually %s", ret.ToBytes()) diff --git a/cluster/client_pool.go b/cluster/client_pool.go index ec2b7548..826c2b3c 100644 --- a/cluster/client_pool.go +++ b/cluster/client_pool.go @@ -21,7 +21,7 @@ func (f *connectionFactory) MakeObject(ctx context.Context) (*pool.PooledObject, c.Start() // all peers of cluster should use the same password if config.Properties.RequirePass != "" { - c.Send(utils.ToBytesList("AUTH", config.Properties.RequirePass)) + c.Send(utils.ToCmdLine("AUTH", config.Properties.RequirePass)) } return pool.NewPooledObject(c), nil } diff --git a/cluster/cluster.go b/cluster/cluster.go index 42167ca9..f0e89e33 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -133,7 +133,7 @@ func makeArgs(cmd string, args ...string) [][]byte { return result } -// return peer -> keys +// return peer -> writeKeys func (cluster *Cluster) groupBy(keys []string) map[string][]string { result := make(map[string][]string) for _, key := range keys { diff --git a/cluster/com.go b/cluster/com.go index 70aeb257..3d3a4ac5 100644 --- a/cluster/com.go +++ b/cluster/com.go @@ -33,7 +33,7 @@ func (cluster *Cluster) returnPeerClient(peer string, peerClient *client.Client) } // relay relays command to peer -// cannot call Prepare, Commit, Rollback of self node +// cannot call Prepare, Commit, execRollback of self node func (cluster *Cluster) relay(peer string, c redis.Connection, args [][]byte) redis.Reply { if peer == cluster.self { // to self db diff --git a/cluster/del.go b/cluster/del.go index c0ed6d45..15884362 100644 --- a/cluster/del.go +++ b/cluster/del.go @@ -6,8 +6,8 @@ import ( "strconv" ) -// Del atomically removes given keys from cluster, keys can be distributed on any node -// if the given keys are distributed on different node, Del will use try-commit-catch to remove them +// Del atomically removes given writeKeys from cluster, writeKeys can be distributed on any node +// if the given writeKeys are distributed on different node, Del will use try-commit-catch to remove them func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { if len(args) < 2 { return reply.MakeErrReply("ERR wrong number of arguments for 'del' command") @@ -18,7 +18,7 @@ func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { } groupMap := cluster.groupBy(keys) if len(groupMap) == 1 && allowFastTransaction { // do fast - for peer, group := range groupMap { // only one group + for peer, group := range groupMap { // only one peerKeys return cluster.relay(peer, c, makeArgs("DEL", group...)) } } @@ -27,14 +27,14 @@ func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { txID := cluster.idGenerator.NextID() txIDStr := strconv.FormatInt(txID, 10) rollback := false - for peer, group := range groupMap { - args := []string{txIDStr} - args = append(args, group...) + for peer, peerKeys := range groupMap { + peerArgs := []string{txIDStr, "DEL"} + peerArgs = append(peerArgs, peerKeys...) var resp redis.Reply if peer == cluster.self { - resp = prepareDel(cluster, c, makeArgs("PrepareDel", args...)) + resp = execPrepare(cluster, c, makeArgs("Prepare", peerArgs...)) } else { - resp = cluster.relay(peer, c, makeArgs("PrepareDel", args...)) + resp = cluster.relay(peer, c, makeArgs("Prepare", peerArgs...)) } if reply.IsErrorReply(resp) { errReply = resp @@ -63,39 +63,3 @@ func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { } return errReply } - -// args: PrepareDel id keys... -func prepareDel(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) < 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'preparedel' command") - } - txID := string(args[1]) - keys := make([]string, 0, len(args)-2) - for i := 2; i < len(args); i++ { - arg := args[i] - keys = append(keys, string(arg)) - } - txArgs := makeArgs("DEL", keys...) // actual args for cluster.db - tx := NewTransaction(cluster, c, txID, txArgs, keys) - cluster.transactions.Put(txID, tx) - err := tx.prepare() - if err != nil { - return reply.MakeErrReply(err.Error()) - } - return &reply.OkReply{} -} - -// invoker should provide lock -func commitDel(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply { - keys := make([]string, len(tx.args)) - for i, v := range tx.args { - keys[i] = string(v) - } - keys = keys[1:] - - deleted := cluster.db.Removes(keys...) - if deleted > 0 { - cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args)) - } - return reply.MakeIntReply(int64(deleted)) -} diff --git a/cluster/mset.go b/cluster/mset.go index e739e473..39ac5c58 100644 --- a/cluster/mset.go +++ b/cluster/mset.go @@ -2,13 +2,12 @@ package cluster import ( "fmt" - "github.com/hdt3213/godis" "github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/redis/reply" "strconv" ) -// MGet atomically get multi key-value from cluster, keys can be distributed on any node +// MGet atomically get multi key-value from cluster, writeKeys can be distributed on any node func MGet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { if len(args) < 2 { return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command") @@ -39,49 +38,7 @@ func MGet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { return reply.MakeMultiBulkReply(result) } -// args: PrepareMSet id keys... -func prepareMSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) < 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'preparemset' command") - } - txID := string(args[1]) - size := (len(args) - 2) / 2 - keys := make([]string, size) - for i := 0; i < size; i++ { - keys[i] = string(args[2*i+2]) - } - - txArgs := [][]byte{ - []byte("MSet"), - } // actual args for cluster.db - txArgs = append(txArgs, args[2:]...) - tx := NewTransaction(cluster, c, txID, txArgs, keys) - cluster.transactions.Put(txID, tx) - err := tx.prepare() - if err != nil { - return reply.MakeErrReply(err.Error()) - } - return &reply.OkReply{} -} - -// invoker should provide lock -func commitMSet(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply { - size := len(tx.args) / 2 - keys := make([]string, size) - values := make([][]byte, size) - for i := 0; i < size; i++ { - keys[i] = string(tx.args[2*i+1]) - values[i] = tx.args[2*i+2] - } - for i, key := range keys { - value := values[i] - cluster.db.PutEntity(key, &godis.DataEntity{Data: value}) - } - cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args)) - return &reply.OkReply{} -} - -// MSet atomically sets multi key-value in cluster, keys can be distributed on any node +// MSet atomically sets multi key-value in cluster, writeKeys can be distributed on any node func MSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { argCount := len(args) - 1 if argCount%2 != 0 || argCount < 1 { @@ -109,15 +66,15 @@ func MSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { txIDStr := strconv.FormatInt(txID, 10) rollback := false for peer, group := range groupMap { - peerArgs := []string{txIDStr} + peerArgs := []string{txIDStr, "MSET"} for _, k := range group { peerArgs = append(peerArgs, k, valueMap[k]) } var resp redis.Reply if peer == cluster.self { - resp = prepareMSet(cluster, c, makeArgs("PrepareMSet", peerArgs...)) + resp = execPrepare(cluster, c, makeArgs("Prepare", peerArgs...)) } else { - resp = cluster.relay(peer, c, makeArgs("PrepareMSet", peerArgs...)) + resp = cluster.relay(peer, c, makeArgs("Prepare", peerArgs...)) } if reply.IsErrorReply(resp) { errReply = resp @@ -139,7 +96,7 @@ func MSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { } -// MSetNX sets multi key-value in database, only if none of the given keys exist and all given keys are on the same node +// MSetNX sets multi key-value in database, only if none of the given writeKeys exist and all given writeKeys are on the same node func MSetNX(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { argCount := len(args) - 1 if argCount%2 != 0 || argCount < 1 { diff --git a/cluster/pubsub_test.go b/cluster/pubsub_test.go index 5bafba34..98b27140 100644 --- a/cluster/pubsub_test.go +++ b/cluster/pubsub_test.go @@ -12,9 +12,9 @@ func TestPublish(t *testing.T) { channel := utils.RandString(5) msg := utils.RandString(5) conn := &connection.FakeConn{} - Subscribe(testCluster, conn, utils.ToBytesList("SUBSCRIBE", channel)) + Subscribe(testCluster, conn, utils.ToCmdLine("SUBSCRIBE", channel)) conn.Clean() // clean subscribe success - Publish(testCluster, conn, utils.ToBytesList("PUBLISH", channel, msg)) + Publish(testCluster, conn, utils.ToCmdLine("PUBLISH", channel, msg)) data := conn.Bytes() ret, err := parser.ParseOne(data) if err != nil { @@ -28,19 +28,19 @@ func TestPublish(t *testing.T) { }) // unsubscribe - UnSubscribe(testCluster, conn, utils.ToBytesList("UNSUBSCRIBE", channel)) + UnSubscribe(testCluster, conn, utils.ToCmdLine("UNSUBSCRIBE", channel)) conn.Clean() - Publish(testCluster, conn, utils.ToBytesList("PUBLISH", channel, msg)) + Publish(testCluster, conn, utils.ToCmdLine("PUBLISH", channel, msg)) data = conn.Bytes() if len(data) > 0 { t.Error("expect no msg") } // unsubscribe all - Subscribe(testCluster, conn, utils.ToBytesList("SUBSCRIBE", channel)) - UnSubscribe(testCluster, conn, utils.ToBytesList("UNSUBSCRIBE")) + Subscribe(testCluster, conn, utils.ToCmdLine("SUBSCRIBE", channel)) + UnSubscribe(testCluster, conn, utils.ToCmdLine("UNSUBSCRIBE")) conn.Clean() - Publish(testCluster, conn, utils.ToBytesList("PUBLISH", channel, msg)) + Publish(testCluster, conn, utils.ToCmdLine("PUBLISH", channel, msg)) data = conn.Bytes() if len(data) > 0 { t.Error("expect no msg") diff --git a/cluster/rename_test.go b/cluster/rename_test.go index 5c1fcc8e..41a39138 100644 --- a/cluster/rename_test.go +++ b/cluster/rename_test.go @@ -10,22 +10,22 @@ import ( func TestRename(t *testing.T) { testDB := testCluster.db - testDB.Exec(nil, utils.ToBytesList("FlushALL")) + testDB.Exec(nil, utils.ToCmdLine("FlushALL")) key := utils.RandString(10) value := utils.RandString(10) newKey := key + utils.RandString(2) - testDB.Exec(nil, utils.ToBytesList("SET", key, value, "ex", "1000")) - result := Rename(testCluster, nil, utils.ToBytesList("RENAME", key, newKey)) + testDB.Exec(nil, utils.ToCmdLine("SET", key, value, "ex", "1000")) + result := Rename(testCluster, nil, utils.ToCmdLine("RENAME", key, newKey)) if _, ok := result.(*reply.OkReply); !ok { t.Error("expect ok") return } - result = testDB.Exec(nil, utils.ToBytesList("EXISTS", key)) + result = testDB.Exec(nil, utils.ToCmdLine("EXISTS", key)) asserts.AssertIntReply(t, result, 0) - result = testDB.Exec(nil, utils.ToBytesList("EXISTS", newKey)) + result = testDB.Exec(nil, utils.ToCmdLine("EXISTS", newKey)) asserts.AssertIntReply(t, result, 1) // check ttl - result = testDB.Exec(nil, utils.ToBytesList("TTL", newKey)) + result = testDB.Exec(nil, utils.ToCmdLine("TTL", newKey)) intResult, ok := result.(*reply.IntReply) if !ok { t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) @@ -39,19 +39,19 @@ func TestRename(t *testing.T) { func TestRenameNx(t *testing.T) { testDB := testCluster.db - testDB.Exec(nil, utils.ToBytesList("FlushALL")) + testDB.Exec(nil, utils.ToCmdLine("FlushALL")) key := utils.RandString(10) value := utils.RandString(10) newKey := key + utils.RandString(2) - testCluster.db.Exec(nil, utils.ToBytesList("SET", key, value, "ex", "1000")) - result := RenameNx(testCluster, nil, utils.ToBytesList("RENAMENX", key, newKey)) + testCluster.db.Exec(nil, utils.ToCmdLine("SET", key, value, "ex", "1000")) + result := RenameNx(testCluster, nil, utils.ToCmdLine("RENAMENX", key, newKey)) asserts.AssertIntReply(t, result, 1) - result = testDB.Exec(nil, utils.ToBytesList("EXISTS", key)) + result = testDB.Exec(nil, utils.ToCmdLine("EXISTS", key)) asserts.AssertIntReply(t, result, 0) - result = testDB.Exec(nil, utils.ToBytesList("EXISTS", newKey)) + result = testDB.Exec(nil, utils.ToCmdLine("EXISTS", newKey)) asserts.AssertIntReply(t, result, 1) - result = testDB.Exec(nil, utils.ToBytesList("TTL", newKey)) + result = testDB.Exec(nil, utils.ToCmdLine("TTL", newKey)) intResult, ok := result.(*reply.IntReply) if !ok { t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) diff --git a/cluster/router.go b/cluster/router.go index 7d06d8d2..249c6cf3 100644 --- a/cluster/router.go +++ b/cluster/router.go @@ -2,15 +2,16 @@ package cluster import "github.com/hdt3213/godis/interface/redis" +type CmdLine = [][]byte + func makeRouter() map[string]CmdFunc { routerMap := make(map[string]CmdFunc) routerMap["ping"] = ping - routerMap["commit"] = commit - routerMap["rollback"] = Rollback + routerMap["prepare"] = execPrepare + routerMap["commit"] = execCommit + routerMap["rollback"] = execRollback routerMap["del"] = Del - routerMap["preparedel"] = prepareDel - routerMap["preparemset"] = prepareMSet routerMap["expire"] = defaultFunc routerMap["expireat"] = defaultFunc @@ -108,7 +109,7 @@ func makeRouter() map[string]CmdFunc { routerMap["flushdb"] = FlushDB routerMap["flushall"] = FlushAll - //routerMap["keys"] = Keys + //routerMap["writeKeys"] = Keys return routerMap } diff --git a/cluster/transaction.go b/cluster/transaction.go index 1ebff293..7c08d33e 100644 --- a/cluster/transaction.go +++ b/cluster/transaction.go @@ -2,13 +2,11 @@ package cluster import ( "fmt" - "github.com/hdt3213/godis" "github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/timewheel" "github.com/hdt3213/godis/redis/reply" "strconv" - "strings" "sync" "time" ) @@ -16,13 +14,14 @@ import ( // Transaction stores state and data for a try-commit-catch distributed transaction type Transaction struct { id string // transaction id - args [][]byte // cmd args + cmdLine [][]byte // cmd cmdLine cluster *Cluster conn redis.Connection - keys []string // related keys - lockedKeys bool - undoLog map[string][][]byte // store data for undoLog + writeKeys []string + readKeys []string + keysLocked bool + undoLog []CmdLine status int8 mu *sync.Mutex @@ -43,13 +42,12 @@ func genTaskKey(txID string) string { } // NewTransaction creates a try-commit-catch distributed transaction -func NewTransaction(cluster *Cluster, c redis.Connection, id string, args [][]byte, keys []string) *Transaction { +func NewTransaction(cluster *Cluster, c redis.Connection, id string, cmdLine [][]byte) *Transaction { return &Transaction{ id: id, - args: args, + cmdLine: cmdLine, cluster: cluster, conn: c, - keys: keys, status: createdStatus, mu: new(sync.Mutex), } @@ -58,16 +56,16 @@ func NewTransaction(cluster *Cluster, c redis.Connection, id string, args [][]by // Reentrant // invoker should hold tx.mu func (tx *Transaction) lockKeys() { - if !tx.lockedKeys { - tx.cluster.db.Locks(tx.keys...) - tx.lockedKeys = true + if !tx.keysLocked { + tx.cluster.db.RWLocks(tx.writeKeys, tx.readKeys) + tx.keysLocked = true } } func (tx *Transaction) unLockKeys() { - if tx.lockedKeys { - tx.cluster.db.UnLocks(tx.keys...) - tx.lockedKeys = false + if tx.keysLocked { + tx.cluster.db.RWUnLocks(tx.writeKeys, tx.readKeys) + tx.keysLocked = false } } @@ -75,20 +73,13 @@ func (tx *Transaction) unLockKeys() { func (tx *Transaction) prepare() error { tx.mu.Lock() defer tx.mu.Unlock() - // lock keys + + tx.writeKeys, tx.readKeys = tx.cluster.db.GetRelatedKeys(tx.cmdLine) + // lock writeKeys tx.lockKeys() // build undoLog - tx.undoLog = make(map[string][][]byte) - for _, key := range tx.keys { - entity, ok := tx.cluster.db.GetEntity(key) - if ok { - blob := godis.EntityToCmd(key, entity) - tx.undoLog[key] = blob.Args - } else { - tx.undoLog[key] = nil // entity was nil, should be removed while rollback - } - } + tx.undoLog = tx.cluster.db.GetUndoLogs(tx.cmdLine) tx.status = preparedStatus taskKey := genTaskKey(tx.id) timewheel.Delay(maxLockTime, taskKey, func() { @@ -112,25 +103,35 @@ func (tx *Transaction) rollback() error { return nil } tx.lockKeys() - for key, blob := range tx.undoLog { - if len(blob) > 0 { - tx.cluster.db.Remove(key) - tx.cluster.db.Exec(nil, blob) - } else { - tx.cluster.db.Remove(key) - } + for _, cmdLine := range tx.undoLog { + tx.cluster.db.ExecWithLock(cmdLine) } tx.unLockKeys() tx.status = rolledBackStatus return nil } -// Rollback rollbacks local transaction -func Rollback(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) != 2 { +// cmdLine: Prepare id cmdName args... +func execPrepare(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply { + if len(cmdLine) < 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'preparedel' command") + } + txID := string(cmdLine[1]) + tx := NewTransaction(cluster, c, txID, cmdLine[2:]) + cluster.transactions.Put(txID, tx) + err := tx.prepare() + if err != nil { + return reply.MakeErrReply(err.Error()) + } + return &reply.OkReply{} +} + +// execRollback rollbacks local transaction +func execRollback(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply { + if len(cmdLine) != 2 { return reply.MakeErrReply("ERR wrong number of arguments for 'rollback' command") } - txID := string(args[1]) + txID := string(cmdLine[1]) raw, ok := cluster.transactions.Get(txID) if !ok { return reply.MakeIntReply(0) @@ -147,12 +148,12 @@ func Rollback(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { return reply.MakeIntReply(1) } -// commit commits local transaction as a worker when receive commit command from coordinator -func commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) != 2 { +// execCommit commits local transaction as a worker when receive execCommit command from coordinator +func execCommit(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply { + if len(cmdLine) != 2 { return reply.MakeErrReply("ERR wrong number of arguments for 'commit' command") } - txID := string(args[1]) + txID := string(cmdLine[1]) raw, ok := cluster.transactions.Get(txID) if !ok { return reply.MakeIntReply(0) @@ -162,13 +163,7 @@ func commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { tx.mu.Lock() defer tx.mu.Unlock() - cmd := strings.ToLower(string(tx.args[0])) - var result redis.Reply - if cmd == "del" { - result = commitDel(cluster, c, tx) - } else if cmd == "mset" { - result = commitMSet(cluster, c, tx) - } + result := cluster.db.ExecWithLock(tx.cmdLine) if reply.IsErrorReply(result) { // failed @@ -186,7 +181,7 @@ func commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { return result } -// requestCommit commands all node commit transaction as coordinator +// requestCommit commands all node to commit transaction as coordinator func requestCommit(cluster *Cluster, c redis.Connection, txID int64, peers map[string][]string) ([]redis.Reply, reply.ErrorReply) { var errReply reply.ErrorReply txIDStr := strconv.FormatInt(txID, 10) @@ -194,7 +189,7 @@ func requestCommit(cluster *Cluster, c redis.Connection, txID int64, peers map[s for peer := range peers { var resp redis.Reply if peer == cluster.self { - resp = commit(cluster, c, makeArgs("commit", txIDStr)) + resp = execCommit(cluster, c, makeArgs("commit", txIDStr)) } else { resp = cluster.relay(peer, c, makeArgs("commit", txIDStr)) } @@ -216,7 +211,7 @@ func requestRollback(cluster *Cluster, c redis.Connection, txID int64, peers map txIDStr := strconv.FormatInt(txID, 10) for peer := range peers { if peer == cluster.self { - Rollback(cluster, c, makeArgs("rollback", txIDStr)) + execRollback(cluster, c, makeArgs("rollback", txIDStr)) } else { cluster.relay(peer, c, makeArgs("rollback", txIDStr)) } diff --git a/cluster/transaction_test.go b/cluster/transaction_test.go index 10fc5fbc..304ca845 100644 --- a/cluster/transaction_test.go +++ b/cluster/transaction_test.go @@ -14,10 +14,10 @@ func TestRollback(t *testing.T) { txIDStr := strconv.FormatInt(txID, 10) keys := []string{"a", "b"} groupMap := testCluster.groupBy(keys) - args := []string{txIDStr} + args := []string{txIDStr, "DEL"} args = append(args, keys...) testCluster.Exec(nil, toArgs("SET", "a", "a")) - ret := prepareDel(testCluster, nil, makeArgs("PrepareDel", args...)) + ret := execPrepare(testCluster, nil, makeArgs("Prepare", args...)) asserts.AssertNotError(t, ret) requestRollback(testCluster, nil, txID, groupMap) ret = testCluster.Exec(nil, toArgs("GET", "a")) @@ -27,10 +27,10 @@ func TestRollback(t *testing.T) { FlushAll(testCluster, nil, toArgs("FLUSHALL")) txID = rand.Int63() txIDStr = strconv.FormatInt(txID, 10) - args = []string{txIDStr} + args = []string{txIDStr, "DEL"} args = append(args, keys...) testCluster.Exec(nil, toArgs("SET", "a", "a")) - ret = prepareDel(testCluster, nil, makeArgs("PrepareDel", args...)) + ret = execPrepare(testCluster, nil, makeArgs("Prepare", args...)) asserts.AssertNotError(t, ret) _, err := requestCommit(testCluster, nil, txID, groupMap) if err != nil { diff --git a/cmd/banner.txt b/cmd/banner.txt deleted file mode 100644 index 70e838e8..00000000 --- a/cmd/banner.txt +++ /dev/null @@ -1,6 +0,0 @@ - ______ ___ - / ____/___ ____/ (_)____ - / / __/ __ \/ __ / / ___/ -/ /_/ / /_/ / /_/ / (__ ) -\____/\____/\__,_/_/____/ - diff --git a/datastruct/list/linked.go b/datastruct/list/linked.go index 41fb57d0..120545bd 100644 --- a/datastruct/list/linked.go +++ b/datastruct/list/linked.go @@ -1,6 +1,6 @@ package list -import "github.com/hdt3213/godis/datastruct/utils" +import "github.com/hdt3213/godis/lib/utils" // LinkedList is doubly linked list type LinkedList struct { diff --git a/db.go b/db.go index 77638ca9..0f88be6e 100644 --- a/db.go +++ b/db.go @@ -49,18 +49,25 @@ type DB struct { pausingAof sync.RWMutex } -// PreFunc analyses command line when queued command to `multi` -// returns related keys and undo commands -type PreFunc func(args [][]byte) ([]string, [][][]byte) +// DataEntity stores data bound to a key, including a string, list, hash, set and so on +type DataEntity struct { + Data interface{} +} // ExecFunc is interface for command executor // args don't include cmd line type ExecFunc func(db *DB, args [][]byte) redis.Reply -// DataEntity stores data bound to a key, including a string, list, hash, set and so on -type DataEntity struct { - Data interface{} -} +// PreFunc analyses command line when queued command to `multi` +// returns related write keys and read keys +type PreFunc func(args [][]byte) ([]string, []string) + +// CmdLine is alias for [][]byte, represents a command line +type CmdLine = [][]byte + +// UndoFunc returns undo logs for the given command line +// execute from head to tail when undo +type UndoFunc func(db *DB, args [][]byte) []CmdLine // MakeDB create DB instance and start it func MakeDB() *DB { @@ -150,6 +157,8 @@ func (db *DB) Remove(key string) { db.stopWorld.Wait() db.data.Remove(key) db.ttlMap.Remove(key) + taskKey := genExpireTask(key) + timewheel.Cancel(taskKey) } // Removes the given keys from db @@ -159,8 +168,7 @@ func (db *DB) Removes(keys ...string) (deleted int) { for _, key := range keys { _, exists := db.data.Get(key) if exists { - db.data.Remove(key) - db.ttlMap.Remove(key) + db.Remove(key) deleted++ } } @@ -180,46 +188,6 @@ func (db *DB) Flush() { /* ---- Lock Function ----- */ -// Lock locks key for writing (exclusive lock) -func (db *DB) Lock(key string) { - db.locker.Lock(key) -} - -// RLock locks key for read (shared lock) -func (db *DB) RLock(key string) { - db.locker.RLock(key) -} - -// UnLock release exclusive lock -func (db *DB) UnLock(key string) { - db.locker.UnLock(key) -} - -// RUnLock release shared lock -func (db *DB) RUnLock(key string) { - db.locker.RUnLock(key) -} - -// Locks lock keys for writing (exclusive lock) -func (db *DB) Locks(keys ...string) { - db.locker.Locks(keys...) -} - -// RLocks lock keys for read (shared lock) -func (db *DB) RLocks(keys ...string) { - db.locker.RLocks(keys...) -} - -// UnLocks release exclusive locks -func (db *DB) UnLocks(keys ...string) { - db.locker.UnLocks(keys...) -} - -// RUnLocks release shared locks -func (db *DB) RUnLocks(keys ...string) { - db.locker.RUnLocks(keys...) -} - // RWLocks lock keys for writing and reading func (db *DB) RWLocks(writeKeys []string, readKeys []string) { db.locker.RWLocks(writeKeys, readKeys) @@ -242,8 +210,20 @@ func (db *DB) Expire(key string, expireTime time.Time) { db.ttlMap.Put(key, expireTime) taskKey := genExpireTask(key) timewheel.At(expireTime, taskKey, func() { + keys := []string{key} + db.RWLocks(keys, nil) + defer db.RWUnLocks(keys, nil) + // check-lock-check, ttl may be updated during waiting lock logger.Info("expire " + key) - db.Remove(key) + rawExpireTime, ok := db.ttlMap.Get(key) + if !ok { + return + } + expireTime, _ := rawExpireTime.(time.Time) + expired := time.Now().After(expireTime) + if expired { + db.Remove(key) + } }) } diff --git a/exec.go b/exec.go index 6fa0a10f..7cc1d158 100644 --- a/exec.go +++ b/exec.go @@ -11,8 +11,8 @@ import ( ) // Exec executes command -// parameter `cmdArgs` contains command and its arguments, for example: "set key value" -func (db *DB) Exec(c redis.Connection, cmdArgs [][]byte) (result redis.Reply) { +// parameter `cmdLine` contains command and its arguments, for example: "set key value" +func (db *DB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Reply) { defer func() { if err := recover(); err != nil { logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack()))) @@ -20,44 +20,57 @@ func (db *DB) Exec(c redis.Connection, cmdArgs [][]byte) (result redis.Reply) { } }() - cmdName := strings.ToLower(string(cmdArgs[0])) + cmdName := strings.ToLower(string(cmdLine[0])) // authenticate if cmdName == "auth" { - return Auth(db, c, cmdArgs[1:]) + return Auth(db, c, cmdLine[1:]) } if !isAuthenticated(c) { return reply.MakeErrReply("NOAUTH Authentication required") } // special commands + done := false + result, done = execSpecialCmd(c, cmdLine, cmdName, db) + if done { + return result + } + if c != nil && c.InMultiState() { + return enqueueCmd(db, c, cmdLine) + } + + // normal commands + return execNormalCommand(db, cmdLine) +} + +func execSpecialCmd(c redis.Connection, cmdLine [][]byte, cmdName string, db *DB) (redis.Reply, bool) { if cmdName == "subscribe" { - if len(cmdArgs) < 2 { - return reply.MakeArgNumErrReply("subscribe") + if len(cmdLine) < 2 { + return reply.MakeArgNumErrReply("subscribe"), true } - return pubsub.Subscribe(db.hub, c, cmdArgs[1:]) + return pubsub.Subscribe(db.hub, c, cmdLine[1:]), true } else if cmdName == "publish" { - return pubsub.Publish(db.hub, cmdArgs[1:]) + return pubsub.Publish(db.hub, cmdLine[1:]), true } else if cmdName == "unsubscribe" { - return pubsub.UnSubscribe(db.hub, c, cmdArgs[1:]) + return pubsub.UnSubscribe(db.hub, c, cmdLine[1:]), true } else if cmdName == "bgrewriteaof" { // aof.go imports router.go, router.go cannot import BGRewriteAOF from aof.go - return BGRewriteAOF(db, cmdArgs[1:]) - } - - // normal commands - cmd, ok := cmdTable[cmdName] - if !ok { - return reply.MakeErrReply("ERR unknown command '" + cmdName + "'") - } - if !validateArity(cmd.arity, cmdArgs) { - return reply.MakeArgNumErrReply(cmdName) - } - - fun := cmd.executor - if len(cmdArgs) > 1 { - result = fun(db, cmdArgs[1:]) - } else { - result = fun(db, [][]byte{}) + return BGRewriteAOF(db, cmdLine[1:]), true + } else if cmdName == "multi" { + if len(cmdLine) != 1 { + return reply.MakeArgNumErrReply(cmdName), true + } + return startMulti(db, c), true + } else if cmdName == "discard" { + if len(cmdLine) != 1 { + return reply.MakeArgNumErrReply(cmdName), true + } + return discardMulti(db, c), true + } else if cmdName == "exec" { + if len(cmdLine) != 1 { + return reply.MakeArgNumErrReply(cmdName), true + } + return execMulti(db, c), true } - return + return nil, false } diff --git a/exec_helper.go b/exec_helper.go new file mode 100644 index 00000000..e092ae24 --- /dev/null +++ b/exec_helper.go @@ -0,0 +1,67 @@ +package godis + +import ( + "github.com/hdt3213/godis/interface/redis" + "github.com/hdt3213/godis/redis/reply" + "strings" +) + +func execNormalCommand(db *DB, cmdArgs [][]byte) redis.Reply { + cmdName := strings.ToLower(string(cmdArgs[0])) + cmd, ok := cmdTable[cmdName] + if !ok { + return reply.MakeErrReply("ERR unknown command '" + cmdName + "'") + } + if !validateArity(cmd.arity, cmdArgs) { + return reply.MakeArgNumErrReply(cmdName) + } + + prepare := cmd.prepare + write, read := prepare(cmdArgs[1:]) + db.RWLocks(write, read) + defer db.RWUnLocks(write, read) + fun := cmd.executor + return fun(db, cmdArgs[1:]) +} + +// ExecWithLock executes normal commands, invoker should provide locks +func (db *DB) ExecWithLock(cmdLine [][]byte) redis.Reply { + cmdName := strings.ToLower(string(cmdLine[0])) + cmd, ok := cmdTable[cmdName] + if !ok { + return reply.MakeErrReply("ERR unknown command '" + cmdName + "'") + } + if !validateArity(cmd.arity, cmdLine) { + return reply.MakeArgNumErrReply(cmdName) + } + fun := cmd.executor + return fun(db, cmdLine[1:]) +} + +// GetRelatedKeys analysis related keys +func (db *DB) GetRelatedKeys(cmdLine [][]byte) ([]string, []string) { + cmdName := strings.ToLower(string(cmdLine[0])) + cmd, ok := cmdTable[cmdName] + if !ok { + return nil, nil + } + prepare := cmd.prepare + if prepare == nil { + return nil, nil + } + return prepare(cmdLine[1:]) +} + +// GetUndoLogs return rollback commands +func (db *DB) GetUndoLogs(cmdLine [][]byte) []CmdLine { + cmdName := strings.ToLower(string(cmdLine[0])) + cmd, ok := cmdTable[cmdName] + if !ok { + return nil + } + undo := cmd.undo + if undo == nil { + return nil + } + return undo(db, cmdLine[1:]) +} diff --git a/geo.go b/geo.go index e062eb8d..62a8f68e 100644 --- a/geo.go +++ b/geo.go @@ -39,10 +39,6 @@ func execGeoAdd(db *DB, args [][]byte) redis.Reply { } } - // lock - db.Lock(key) - defer db.UnLock(key) - // get or init entity sortedSet, _, errReply := db.getOrInitSortedSet(key) if errReply != nil { @@ -61,6 +57,16 @@ func execGeoAdd(db *DB, args [][]byte) redis.Reply { return reply.MakeIntReply(int64(i)) } +func undoGeoAdd(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + size := (len(args) - 1) / 3 + fields := make([]string, size) + for i := 0; i < size; i++ { + fields[i] = string(args[3*i+3]) + } + return rollbackZSetFields(db, key, fields...) +} + // execGeoPos returns location of a member func execGeoPos(db *DB, args [][]byte) redis.Reply { // parse args @@ -257,10 +263,10 @@ func geoRadius0(sortedSet *sortedset.SortedSet, lat float64, lng float64, radius } func init() { - RegisterCommand("GeoAdd", execGeoAdd, nil, -5) - RegisterCommand("GeoPos", execGeoPos, nil, -2) - RegisterCommand("GeoDist", execGeoDist, nil, -4) - RegisterCommand("GeoHash", execGeoHash, nil, -2) - RegisterCommand("GeoRadius", execGeoRadius, nil, -6) - RegisterCommand("GeoRadiusByMember", execGeoRadiusByMember, nil, -5) + RegisterCommand("GeoAdd", execGeoAdd, writeFirstKey, undoGeoAdd, -5) + RegisterCommand("GeoPos", execGeoPos, readFirstKey, nil, -2) + RegisterCommand("GeoDist", execGeoDist, readFirstKey, nil, -4) + RegisterCommand("GeoHash", execGeoHash, readFirstKey, nil, -2) + RegisterCommand("GeoRadius", execGeoRadius, readFirstKey, nil, -6) + RegisterCommand("GeoRadiusByMember", execGeoRadiusByMember, readFirstKey, nil, -5) } diff --git a/geo_test.go b/geo_test.go index 1e73abf5..232d69c6 100644 --- a/geo_test.go +++ b/geo_test.go @@ -10,52 +10,52 @@ import ( ) func TestGeoHash(t *testing.T) { - execFlushDB(testDB, utils.ToBytesList()) + execFlushDB(testDB, utils.ToCmdLine()) key := utils.RandString(10) pos := utils.RandString(10) - result := execGeoAdd(testDB, utils.ToBytesList(key, "13.361389", "38.115556", pos)) + result := execGeoAdd(testDB, utils.ToCmdLine(key, "13.361389", "38.115556", pos)) asserts.AssertIntReply(t, result, 1) - result = execGeoHash(testDB, utils.ToBytesList(key, pos)) + result = execGeoHash(testDB, utils.ToCmdLine(key, pos)) asserts.AssertMultiBulkReply(t, result, []string{"sqc8b49rnys00"}) } func TestGeoRadius(t *testing.T) { - execFlushDB(testDB, utils.ToBytesList()) + execFlushDB(testDB, utils.ToCmdLine()) key := utils.RandString(10) pos1 := utils.RandString(10) pos2 := utils.RandString(10) - execGeoAdd(testDB, utils.ToBytesList(key, + execGeoAdd(testDB, utils.ToCmdLine(key, "13.361389", "38.115556", pos1, "15.087269", "37.502669", pos2, )) - result := execGeoRadius(testDB, utils.ToBytesList(key, "15", "37", "200", "km")) + result := execGeoRadius(testDB, utils.ToCmdLine(key, "15", "37", "200", "km")) asserts.AssertMultiBulkReplySize(t, result, 2) } func TestGeoRadiusByMember(t *testing.T) { - execFlushDB(testDB, utils.ToBytesList()) + execFlushDB(testDB, utils.ToCmdLine()) key := utils.RandString(10) pos1 := utils.RandString(10) pos2 := utils.RandString(10) pivot := utils.RandString(10) - execGeoAdd(testDB, utils.ToBytesList(key, + execGeoAdd(testDB, utils.ToCmdLine(key, "13.361389", "38.115556", pos1, "17.087269", "38.502669", pos2, "13.583333", "37.316667", pivot, )) - result := execGeoRadiusByMember(testDB, utils.ToBytesList(key, pivot, "100", "km")) + result := execGeoRadiusByMember(testDB, utils.ToCmdLine(key, pivot, "100", "km")) asserts.AssertMultiBulkReplySize(t, result, 2) } func TestGeoPos(t *testing.T) { - execFlushDB(testDB, utils.ToBytesList()) + execFlushDB(testDB, utils.ToCmdLine()) key := utils.RandString(10) pos1 := utils.RandString(10) pos2 := utils.RandString(10) - execGeoAdd(testDB, utils.ToBytesList(key, + execGeoAdd(testDB, utils.ToCmdLine(key, "13.361389", "38.115556", pos1, )) - result := execGeoPos(testDB, utils.ToBytesList(key, pos1, pos2)) + result := execGeoPos(testDB, utils.ToCmdLine(key, pos1, pos2)) expected := "*2\r\n*2\r\n$18\r\n13.361386698670685\r\n$17\r\n38.11555536696687\r\n*0\r\n" if string(result.ToBytes()) != expected { t.Error("test failed") @@ -63,15 +63,15 @@ func TestGeoPos(t *testing.T) { } func TestGeoDist(t *testing.T) { - execFlushDB(testDB, utils.ToBytesList()) + execFlushDB(testDB, utils.ToCmdLine()) key := utils.RandString(10) pos1 := utils.RandString(10) pos2 := utils.RandString(10) - execGeoAdd(testDB, utils.ToBytesList(key, + execGeoAdd(testDB, utils.ToCmdLine(key, "13.361389", "38.115556", pos1, "15.087269", "37.502669", pos2, )) - result := execGeoDist(testDB, utils.ToBytesList(key, pos1, pos2, "km")) + result := execGeoDist(testDB, utils.ToCmdLine(key, pos1, pos2, "km")) bulkReply, ok := result.(*reply.BulkReply) if !ok { t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes())) @@ -86,7 +86,7 @@ func TestGeoDist(t *testing.T) { t.Errorf("expected 166.274, actual: %f", dist) } - result = execGeoDist(testDB, utils.ToBytesList(key, pos1, pos2, "m")) + result = execGeoDist(testDB, utils.ToCmdLine(key, pos1, pos2, "m")) bulkReply, ok = result.(*reply.BulkReply) if !ok { t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes())) diff --git a/hash.go b/hash.go index bf4a03ce..e4696797 100644 --- a/hash.go +++ b/hash.go @@ -43,10 +43,6 @@ func execHSet(db *DB, args [][]byte) redis.Reply { field := string(args[1]) value := args[2] - // lock - db.Lock(key) - defer db.UnLock(key) - // get or init entity dict, _, errReply := db.getOrInitDict(key) if errReply != nil { @@ -58,6 +54,12 @@ func execHSet(db *DB, args [][]byte) redis.Reply { return reply.MakeIntReply(int64(result)) } +func undoHSet(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + field := string(args[1]) + return rollbackHashFields(db, key, field) +} + // execHSetNX sets field in hash table only if field not exists func execHSetNX(db *DB, args [][]byte) redis.Reply { // parse args @@ -65,9 +67,6 @@ func execHSetNX(db *DB, args [][]byte) redis.Reply { field := string(args[1]) value := args[2] - db.Lock(key) - defer db.UnLock(key) - dict, _, errReply := db.getOrInitDict(key) if errReply != nil { return errReply @@ -87,9 +86,6 @@ func execHGet(db *DB, args [][]byte) redis.Reply { key := string(args[0]) field := string(args[1]) - db.RLock(key) - defer db.RUnLock(key) - // get entity dict, errReply := db.getAsDict(key) if errReply != nil { @@ -113,9 +109,6 @@ func execHExists(db *DB, args [][]byte) redis.Reply { key := string(args[0]) field := string(args[1]) - db.RLock(key) - defer db.RUnLock(key) - // get entity dict, errReply := db.getAsDict(key) if errReply != nil { @@ -142,9 +135,6 @@ func execHDel(db *DB, args [][]byte) redis.Reply { fields[i] = string(v) } - db.Lock(key) - defer db.UnLock(key) - // get entity dict, errReply := db.getAsDict(key) if errReply != nil { @@ -169,14 +159,21 @@ func execHDel(db *DB, args [][]byte) redis.Reply { return reply.MakeIntReply(int64(deleted)) } +func undoHDel(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + fields := make([]string, len(args)-1) + fieldArgs := args[1:] + for i, v := range fieldArgs { + fields[i] = string(v) + } + return rollbackHashFields(db, key, fields...) +} + // execHLen gets number of fields in hash table func execHLen(db *DB, args [][]byte) redis.Reply { // parse args key := string(args[0]) - db.RLock(key) - defer db.RUnLock(key) - dict, errReply := db.getAsDict(key) if errReply != nil { return errReply @@ -202,10 +199,6 @@ func execHMSet(db *DB, args [][]byte) redis.Reply { values[i] = args[2*i+2] } - // lock key - db.Lock(key) - defer db.UnLock(key) - // get or init entity dict, _, errReply := db.getOrInitDict(key) if errReply != nil { @@ -221,8 +214,18 @@ func execHMSet(db *DB, args [][]byte) redis.Reply { return &reply.OkReply{} } -// HMGet gets multi fields in hash table -func HMGet(db *DB, args [][]byte) redis.Reply { +func undoHMSet(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + size := (len(args) - 1) / 2 + fields := make([]string, size) + for i := 0; i < size; i++ { + fields[i] = string(args[2*i+1]) + } + return rollbackHashFields(db, key, fields...) +} + +// execHMGet gets multi fields in hash table +func execHMGet(db *DB, args [][]byte) redis.Reply { key := string(args[0]) size := len(args) - 1 fields := make([]string, size) @@ -230,9 +233,6 @@ func HMGet(db *DB, args [][]byte) redis.Reply { fields[i] = string(args[i+1]) } - db.RLock(key) - defer db.RUnLock(key) - // get entity result := make([][]byte, size) dict, errReply := db.getAsDict(key) @@ -259,9 +259,6 @@ func HMGet(db *DB, args [][]byte) redis.Reply { func execHKeys(db *DB, args [][]byte) redis.Reply { key := string(args[0]) - db.RLock(key) - defer db.RUnLock(key) - dict, errReply := db.getAsDict(key) if errReply != nil { return errReply @@ -284,9 +281,6 @@ func execHKeys(db *DB, args [][]byte) redis.Reply { func execHVals(db *DB, args [][]byte) redis.Reply { key := string(args[0]) - db.RLock(key) - defer db.RUnLock(key) - // get entity dict, errReply := db.getAsDict(key) if errReply != nil { @@ -310,9 +304,6 @@ func execHVals(db *DB, args [][]byte) redis.Reply { func execHGetAll(db *DB, args [][]byte) redis.Reply { key := string(args[0]) - db.RLock(key) - defer db.RUnLock(key) - // get entity dict, errReply := db.getAsDict(key) if errReply != nil { @@ -345,9 +336,6 @@ func execHIncrBy(db *DB, args [][]byte) redis.Reply { return reply.MakeErrReply("ERR value is not an integer or out of range") } - db.Lock(key) - defer db.UnLock(key) - dict, _, errReply := db.getOrInitDict(key) if errReply != nil { return errReply @@ -370,6 +358,12 @@ func execHIncrBy(db *DB, args [][]byte) redis.Reply { return reply.MakeBulkReply(bytes) } +func undoHIncr(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + field := string(args[1]) + return rollbackHashFields(db, key, field) +} + // execHIncrByFloat increments the float value of a hash field by the given number func execHIncrByFloat(db *DB, args [][]byte) redis.Reply { key := string(args[0]) @@ -380,9 +374,6 @@ func execHIncrByFloat(db *DB, args [][]byte) redis.Reply { return reply.MakeErrReply("ERR value is not a valid float") } - db.Lock(key) - defer db.UnLock(key) - // get or init entity dict, _, errReply := db.getOrInitDict(key) if errReply != nil { @@ -406,17 +397,18 @@ func execHIncrByFloat(db *DB, args [][]byte) redis.Reply { } func init() { - RegisterCommand("HSet", execHSet, nil, 4) - RegisterCommand("HSetNX", execHSetNX, nil, 4) - RegisterCommand("HGet", execHGet, nil, 3) - RegisterCommand("HExists", execHExists, nil, 3) - RegisterCommand("HDel", execHDel, nil, -3) - RegisterCommand("HLen", execHLen, nil, 2) - RegisterCommand("HMSet", execHMSet, nil, -4) - RegisterCommand("HGet", execHGet, nil, -3) - RegisterCommand("HKeys", execHKeys, nil, 2) - RegisterCommand("HVals", execHVals, nil, 2) - RegisterCommand("HGetAll", execHGetAll, nil, 2) - RegisterCommand("HIncrBy", execHIncrBy, nil, 4) - RegisterCommand("HIncrByFloat", execHIncrByFloat, nil, 4) + RegisterCommand("HSet", execHSet, writeFirstKey, undoHSet, 4) + RegisterCommand("HSetNX", execHSetNX, writeFirstKey, undoHSet, 4) + RegisterCommand("HGet", execHGet, readFirstKey, nil, 3) + RegisterCommand("HExists", execHExists, readFirstKey, nil, 3) + RegisterCommand("HDel", execHDel, writeFirstKey, undoHDel, -3) + RegisterCommand("HLen", execHLen, readFirstKey, nil, 2) + RegisterCommand("HMSet", execHMSet, writeFirstKey, undoHMSet, -4) + RegisterCommand("HMGet", execHMGet, readFirstKey, nil, -3) + RegisterCommand("HGet", execHGet, readFirstKey, nil, -3) + RegisterCommand("HKeys", execHKeys, readFirstKey, nil, 2) + RegisterCommand("HVals", execHVals, readFirstKey, nil, 2) + RegisterCommand("HGetAll", execHGetAll, readFirstKey, nil, 2) + RegisterCommand("HIncrBy", execHIncrBy, writeFirstKey, undoHIncr, 4) + RegisterCommand("HIncrByFloat", execHIncrByFloat, writeFirstKey, undoHIncr, 4) } diff --git a/hash_test.go b/hash_test.go index ed2ec129..26e6d32f 100644 --- a/hash_test.go +++ b/hash_test.go @@ -2,8 +2,7 @@ package godis import ( "fmt" - "github.com/hdt3213/godis/datastruct/utils" - utils2 "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply/asserts" "strconv" @@ -11,17 +10,17 @@ import ( ) func TestHSet(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 // test hset - key := utils2.RandString(10) + key := utils.RandString(10) values := make(map[string][]byte, size) for i := 0; i < size; i++ { - value := utils2.RandString(10) + value := utils.RandString(10) field := strconv.Itoa(i) values[field] = []byte(value) - result := execHSet(testDB, utils2.ToBytesList(key, field, value)) + result := testDB.Exec(nil, utils.ToCmdLine("hset", key, field, value)) if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(1) { t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code)) } @@ -29,67 +28,67 @@ func TestHSet(t *testing.T) { // test hget and hexists for field, v := range values { - actual := execHGet(testDB, utils2.ToBytesList(key, field)) + actual := testDB.Exec(nil, utils.ToCmdLine("hget", key, field)) expected := reply.MakeBulkReply(v) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(actual.ToBytes()))) } - actual = execHExists(testDB, utils2.ToBytesList(key, field)) + actual = testDB.Exec(nil, utils.ToCmdLine("hexists", key, field)) if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(1) { t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code)) } } // test hlen - actual := execHLen(testDB, utils2.ToBytesList(key)) + actual := testDB.Exec(nil, utils.ToCmdLine("hlen", key)) if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(len(values)) { t.Error(fmt.Sprintf("expected %d, actually %d", len(values), intResult.Code)) } } func TestHDel(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 // set values - key := utils2.RandString(10) + key := utils.RandString(10) fields := make([]string, size) for i := 0; i < size; i++ { - value := utils2.RandString(10) + value := utils.RandString(10) field := strconv.Itoa(i) fields[i] = field - execHSet(testDB, utils2.ToBytesList(key, field, value)) + testDB.Exec(nil, utils.ToCmdLine("hset", key, field, value)) } // test HDel args := []string{key} args = append(args, fields...) - actual := execHDel(testDB, utils2.ToBytesList(args...)) + actual := testDB.Exec(nil, utils.ToCmdLine2("hdel", args...)) if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(len(fields)) { t.Error(fmt.Sprintf("expected %d, actually %d", len(fields), intResult.Code)) } - actual = execHLen(testDB, utils2.ToBytesList(key)) + actual = testDB.Exec(nil, utils.ToCmdLine("hlen", key)) if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(0) { t.Error(fmt.Sprintf("expected %d, actually %d", 0, intResult.Code)) } } func TestHMSet(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 // test hset - key := utils2.RandString(10) + key := utils.RandString(10) fields := make([]string, size) values := make([]string, size) setArgs := []string{key} for i := 0; i < size; i++ { - fields[i] = utils2.RandString(10) - values[i] = utils2.RandString(10) + fields[i] = utils.RandString(10) + values[i] = utils.RandString(10) setArgs = append(setArgs, fields[i], values[i]) } - result := execHMSet(testDB, utils2.ToBytesList(setArgs...)) + result := testDB.Exec(nil, utils.ToCmdLine2("hmset", setArgs...)) if _, ok := result.(*reply.OkReply); !ok { t.Error(fmt.Sprintf("expected ok, actually %s", string(result.ToBytes()))) } @@ -97,32 +96,32 @@ func TestHMSet(t *testing.T) { // test HMGet getArgs := []string{key} getArgs = append(getArgs, fields...) - actual := HMGet(testDB, utils2.ToBytesList(getArgs...)) - expected := reply.MakeMultiBulkReply(utils2.ToBytesList(values...)) + actual := testDB.Exec(nil, utils.ToCmdLine2("hmget", getArgs...)) + expected := reply.MakeMultiBulkReply(utils.ToCmdLine(values...)) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(actual.ToBytes()))) } } func TestHGetAll(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 - key := utils2.RandString(10) + key := utils.RandString(10) fields := make([]string, size) valueSet := make(map[string]bool, size) valueMap := make(map[string]string) all := make([]string, 0) for i := 0; i < size; i++ { - fields[i] = utils2.RandString(10) - value := utils2.RandString(10) + fields[i] = utils.RandString(10) + value := utils.RandString(10) all = append(all, fields[i], value) valueMap[fields[i]] = value valueSet[value] = true - execHSet(testDB, utils2.ToBytesList(key, fields[i], value)) + execHSet(testDB, utils.ToCmdLine(key, fields[i], value)) } // test HGetAll - result := execHGetAll(testDB, utils2.ToBytesList(key)) + result := testDB.Exec(nil, utils.ToCmdLine("hgetall", key)) multiBulk, ok := result.(*reply.MultiBulkReply) if !ok { t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes()))) @@ -144,7 +143,7 @@ func TestHGetAll(t *testing.T) { } // test HKeys - result = execHKeys(testDB, utils2.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("hkeys", key)) multiBulk, ok = result.(*reply.MultiBulkReply) if !ok { t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes()))) @@ -160,7 +159,7 @@ func TestHGetAll(t *testing.T) { } // test HVals - result = execHVals(testDB, utils2.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("hvals", key)) multiBulk, ok = result.(*reply.MultiBulkReply) if !ok { t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes()))) @@ -178,39 +177,110 @@ func TestHGetAll(t *testing.T) { } func TestHIncrBy(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() - key := utils2.RandString(10) - result := execHIncrBy(testDB, utils2.ToBytesList(key, "a", "1")) + key := utils.RandString(10) + result := testDB.Exec(nil, utils.ToCmdLine("hincrby", key, "a", "1")) if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "1" { t.Error(fmt.Sprintf("expected %s, actually %s", "1", string(bulkResult.Arg))) } - result = execHIncrBy(testDB, utils2.ToBytesList(key, "a", "1")) + result = testDB.Exec(nil, utils.ToCmdLine("hincrby", key, "a", "1")) if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "2" { t.Error(fmt.Sprintf("expected %s, actually %s", "2", string(bulkResult.Arg))) } - result = execHIncrByFloat(testDB, utils2.ToBytesList(key, "b", "1.2")) + result = testDB.Exec(nil, utils.ToCmdLine("hincrbyfloat", key, "b", "1.2")) if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "1.2" { t.Error(fmt.Sprintf("expected %s, actually %s", "1.2", string(bulkResult.Arg))) } - result = execHIncrByFloat(testDB, utils2.ToBytesList(key, "b", "1.2")) + result = testDB.Exec(nil, utils.ToCmdLine("hincrbyfloat", key, "b", "1.2")) if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "2.4" { t.Error(fmt.Sprintf("expected %s, actually %s", "2.4", string(bulkResult.Arg))) } } func TestHSetNX(t *testing.T) { - execFlushAll(testDB, [][]byte{}) - key := utils2.RandString(10) - field := utils2.RandString(10) - value := utils2.RandString(10) - result := execHSetNX(testDB, utils2.ToBytesList(key, field, value)) + testDB.Flush() + key := utils.RandString(10) + field := utils.RandString(10) + value := utils.RandString(10) + result := testDB.Exec(nil, utils.ToCmdLine("hsetnx", key, field, value)) asserts.AssertIntReply(t, result, 1) - value2 := utils2.RandString(10) - result = execHSetNX(testDB, utils2.ToBytesList(key, field, value2)) + value2 := utils.RandString(10) + result = testDB.Exec(nil, utils.ToCmdLine("hsetnx", key, field, value2)) asserts.AssertIntReply(t, result, 0) - result = execHGet(testDB, utils2.ToBytesList(key, field)) + result = testDB.Exec(nil, utils.ToCmdLine("hget", key, field)) asserts.AssertBulkReply(t, result, value) +} + +func TestUndoHDel(t *testing.T) { + testDB.Flush() + key := utils.RandString(10) + field := utils.RandString(10) + value := utils.RandString(10) + + testDB.Exec(nil, utils.ToCmdLine("hset", key, field, value)) + cmdLine := utils.ToCmdLine("hdel", key, field) + undoCmdLines := undoHDel(testDB, cmdLine[1:]) + testDB.Exec(nil, cmdLine) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + result := testDB.Exec(nil, utils.ToCmdLine("hget", key, field)) + asserts.AssertBulkReply(t, result, value) +} + +func TestUndoHSet(t *testing.T) { + testDB.Flush() + key := utils.RandString(10) + field := utils.RandString(10) + value := utils.RandString(10) + value2 := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine("hset", key, field, value)) + cmdLine := utils.ToCmdLine("hset", key, field, value2) + undoCmdLines := undoHSet(testDB, cmdLine[1:]) + testDB.Exec(nil, cmdLine) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + result := testDB.Exec(nil, utils.ToCmdLine("hget", key, field)) + asserts.AssertBulkReply(t, result, value) +} + +func TestUndoHMSet(t *testing.T) { + testDB.Flush() + key := utils.RandString(10) + field1 := utils.RandString(10) + field2 := utils.RandString(10) + value := utils.RandString(10) + value2 := utils.RandString(10) + + testDB.Exec(nil, utils.ToCmdLine("hmset", key, field1, value, field2, value)) + cmdLine := utils.ToCmdLine("hmset", key, field1, value2, field2, value2) + undoCmdLines := undoHMSet(testDB, cmdLine[1:]) + testDB.Exec(nil, cmdLine) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + result := testDB.Exec(nil, utils.ToCmdLine("hget", key, field1)) + asserts.AssertBulkReply(t, result, value) + result = testDB.Exec(nil, utils.ToCmdLine("hget", key, field2)) + asserts.AssertBulkReply(t, result, value) +} + +func TestUndoHIncr(t *testing.T) { + testDB.Flush() + key := utils.RandString(10) + field := utils.RandString(10) + + testDB.Exec(nil, utils.ToCmdLine("hset", key, field, "1")) + cmdLine := utils.ToCmdLine("hinctby", key, field, "2") + undoCmdLines := undoHIncr(testDB, cmdLine[1:]) + testDB.Exec(nil, cmdLine) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + result := testDB.Exec(nil, utils.ToCmdLine("hget", key, field)) + asserts.AssertBulkReply(t, result, "1") } diff --git a/interface/redis/client.go b/interface/redis/client.go index 75fe87e1..c91f23f9 100644 --- a/interface/redis/client.go +++ b/interface/redis/client.go @@ -5,9 +5,17 @@ type Connection interface { Write([]byte) error SetPassword(string) GetPassword() string + // client should keep its subscribing channels Subscribe(channel string) UnSubscribe(channel string) SubsCount() int GetChannels() []string + + // used for `Multi` command + InMultiState() bool + SetMultiState(bool) + GetQueuedCmdLine() [][][]byte + EnqueueCmd([][]byte) + ClearQueuedCmds() } diff --git a/keys.go b/keys.go index d5be7e27..0bf7a73b 100644 --- a/keys.go +++ b/keys.go @@ -19,9 +19,6 @@ func execDel(db *DB, args [][]byte) redis.Reply { keys[i] = string(v) } - db.Locks(keys...) - defer db.UnLocks(keys...) - deleted := db.Removes(keys...) if deleted > 0 { db.AddAof(makeAofCmd("del", args)) @@ -29,6 +26,14 @@ func execDel(db *DB, args [][]byte) redis.Reply { return reply.MakeIntReply(int64(deleted)) } +func undoDel(db *DB, args [][]byte) []CmdLine { + keys := make([]string, len(args)) + for i, v := range args { + keys[i] = string(v) + } + return rollbackGivenKeys(db, keys...) +} + // execExists checks if a is existed in db func execExists(db *DB, args [][]byte) redis.Reply { result := int64(0) @@ -78,6 +83,12 @@ func execType(db *DB, args [][]byte) redis.Reply { return &reply.UnknownErrReply{} } +func prepareRename(args [][]byte) ([]string, []string) { + src := string(args[0]) + dest := string(args[1]) + return []string{dest}, []string{src} +} + // execRename a key func execRename(db *DB, args [][]byte) redis.Reply { if len(args) != 2 { @@ -86,9 +97,6 @@ func execRename(db *DB, args [][]byte) redis.Reply { src := string(args[0]) dest := string(args[1]) - db.Locks(src, dest) - defer db.UnLocks(src, dest) - entity, ok := db.GetEntity(src) if !ok { return reply.MakeErrReply("no such key") @@ -106,14 +114,17 @@ func execRename(db *DB, args [][]byte) redis.Reply { return &reply.OkReply{} } +func undoRename(db *DB, args [][]byte) []CmdLine { + src := string(args[0]) + dest := string(args[1]) + return rollbackGivenKeys(db, src, dest) +} + // execRenameNx a key, only if the new key does not exist func execRenameNx(db *DB, args [][]byte) redis.Reply { src := string(args[0]) dest := string(args[1]) - db.Locks(src, dest) - defer db.UnLocks(src, dest) - _, ok := db.GetEntity(dest) if ok { return reply.MakeIntReply(0) @@ -290,20 +301,27 @@ func execKeys(db *DB, args [][]byte) redis.Reply { return reply.MakeMultiBulkReply(result) } +func undoExpire(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + return []CmdLine{ + toTTLCmd(db, key).Args, + } +} + func init() { - RegisterCommand("Del", execDel, nil, -2) - RegisterCommand("Expire", execExpire, nil, 3) - RegisterCommand("ExpireAt", execExpireAt, nil, 3) - RegisterCommand("PExpire", execPExpire, nil, 3) - RegisterCommand("PExpireAt", execPExpireAt, nil, 3) - RegisterCommand("TTL", execTTL, nil, 2) - RegisterCommand("PTTL", execPTTL, nil, 2) - RegisterCommand("Persist", execPersist, nil, 2) - RegisterCommand("Exists", execExists, nil, -2) - RegisterCommand("Type", execType, nil, 2) - RegisterCommand("Rename", execRename, nil, 3) - RegisterCommand("RenameNx", execRenameNx, nil, 3) - RegisterCommand("FlushDB", execFlushDB, nil, -1) - RegisterCommand("FlushAll", execFlushAll, nil, -1) - RegisterCommand("Keys", execKeys, nil, 2) + RegisterCommand("Del", execDel, writeAllKeys, undoDel, -2) + RegisterCommand("Expire", execExpire, writeFirstKey, undoExpire, 3) + RegisterCommand("ExpireAt", execExpireAt, writeFirstKey, undoExpire, 3) + RegisterCommand("PExpire", execPExpire, writeFirstKey, undoExpire, 3) + RegisterCommand("PExpireAt", execPExpireAt, writeFirstKey, undoExpire, 3) + RegisterCommand("TTL", execTTL, readFirstKey, nil, 2) + RegisterCommand("PTTL", execPTTL, readFirstKey, nil, 2) + RegisterCommand("Persist", execPersist, writeFirstKey, undoExpire, 2) + RegisterCommand("Exists", execExists, readAllKeys, nil, -2) + RegisterCommand("Type", execType, readFirstKey, nil, 2) + RegisterCommand("Rename", execRename, prepareRename, undoRename, 3) + RegisterCommand("RenameNx", execRenameNx, prepareRename, undoRename, 3) + RegisterCommand("FlushDB", execFlushDB, noPrepare, nil, -1) + RegisterCommand("FlushAll", execFlushAll, noPrepare, nil, -1) + RegisterCommand("Keys", execKeys, noPrepare, nil, 2) } diff --git a/keys_test.go b/keys_test.go index 16a1e261..95f15ecf 100644 --- a/keys_test.go +++ b/keys_test.go @@ -11,65 +11,65 @@ import ( ) func TestExists(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() key := utils.RandString(10) value := utils.RandString(10) - execSet(testDB, utils.ToBytesList(key, value)) - result := execExists(testDB, utils.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("set", key, value)) + result := testDB.Exec(nil, utils.ToCmdLine("exists", key)) asserts.AssertIntReply(t, result, 1) key = utils.RandString(10) - result = execExists(testDB, utils.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("exists", key)) asserts.AssertIntReply(t, result, 0) } func TestType(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() key := utils.RandString(10) value := utils.RandString(10) - execSet(testDB, utils.ToBytesList(key, value)) - result := execType(testDB, utils.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("set", key, value)) + result := testDB.Exec(nil, utils.ToCmdLine("type", key)) asserts.AssertStatusReply(t, result, "string") testDB.Remove(key) - result = execType(testDB, utils.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("type", key)) asserts.AssertStatusReply(t, result, "none") - execRPush(testDB, utils.ToBytesList(key, value)) - result = execType(testDB, utils.ToBytesList(key)) + execRPush(testDB, utils.ToCmdLine(key, value)) + result = testDB.Exec(nil, utils.ToCmdLine("type", key)) asserts.AssertStatusReply(t, result, "list") testDB.Remove(key) - execHSet(testDB, utils.ToBytesList(key, key, value)) - result = execType(testDB, utils.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("hset", key, key, value)) + result = testDB.Exec(nil, utils.ToCmdLine("type", key)) asserts.AssertStatusReply(t, result, "hash") testDB.Remove(key) - execSAdd(testDB, utils.ToBytesList(key, value)) - result = execType(testDB, utils.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("sadd", key, value)) + result = testDB.Exec(nil, utils.ToCmdLine("type", key)) asserts.AssertStatusReply(t, result, "set") testDB.Remove(key) - execZAdd(testDB, utils.ToBytesList(key, "1", value)) - result = execType(testDB, utils.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("zadd", key, "1", value)) + result = testDB.Exec(nil, utils.ToCmdLine("type", key)) asserts.AssertStatusReply(t, result, "zset") } func TestRename(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() key := utils.RandString(10) value := utils.RandString(10) newKey := key + utils.RandString(2) - execSet(testDB, utils.ToBytesList(key, value, "ex", "1000")) - result := execRename(testDB, utils.ToBytesList(key, newKey)) + testDB.Exec(nil, utils.ToCmdLine("set", key, value, "ex", "1000")) + result := testDB.Exec(nil, utils.ToCmdLine("rename", key, newKey)) if _, ok := result.(*reply.OkReply); !ok { t.Error("expect ok") return } - result = execExists(testDB, utils.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("exists", key)) asserts.AssertIntReply(t, result, 0) - result = execExists(testDB, utils.ToBytesList(newKey)) + result = testDB.Exec(nil, utils.ToCmdLine("exists", newKey)) asserts.AssertIntReply(t, result, 1) // check ttl - result = execTTL(testDB, utils.ToBytesList(newKey)) + result = testDB.Exec(nil, utils.ToCmdLine("ttl", newKey)) intResult, ok := result.(*reply.IntReply) if !ok { t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) @@ -82,18 +82,18 @@ func TestRename(t *testing.T) { } func TestRenameNx(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() key := utils.RandString(10) value := utils.RandString(10) newKey := key + utils.RandString(2) - execSet(testDB, utils.ToBytesList(key, value, "ex", "1000")) - result := execRenameNx(testDB, utils.ToBytesList(key, newKey)) + testDB.Exec(nil, utils.ToCmdLine("set", key, value, "ex", "1000")) + result := testDB.Exec(nil, utils.ToCmdLine("RenameNx", key, newKey)) asserts.AssertIntReply(t, result, 1) - result = execExists(testDB, utils.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("exists", key)) asserts.AssertIntReply(t, result, 0) - result = execExists(testDB, utils.ToBytesList(newKey)) + result = testDB.Exec(nil, utils.ToCmdLine("exists", newKey)) asserts.AssertIntReply(t, result, 1) - result = execTTL(testDB, utils.ToBytesList(newKey)) + result = testDB.Exec(nil, utils.ToCmdLine("ttl", newKey)) intResult, ok := result.(*reply.IntReply) if !ok { t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) @@ -106,14 +106,14 @@ func TestRenameNx(t *testing.T) { } func TestTTL(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() key := utils.RandString(10) value := utils.RandString(10) - execSet(testDB, utils.ToBytesList(key, value)) + testDB.Exec(nil, utils.ToCmdLine("set", key, value)) - result := execExpire(testDB, utils.ToBytesList(key, "1000")) + result := testDB.Exec(nil, utils.ToCmdLine("expire", key, "1000")) asserts.AssertIntReply(t, result, 1) - result = execTTL(testDB, utils.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("ttl", key)) intResult, ok := result.(*reply.IntReply) if !ok { t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) @@ -124,14 +124,14 @@ func TestTTL(t *testing.T) { return } - result = execPersist(testDB, utils.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("persist", key)) asserts.AssertIntReply(t, result, 1) - result = execTTL(testDB, utils.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("ttl", key)) asserts.AssertIntReply(t, result, -1) - result = execPExpire(testDB, utils.ToBytesList(key, "1000000")) + result = testDB.Exec(nil, utils.ToCmdLine("PExpire", key, "1000000")) asserts.AssertIntReply(t, result, 1) - result = execPTTL(testDB, utils.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("PTTL", key)) intResult, ok = result.(*reply.IntReply) if !ok { t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) @@ -143,16 +143,28 @@ func TestTTL(t *testing.T) { } } +func TestExpire(t *testing.T) { + key := utils.RandString(10) + value := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine("SET", key, value)) + testDB.Exec(nil, utils.ToCmdLine("PEXPIRE", key, "100")) + time.Sleep(2 * time.Second) + result := testDB.Exec(nil, utils.ToCmdLine("TTL", key)) + asserts.AssertIntReply(t, result, -2) + +} + func TestExpireAt(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() key := utils.RandString(10) value := utils.RandString(10) - execSet(testDB, utils.ToBytesList(key, value)) + testDB.Exec(nil, utils.ToCmdLine("set", key, value)) expireAt := time.Now().Add(time.Minute).Unix() - result := execExpireAt(testDB, utils.ToBytesList(key, strconv.FormatInt(expireAt, 10))) + result := testDB.Exec(nil, utils.ToCmdLine("ExpireAt", key, strconv.FormatInt(expireAt, 10))) + asserts.AssertIntReply(t, result, 1) - result = execTTL(testDB, utils.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("ttl", key)) intResult, ok := result.(*reply.IntReply) if !ok { t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) @@ -164,9 +176,9 @@ func TestExpireAt(t *testing.T) { } expireAt = time.Now().Add(time.Minute).Unix() - result = execPExpireAt(testDB, utils.ToBytesList(key, strconv.FormatInt(expireAt*1000, 10))) + result = testDB.Exec(nil, utils.ToCmdLine("PExpireAt", key, strconv.FormatInt(expireAt*1000, 10))) asserts.AssertIntReply(t, result, 1) - result = execTTL(testDB, utils.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("ttl", key)) intResult, ok = result.(*reply.IntReply) if !ok { t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) @@ -179,17 +191,17 @@ func TestExpireAt(t *testing.T) { } func TestKeys(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() key := utils.RandString(10) value := utils.RandString(10) - execSet(testDB, utils.ToBytesList(key, value)) - execSet(testDB, utils.ToBytesList("a:"+key, value)) - execSet(testDB, utils.ToBytesList("b:"+key, value)) + testDB.Exec(nil, utils.ToCmdLine("set", key, value)) + testDB.Exec(nil, utils.ToCmdLine("set", "a:"+key, value)) + testDB.Exec(nil, utils.ToCmdLine("set", "b:"+key, value)) - result := execKeys(testDB, utils.ToBytesList("*")) + result := testDB.Exec(nil, utils.ToCmdLine("keys", "*")) asserts.AssertMultiBulkReplySize(t, result, 3) - result = execKeys(testDB, utils.ToBytesList("a:*")) + result = testDB.Exec(nil, utils.ToCmdLine("keys", "a:*")) asserts.AssertMultiBulkReplySize(t, result, 1) - result = execKeys(testDB, utils.ToBytesList("?:*")) + result = testDB.Exec(nil, utils.ToCmdLine("keys", "?:*")) asserts.AssertMultiBulkReplySize(t, result, 2) } diff --git a/lib/utils/convert.go b/lib/utils/convert.go deleted file mode 100644 index d27df7aa..00000000 --- a/lib/utils/convert.go +++ /dev/null @@ -1,10 +0,0 @@ -package utils - -// ToBytesList convert strings to [][]byte -func ToBytesList(cmd ...string) [][]byte { - args := make([][]byte, len(cmd)) - for i, s := range cmd { - args[i] = []byte(s) - } - return args -} diff --git a/datastruct/utils/utils.go b/lib/utils/utils.go similarity index 59% rename from datastruct/utils/utils.go rename to lib/utils/utils.go index 470eb35b..7cd8a27b 100644 --- a/datastruct/utils/utils.go +++ b/lib/utils/utils.go @@ -1,5 +1,23 @@ package utils +// ToCmdLine convert strings to [][]byte +func ToCmdLine(cmd ...string) [][]byte { + args := make([][]byte, len(cmd)) + for i, s := range cmd { + args[i] = []byte(s) + } + return args +} + +func ToCmdLine2(commandName string, args ...string) [][]byte { + result := make([][]byte, len(args)+1) + result[0] = []byte(commandName) + for i, s := range args { + result[i+1] = []byte(s) + } + return result +} + // Equals check whether the given value is equal func Equals(a interface{}, b interface{}) bool { sliceA, okA := a.([]byte) diff --git a/list.go b/list.go index 1b27ada6..2e49f746 100644 --- a/list.go +++ b/list.go @@ -3,6 +3,7 @@ package godis import ( List "github.com/hdt3213/godis/datastruct/list" "github.com/hdt3213/godis/interface/redis" + "github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/redis/reply" "strconv" ) @@ -45,9 +46,6 @@ func execLIndex(db *DB, args [][]byte) redis.Reply { } index := int(index64) - db.RLock(key) - defer db.RUnLock(key) - // get entity list, errReply := db.getAsList(key) if errReply != nil { @@ -75,9 +73,6 @@ func execLLen(db *DB, args [][]byte) redis.Reply { // parse args key := string(args[0]) - db.RLock(key) - defer db.RUnLock(key) - list, errReply := db.getAsList(key) if errReply != nil { return errReply @@ -95,10 +90,6 @@ func execLPop(db *DB, args [][]byte) redis.Reply { // parse args key := string(args[0]) - // lock - db.Lock(key) - defer db.UnLock(key) - // get data list, errReply := db.getAsList(key) if errReply != nil { @@ -116,15 +107,32 @@ func execLPop(db *DB, args [][]byte) redis.Reply { return reply.MakeBulkReply(val) } +var lPushCmd = []byte("LPUSH") + +func undoLPop(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + list, errReply := db.getAsList(key) + if errReply != nil { + return nil + } + if list == nil || list.Len() == 0 { + return nil + } + element, _ := list.Get(0).([]byte) + return []CmdLine{ + { + lPushCmd, + args[0], + element, + }, + } +} + // execLPush inserts element at head of list func execLPush(db *DB, args [][]byte) redis.Reply { key := string(args[0]) values := args[1:] - // lock - db.Lock(key) - defer db.UnLock(key) - // get or init entity list, _, errReply := db.getOrInitList(key) if errReply != nil { @@ -140,15 +148,21 @@ func execLPush(db *DB, args [][]byte) redis.Reply { return reply.MakeIntReply(int64(list.Len())) } +func undoLPush(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + count := len(args) - 1 + cmdLines := make([]CmdLine, 0, count) + for i := 0; i < count; i++ { + cmdLines = append(cmdLines, utils.ToCmdLine("LPOP", key)) + } + return cmdLines +} + // execLPushX inserts element at head of list, only if list exists func execLPushX(db *DB, args [][]byte) redis.Reply { key := string(args[0]) values := args[1:] - // lock - db.Lock(key) - defer db.UnLock(key) - // get or init entity list, errReply := db.getAsList(key) if errReply != nil { @@ -181,10 +195,6 @@ func execLRange(db *DB, args [][]byte) redis.Reply { } stop := int(stop64) - // lock key - db.RLock(key) - defer db.RUnLock(key) - // get data list, errReply := db.getAsList(key) if errReply != nil { @@ -237,10 +247,6 @@ func execLRem(db *DB, args [][]byte) redis.Reply { count := int(count64) value := args[2] - // lock - db.Lock(key) - defer db.UnLock(key) - // get data entity list, errReply := db.getAsList(key) if errReply != nil { @@ -280,10 +286,6 @@ func execLSet(db *DB, args [][]byte) redis.Reply { index := int(index64) value := args[2] - // lock - db.Lock(key) - defer db.UnLock(key) - // get data list, errReply := db.getAsList(key) if errReply != nil { @@ -307,15 +309,44 @@ func execLSet(db *DB, args [][]byte) redis.Reply { return &reply.OkReply{} } +func undoLSet(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + index64, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return nil + } + index := int(index64) + list, errReply := db.getAsList(key) + if errReply != nil { + return nil + } + if list == nil { + return nil + } + size := list.Len() // assert: size > 0 + if index < -1*size { + return nil + } else if index < 0 { + index = size + index + } else if index >= size { + return nil + } + value, _ := list.Get(index).([]byte) + return []CmdLine{ + { + []byte("LSET"), + args[0], + args[1], + value, + }, + } +} + // execRPop removes last element of list then return it func execRPop(db *DB, args [][]byte) redis.Reply { // parse args key := string(args[0]) - // lock - db.Lock(key) - defer db.UnLock(key) - // get data list, errReply := db.getAsList(key) if errReply != nil { @@ -333,15 +364,39 @@ func execRPop(db *DB, args [][]byte) redis.Reply { return reply.MakeBulkReply(val) } +var rPushCmd = []byte("RPUSH") + +func undoRPop(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + list, errReply := db.getAsList(key) + if errReply != nil { + return nil + } + if list == nil || list.Len() == 0 { + return nil + } + element, _ := list.Get(list.Len() - 1).([]byte) + return []CmdLine{ + { + rPushCmd, + args[0], + element, + }, + } +} + +func prepareRPopLPush(args [][]byte) ([]string, []string) { + return []string{ + string(args[0]), + string(args[1]), + }, nil +} + // execRPopLPush pops last element of list-A then insert it to the head of list-B func execRPopLPush(db *DB, args [][]byte) redis.Reply { sourceKey := string(args[0]) destKey := string(args[1]) - // lock - db.Locks(sourceKey, destKey) - defer db.UnLocks(sourceKey, destKey) - // get source entity sourceList, errReply := db.getAsList(sourceKey) if errReply != nil { @@ -369,16 +424,35 @@ func execRPopLPush(db *DB, args [][]byte) redis.Reply { return reply.MakeBulkReply(val) } +func undoRPopLPush(db *DB, args [][]byte) []CmdLine { + sourceKey := string(args[0]) + list, errReply := db.getAsList(sourceKey) + if errReply != nil { + return nil + } + if list == nil || list.Len() == 0 { + return nil + } + element, _ := list.Get(list.Len() - 1).([]byte) + return []CmdLine{ + { + rPushCmd, + args[0], + element, + }, + { + []byte("LPOP"), + args[1], + }, + } +} + // execRPush inserts element at last of list func execRPush(db *DB, args [][]byte) redis.Reply { // parse args key := string(args[0]) values := args[1:] - // lock - db.Lock(key) - defer db.UnLock(key) - // get or init entity list, _, errReply := db.getOrInitList(key) if errReply != nil { @@ -393,6 +467,16 @@ func execRPush(db *DB, args [][]byte) redis.Reply { return reply.MakeIntReply(int64(list.Len())) } +func undoRPush(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + count := len(args) - 1 + cmdLines := make([]CmdLine, 0, count) + for i := 0; i < count; i++ { + cmdLines = append(cmdLines, utils.ToCmdLine("RPOP", key)) + } + return cmdLines +} + // execRPushX inserts element at last of list only if list exists func execRPushX(db *DB, args [][]byte) redis.Reply { if len(args) < 2 { @@ -401,10 +485,6 @@ func execRPushX(db *DB, args [][]byte) redis.Reply { key := string(args[0]) values := args[1:] - // lock - db.Lock(key) - defer db.UnLock(key) - // get or init entity list, errReply := db.getAsList(key) if errReply != nil { @@ -424,16 +504,16 @@ func execRPushX(db *DB, args [][]byte) redis.Reply { } func init() { - RegisterCommand("LPush", execLPush, nil, -3) - RegisterCommand("LPushX", execLPushX, nil, -3) - RegisterCommand("RPush", execRPush, nil, -3) - RegisterCommand("RPushX", execRPushX, nil, -3) - RegisterCommand("LPop", execLPop, nil, 2) - RegisterCommand("RPop", execRPop, nil, 2) - RegisterCommand("RPopLPush", execRPopLPush, nil, 4) - RegisterCommand("LRem", execLRem, nil, 4) - RegisterCommand("LLen", execLLen, nil, 2) - RegisterCommand("LIndex", execLIndex, nil, 3) - RegisterCommand("LSet", execLSet, nil, 4) - RegisterCommand("LRange", execLRange, nil, 4) + RegisterCommand("LPush", execLPush, writeFirstKey, undoLPush, -3) + RegisterCommand("LPushX", execLPushX, writeFirstKey, undoLPush, -3) + RegisterCommand("RPush", execRPush, writeFirstKey, undoRPush, -3) + RegisterCommand("RPushX", execRPushX, writeFirstKey, undoRPush, -3) + RegisterCommand("LPop", execLPop, writeFirstKey, undoLPop, 2) + RegisterCommand("RPop", execRPop, writeFirstKey, undoRPop, 2) + RegisterCommand("RPopLPush", execRPopLPush, prepareRPopLPush, undoRPopLPush, 3) + RegisterCommand("LRem", execLRem, writeFirstKey, rollbackFirstKey, 4) + RegisterCommand("LLen", execLLen, readFirstKey, nil, 2) + RegisterCommand("LIndex", execLIndex, readFirstKey, nil, 3) + RegisterCommand("LSet", execLSet, writeFirstKey, undoLSet, 4) + RegisterCommand("LRange", execLRange, readFirstKey, nil, 4) } diff --git a/list_test.go b/list_test.go index 0e6eaba1..2b290377 100644 --- a/list_test.go +++ b/list_test.go @@ -2,29 +2,29 @@ package godis import ( "fmt" - "github.com/hdt3213/godis/datastruct/utils" - utils2 "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/redis/reply" + "github.com/hdt3213/godis/redis/reply/asserts" "strconv" "testing" ) func TestPush(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 // rpush single - key := utils2.RandString(10) + key := utils.RandString(10) values := make([][]byte, size) for i := 0; i < size; i++ { - value := utils2.RandString(10) + value := utils.RandString(10) values[i] = []byte(value) - result := execRPush(testDB, utils2.ToBytesList(key, value)) + result := testDB.Exec(nil, utils.ToCmdLine("rpush", key, value)) if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(i+1) { t.Error(fmt.Sprintf("expected %d, actually %d", i+1, intResult.Code)) } } - actual := execLRange(testDB, utils2.ToBytesList(key, "0", "-1")) + actual := testDB.Exec(nil, utils.ToCmdLine("lrange", key, "0", "-1")) expected := reply.MakeMultiBulkReply(values) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error("push error") @@ -32,36 +32,38 @@ func TestPush(t *testing.T) { testDB.Remove(key) // rpush multi - key = utils2.RandString(10) - values = make([][]byte, size+1) - values[0] = []byte(key) + key = utils.RandString(10) + args := make([]string, size+1) + args[0] = key + values = make([][]byte, size) for i := 0; i < size; i++ { - value := utils2.RandString(10) - values[i+1] = []byte(value) + value := utils.RandString(10) + values[i] = []byte(value) + args[i+1] = value } - result := execRPush(testDB, values) + result := testDB.Exec(nil, utils.ToCmdLine2("rpush", args...)) if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(size) { t.Error(fmt.Sprintf("expected %d, actually %d", size, intResult.Code)) } - actual = execLRange(testDB, utils2.ToBytesList(key, "0", "-1")) - expected = reply.MakeMultiBulkReply(values[1:]) + actual = testDB.Exec(nil, utils.ToCmdLine("lrange", key, "0", "-1")) + expected = reply.MakeMultiBulkReply(values) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error("push error") } testDB.Remove(key) // left push single - key = utils2.RandString(10) + key = utils.RandString(10) values = make([][]byte, size) for i := 0; i < size; i++ { - value := utils2.RandString(10) + value := utils.RandString(10) values[size-i-1] = []byte(value) - result = execLPush(testDB, utils2.ToBytesList(key, value)) + result = testDB.Exec(nil, utils.ToCmdLine("lpush", key, value)) if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(i+1) { t.Error(fmt.Sprintf("expected %d, actually %d", i+1, intResult.Code)) } } - actual = execLRange(testDB, utils2.ToBytesList(key, "0", "-1")) + actual = testDB.Exec(nil, utils.ToCmdLine("lrange", key, "0", "-1")) expected = reply.MakeMultiBulkReply(values) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error("push error") @@ -69,20 +71,21 @@ func TestPush(t *testing.T) { testDB.Remove(key) // left push multi - key = utils2.RandString(10) - values = make([][]byte, size+1) - values[0] = []byte(key) + key = utils.RandString(10) + args = make([]string, size+1) + args[0] = key expectedValues := make([][]byte, size) for i := 0; i < size; i++ { - value := utils2.RandString(10) - values[i+1] = []byte(value) + value := utils.RandString(10) + args[i+1] = value expectedValues[size-i-1] = []byte(value) } result = execLPush(testDB, values) + result = testDB.Exec(nil, utils.ToCmdLine2("lpush", args...)) if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(size) { t.Error(fmt.Sprintf("expected %d, actually %d", size, intResult.Code)) } - actual = execLRange(testDB, utils2.ToBytesList(key, "0", "-1")) + actual = testDB.Exec(nil, utils.ToCmdLine("lrange", key, "0", "-1")) expected = reply.MakeMultiBulkReply(expectedValues) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error("push error") @@ -92,19 +95,19 @@ func TestPush(t *testing.T) { func TestLRange(t *testing.T) { // prepare list - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 - key := utils2.RandString(10) + key := utils.RandString(10) values := make([][]byte, size) for i := 0; i < size; i++ { - value := utils2.RandString(10) - execRPush(testDB, utils2.ToBytesList(key, value)) + value := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine("rpush", key, value)) values[i] = []byte(value) } start := "0" end := "9" - actual := execLRange(testDB, utils2.ToBytesList(key, start, end)) + actual := testDB.Exec(nil, utils.ToCmdLine("lrange", key, start, end)) expected := reply.MakeMultiBulkReply(values[0:10]) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) @@ -112,7 +115,7 @@ func TestLRange(t *testing.T) { start = "0" end = "200" - actual = execLRange(testDB, utils2.ToBytesList(key, start, end)) + actual = testDB.Exec(nil, utils.ToCmdLine("lrange", key, start, end)) expected = reply.MakeMultiBulkReply(values) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) @@ -120,7 +123,7 @@ func TestLRange(t *testing.T) { start = "0" end = "-10" - actual = execLRange(testDB, utils2.ToBytesList(key, start, end)) + actual = testDB.Exec(nil, utils.ToCmdLine("lrange", key, start, end)) expected = reply.MakeMultiBulkReply(values[0 : size-10+1]) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) @@ -128,7 +131,7 @@ func TestLRange(t *testing.T) { start = "0" end = "-200" - actual = execLRange(testDB, utils2.ToBytesList(key, start, end)) + actual = testDB.Exec(nil, utils.ToCmdLine("lrange", key, start, end)) expected = reply.MakeMultiBulkReply(values[0:0]) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) @@ -136,7 +139,7 @@ func TestLRange(t *testing.T) { start = "-10" end = "-1" - actual = execLRange(testDB, utils2.ToBytesList(key, start, end)) + actual = testDB.Exec(nil, utils.ToCmdLine("lrange", key, start, end)) expected = reply.MakeMultiBulkReply(values[90:]) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("range error [%s, %s]", start, end)) @@ -145,23 +148,23 @@ func TestLRange(t *testing.T) { func TestLIndex(t *testing.T) { // prepare list - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 - key := utils2.RandString(10) + key := utils.RandString(10) values := make([][]byte, size) for i := 0; i < size; i++ { - value := utils2.RandString(10) - execRPush(testDB, utils2.ToBytesList(key, value)) + value := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine("rpush", key, value)) values[i] = []byte(value) } - result := execLLen(testDB, utils2.ToBytesList(key)) + result := testDB.Exec(nil, utils.ToCmdLine("llen", key)) if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(size) { t.Error(fmt.Sprintf("expected %d, actually %d", size, intResult.Code)) } for i := 0; i < size; i++ { - result = execLIndex(testDB, utils2.ToBytesList(key, strconv.Itoa(i))) + result = testDB.Exec(nil, utils.ToCmdLine("lindex", key, strconv.Itoa(i))) expected := reply.MakeBulkReply(values[i]) if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) @@ -169,7 +172,7 @@ func TestLIndex(t *testing.T) { } for i := 1; i <= size; i++ { - result = execLIndex(testDB, utils2.ToBytesList(key, strconv.Itoa(-i))) + result = testDB.Exec(nil, utils.ToCmdLine("lindex", key, strconv.Itoa(-i))) expected := reply.MakeBulkReply(values[size-i]) if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) @@ -179,55 +182,55 @@ func TestLIndex(t *testing.T) { func TestLRem(t *testing.T) { // prepare list - execFlushAll(testDB, [][]byte{}) - key := utils2.RandString(10) + testDB.Flush() + key := utils.RandString(10) values := []string{key, "a", "b", "a", "a", "c", "a", "a"} - execRPush(testDB, utils2.ToBytesList(values...)) + testDB.Exec(nil, utils.ToCmdLine2("rpush", values...)) - result := execLRem(testDB, utils2.ToBytesList(key, "1", "a")) + result := testDB.Exec(nil, utils.ToCmdLine("lrem", key, "1", "a")) if intResult, _ := result.(*reply.IntReply); intResult.Code != 1 { t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code)) } - result = execLLen(testDB, utils2.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("llen", key)) if intResult, _ := result.(*reply.IntReply); intResult.Code != 6 { t.Error(fmt.Sprintf("expected %d, actually %d", 6, intResult.Code)) } - result = execLRem(testDB, utils2.ToBytesList(key, "-2", "a")) + result = testDB.Exec(nil, utils.ToCmdLine("lrem", key, "-2", "a")) if intResult, _ := result.(*reply.IntReply); intResult.Code != 2 { t.Error(fmt.Sprintf("expected %d, actually %d", 2, intResult.Code)) } - result = execLLen(testDB, utils2.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("llen", key)) if intResult, _ := result.(*reply.IntReply); intResult.Code != 4 { t.Error(fmt.Sprintf("expected %d, actually %d", 4, intResult.Code)) } - result = execLRem(testDB, utils2.ToBytesList(key, "0", "a")) + result = testDB.Exec(nil, utils.ToCmdLine("lrem", key, "0", "a")) if intResult, _ := result.(*reply.IntReply); intResult.Code != 2 { t.Error(fmt.Sprintf("expected %d, actually %d", 2, intResult.Code)) } - result = execLLen(testDB, utils2.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("llen", key)) if intResult, _ := result.(*reply.IntReply); intResult.Code != 2 { t.Error(fmt.Sprintf("expected %d, actually %d", 2, intResult.Code)) } } func TestLSet(t *testing.T) { - execFlushAll(testDB, [][]byte{}) - key := utils2.RandString(10) + testDB.Flush() + key := utils.RandString(10) values := []string{key, "a", "b", "c", "d", "e", "f"} - execRPush(testDB, utils2.ToBytesList(values...)) + testDB.Exec(nil, utils.ToCmdLine2("rpush", values...)) // test positive index size := len(values) - 1 for i := 0; i < size; i++ { indexStr := strconv.Itoa(i) - value := utils2.RandString(10) - result := execLSet(testDB, utils2.ToBytesList(key, indexStr, value)) + value := utils.RandString(10) + result := testDB.Exec(nil, utils.ToCmdLine("lset", key, indexStr, value)) if _, ok := result.(*reply.OkReply); !ok { t.Error(fmt.Sprintf("expected OK, actually %s", string(result.ToBytes()))) } - result = execLIndex(testDB, utils2.ToBytesList(key, indexStr)) + result = testDB.Exec(nil, utils.ToCmdLine("lindex", key, indexStr)) expected := reply.MakeBulkReply([]byte(value)) if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) @@ -235,12 +238,12 @@ func TestLSet(t *testing.T) { } // test negative index for i := 1; i <= size; i++ { - value := utils2.RandString(10) - result := execLSet(testDB, utils2.ToBytesList(key, strconv.Itoa(-i), value)) + value := utils.RandString(10) + result := testDB.Exec(nil, utils.ToCmdLine("lset", key, strconv.Itoa(-i), value)) if _, ok := result.(*reply.OkReply); !ok { t.Error(fmt.Sprintf("expected OK, actually %s", string(result.ToBytes()))) } - result = execLIndex(testDB, utils2.ToBytesList(key, strconv.Itoa(len(values)-i-1))) + result = testDB.Exec(nil, utils.ToCmdLine("lindex", key, strconv.Itoa(len(values)-i-1))) expected := reply.MakeBulkReply([]byte(value)) if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) @@ -248,17 +251,17 @@ func TestLSet(t *testing.T) { } // test illegal index - value := utils2.RandString(10) - result := execLSet(testDB, utils2.ToBytesList(key, strconv.Itoa(-len(values)-1), value)) + value := utils.RandString(10) + result := testDB.Exec(nil, utils.ToCmdLine("lset", key, strconv.Itoa(-len(values)-1), value)) expected := reply.MakeErrReply("ERR index out of range") if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) } - result = execLSet(testDB, utils2.ToBytesList(key, strconv.Itoa(len(values)), value)) + result = testDB.Exec(nil, utils.ToCmdLine("lset", key, strconv.Itoa(len(values)), value)) if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) } - result = execLSet(testDB, utils2.ToBytesList(key, "a", value)) + result = testDB.Exec(nil, utils.ToCmdLine("lset", key, "a", value)) expected = reply.MakeErrReply("ERR value is not an integer or out of range") if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) @@ -266,20 +269,20 @@ func TestLSet(t *testing.T) { } func TestLPop(t *testing.T) { - execFlushAll(testDB, [][]byte{}) - key := utils2.RandString(10) + testDB.Flush() + key := utils.RandString(10) values := []string{key, "a", "b", "c", "d", "e", "f"} - execRPush(testDB, utils2.ToBytesList(values...)) + testDB.Exec(nil, utils.ToCmdLine2("rpush", values...)) size := len(values) - 1 for i := 0; i < size; i++ { - result := execLPop(testDB, utils2.ToBytesList(key)) + result := testDB.Exec(nil, utils.ToCmdLine("lpop", key)) expected := reply.MakeBulkReply([]byte(values[i+1])) if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) } } - result := execRPop(testDB, utils2.ToBytesList(key)) + result := testDB.Exec(nil, utils.ToCmdLine("rpop", key)) expected := &reply.NullBulkReply{} if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) @@ -287,20 +290,20 @@ func TestLPop(t *testing.T) { } func TestRPop(t *testing.T) { - execFlushAll(testDB, [][]byte{}) - key := utils2.RandString(10) + testDB.Flush() + key := utils.RandString(10) values := []string{key, "a", "b", "c", "d", "e", "f"} - execRPush(testDB, utils2.ToBytesList(values...)) + testDB.Exec(nil, utils.ToCmdLine2("rpush", values...)) size := len(values) - 1 for i := 0; i < size; i++ { - result := execRPop(testDB, utils2.ToBytesList(key)) + result := testDB.Exec(nil, utils.ToCmdLine("rpop", key)) expected := reply.MakeBulkReply([]byte(values[len(values)-i-1])) if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) } } - result := execRPop(testDB, utils2.ToBytesList(key)) + result := testDB.Exec(nil, utils.ToCmdLine("rpop", key)) expected := &reply.NullBulkReply{} if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) @@ -308,25 +311,25 @@ func TestRPop(t *testing.T) { } func TestRPopLPush(t *testing.T) { - execFlushAll(testDB, [][]byte{}) - key1 := utils2.RandString(10) - key2 := utils2.RandString(10) + testDB.Flush() + key1 := utils.RandString(10) + key2 := utils.RandString(10) values := []string{key1, "a", "b", "c", "d", "e", "f"} - execRPush(testDB, utils2.ToBytesList(values...)) + testDB.Exec(nil, utils.ToCmdLine2("rpush", values...)) size := len(values) - 1 for i := 0; i < size; i++ { - result := execRPopLPush(testDB, utils2.ToBytesList(key1, key2)) + result := testDB.Exec(nil, utils.ToCmdLine("rpoplpush", key1, key2)) expected := reply.MakeBulkReply([]byte(values[len(values)-i-1])) if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) } - result = execLIndex(testDB, utils2.ToBytesList(key2, "0")) + result = testDB.Exec(nil, utils.ToCmdLine("lindex", key2, "0")) if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) } } - result := execRPop(testDB, utils2.ToBytesList(key1)) + result := testDB.Exec(nil, utils.ToCmdLine("rpop", key1)) expected := &reply.NullBulkReply{} if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) @@ -334,23 +337,23 @@ func TestRPopLPush(t *testing.T) { } func TestRPushX(t *testing.T) { - execFlushAll(testDB, [][]byte{}) - key := utils2.RandString(10) - result := execRPushX(testDB, utils2.ToBytesList(key, "1")) + testDB.Flush() + key := utils.RandString(10) + result := testDB.Exec(nil, utils.ToCmdLine("rpushx", key, "1")) expected := reply.MakeIntReply(int64(0)) if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) } - execRPush(testDB, utils2.ToBytesList(key, "1")) + testDB.Exec(nil, utils.ToCmdLine("rpush", key, "1")) for i := 0; i < 10; i++ { - value := utils2.RandString(10) - result := execRPushX(testDB, utils2.ToBytesList(key, value)) + value := utils.RandString(10) + result = testDB.Exec(nil, utils.ToCmdLine("rpushx", key, value)) expected := reply.MakeIntReply(int64(i + 2)) if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) } - result = execLIndex(testDB, utils2.ToBytesList(key, "-1")) + result = testDB.Exec(nil, utils.ToCmdLine("lindex", key, "-1")) expected2 := reply.MakeBulkReply([]byte(value)) if !utils.BytesEquals(result.ToBytes(), expected2.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected2.ToBytes()), string(result.ToBytes()))) @@ -359,27 +362,106 @@ func TestRPushX(t *testing.T) { } func TestLPushX(t *testing.T) { - execFlushAll(testDB, [][]byte{}) - key := utils2.RandString(10) - result := execRPushX(testDB, utils2.ToBytesList(key, "1")) + testDB.Flush() + key := utils.RandString(10) + result := testDB.Exec(nil, utils.ToCmdLine("rpushx", key, "1")) expected := reply.MakeIntReply(int64(0)) if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) } - execLPush(testDB, utils2.ToBytesList(key, "1")) + testDB.Exec(nil, utils.ToCmdLine("lpush", key, "1")) for i := 0; i < 10; i++ { - value := utils2.RandString(10) - result := execLPushX(testDB, utils2.ToBytesList(key, value)) + value := utils.RandString(10) + result = testDB.Exec(nil, utils.ToCmdLine("lpushx", key, value)) expected := reply.MakeIntReply(int64(i + 2)) if !utils.BytesEquals(result.ToBytes(), expected.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(result.ToBytes()))) } - result = execLIndex(testDB, utils2.ToBytesList(key, "0")) + result = testDB.Exec(nil, utils.ToCmdLine("lindex", key, "0")) expected2 := reply.MakeBulkReply([]byte(value)) if !utils.BytesEquals(result.ToBytes(), expected2.ToBytes()) { t.Error(fmt.Sprintf("expected %s, actually %s", string(expected2.ToBytes()), string(result.ToBytes()))) } } +} + +func TestUndoLPush(t *testing.T) { + testDB.Flush() + key := utils.RandString(10) + value := utils.RandString(10) + cmdLine := utils.ToCmdLine("lpush", key, value) + testDB.Exec(nil, cmdLine) + undoCmdLines := undoLPush(testDB, cmdLine[1:]) + testDB.Exec(nil, cmdLine) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + result := testDB.Exec(nil, utils.ToCmdLine("llen", key)) + asserts.AssertIntReply(t, result, 1) +} + +func TestUndoLPop(t *testing.T) { + testDB.Flush() + key := utils.RandString(10) + value := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine("lpush", key, value, value)) + cmdLine := utils.ToCmdLine("lpop", key) + undoCmdLines := undoLPop(testDB, cmdLine[1:]) + testDB.Exec(nil, cmdLine) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + result := testDB.Exec(nil, utils.ToCmdLine("llen", key)) + asserts.AssertIntReply(t, result, 2) +} + +func TestUndoLSet(t *testing.T) { + testDB.Flush() + key := utils.RandString(10) + value := utils.RandString(10) + value2 := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine("lpush", key, value, value)) + cmdLine := utils.ToCmdLine("lset", key, "1", value2) + undoCmdLines := undoLSet(testDB, cmdLine[1:]) + testDB.Exec(nil, cmdLine) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + result := testDB.Exec(nil, utils.ToCmdLine("lindex", key, "1")) + asserts.AssertBulkReply(t, result, value) +} + +func TestUndoRPop(t *testing.T) { + testDB.Flush() + key := utils.RandString(10) + value := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine("rpush", key, value, value)) + cmdLine := utils.ToCmdLine("rpop", key) + undoCmdLines := undoRPop(testDB, cmdLine[1:]) + testDB.Exec(nil, cmdLine) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + result := testDB.Exec(nil, utils.ToCmdLine("llen", key)) + asserts.AssertIntReply(t, result, 2) +} +func TestUndoRPopLPush(t *testing.T) { + testDB.Flush() + key1 := utils.RandString(10) + key2 := utils.RandString(10) + value := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine("lpush", key1, value)) + + cmdLine := utils.ToCmdLine("rpoplpush", key1, key2) + undoCmdLines := undoRPopLPush(testDB, cmdLine[1:]) + testDB.Exec(nil, cmdLine) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + result := testDB.Exec(nil, utils.ToCmdLine("llen", key1)) + asserts.AssertIntReply(t, result, 1) + result = testDB.Exec(nil, utils.ToCmdLine("llen", key2)) + asserts.AssertIntReply(t, result, 0) } diff --git a/marshal.go b/marshal.go new file mode 100644 index 00000000..ee251c90 --- /dev/null +++ b/marshal.go @@ -0,0 +1,118 @@ +package godis + +import ( + "github.com/hdt3213/godis/datastruct/dict" + List "github.com/hdt3213/godis/datastruct/list" + "github.com/hdt3213/godis/datastruct/set" + SortedSet "github.com/hdt3213/godis/datastruct/sortedset" + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/reply" + "strconv" + "time" +) + +// EntityToCmd serialize data entity to redis command +func EntityToCmd(key string, entity *DataEntity) *reply.MultiBulkReply { + if entity == nil { + return nil + } + var cmd *reply.MultiBulkReply + switch val := entity.Data.(type) { + case []byte: + cmd = stringToCmd(key, val) + case *List.LinkedList: + cmd = listToCmd(key, val) + case *set.Set: + cmd = setToCmd(key, val) + case dict.Dict: + cmd = hashToCmd(key, val) + case *SortedSet.SortedSet: + cmd = zSetToCmd(key, val) + } + return cmd +} + +// toTTLCmd serialize ttl config +func toTTLCmd(db *DB, key string) *reply.MultiBulkReply { + raw, exists := db.ttlMap.Get(key) + if !exists { + // 无 TTL + return reply.MakeMultiBulkReply(utils.ToCmdLine("PERSIST", key)) + } + expireTime, _ := raw.(time.Time) + timestamp := strconv.FormatInt(expireTime.UnixNano()/1000/1000, 10) + return reply.MakeMultiBulkReply(utils.ToCmdLine("PEXPIREAT", key, timestamp)) +} + +var setCmd = []byte("SET") + +func stringToCmd(key string, bytes []byte) *reply.MultiBulkReply { + args := make([][]byte, 3) + args[0] = setCmd + args[1] = []byte(key) + args[2] = bytes + return reply.MakeMultiBulkReply(args) +} + +var rPushAllCmd = []byte("RPUSH") + +func listToCmd(key string, list *List.LinkedList) *reply.MultiBulkReply { + args := make([][]byte, 2+list.Len()) + args[0] = rPushAllCmd + args[1] = []byte(key) + list.ForEach(func(i int, val interface{}) bool { + bytes, _ := val.([]byte) + args[2+i] = bytes + return true + }) + return reply.MakeMultiBulkReply(args) +} + +var sAddCmd = []byte("SADD") + +func setToCmd(key string, set *set.Set) *reply.MultiBulkReply { + args := make([][]byte, 2+set.Len()) + args[0] = sAddCmd + args[1] = []byte(key) + i := 0 + set.ForEach(func(val string) bool { + args[2+i] = []byte(val) + i++ + return true + }) + return reply.MakeMultiBulkReply(args) +} + +var hMSetCmd = []byte("HMSET") + +func hashToCmd(key string, hash dict.Dict) *reply.MultiBulkReply { + args := make([][]byte, 2+hash.Len()*2) + args[0] = hMSetCmd + args[1] = []byte(key) + i := 0 + hash.ForEach(func(field string, val interface{}) bool { + bytes, _ := val.([]byte) + args[2+i*2] = []byte(field) + args[3+i*2] = bytes + i++ + return true + }) + return reply.MakeMultiBulkReply(args) +} + +var zAddCmd = []byte("ZADD") + +func zSetToCmd(key string, zset *SortedSet.SortedSet) *reply.MultiBulkReply { + args := make([][]byte, 2+zset.Len()*2) + args[0] = zAddCmd + args[1] = []byte(key) + i := 0 + zset.ForEach(int64(0), int64(zset.Len()), true, func(element *SortedSet.Element) bool { + value := strconv.FormatFloat(element.Score, 'f', -1, 64) + args[2+i*2] = []byte(value) + args[3+i*2] = []byte(element.Member) + i++ + return true + }) + return reply.MakeMultiBulkReply(args) +} diff --git a/multi.go b/multi.go new file mode 100644 index 00000000..de477918 --- /dev/null +++ b/multi.go @@ -0,0 +1,94 @@ +package godis + +import ( + "github.com/hdt3213/godis/interface/redis" + "github.com/hdt3213/godis/redis/reply" + "strings" +) + +func startMulti(db *DB, conn redis.Connection) redis.Reply { + if conn.InMultiState() { + return reply.MakeErrReply("ERR MULTI calls can not be nested") + } + conn.SetMultiState(true) + return reply.MakeOkReply() +} + +func enqueueCmd(db *DB, conn redis.Connection, cmdLine [][]byte) redis.Reply { + cmdName := strings.ToLower(string(cmdLine[0])) + cmd, ok := cmdTable[cmdName] + if !ok { + return reply.MakeErrReply("ERR unknown command '" + cmdName + "'") + } + if cmd.prepare == nil { + return reply.MakeErrReply("ERR command '" + cmdName + "' cannot be used in MULTI") + } + if !validateArity(cmd.arity, cmdLine) { + // difference with redis: we won't enqueue command line with wrong arity + return reply.MakeArgNumErrReply(cmdName) + } + conn.EnqueueCmd(cmdLine) + return reply.MakeQueuedReply() +} + +func execMulti(db *DB, conn redis.Connection) redis.Reply { + if !conn.InMultiState() { + return reply.MakeErrReply("ERR EXEC without MULTI") + } + defer conn.SetMultiState(false) + cmdLines := conn.GetQueuedCmdLine() + + // prepare + writeKeys := make([]string, 0) // may contains duplicate + readKeys := make([]string, 0) + for _, cmdLine := range cmdLines { + cmdName := strings.ToLower(string(cmdLine[0])) + cmd := cmdTable[cmdName] + prepare := cmd.prepare + write, read := prepare(cmdLine[1:]) + writeKeys = append(writeKeys, write...) + readKeys = append(readKeys, read...) + } + db.RWLocks(writeKeys, readKeys) + defer db.RWUnLocks(writeKeys, readKeys) + + // execute + results := make([][]byte, 0, len(cmdLines)) + aborted := false + undoCmdLines := make([][]CmdLine, 0, len(cmdLines)) + for _, cmdLine := range cmdLines { + undoCmdLines = append(undoCmdLines, db.GetUndoLogs(cmdLine)) + result := db.ExecWithLock(cmdLine) + if reply.IsErrorReply(result) { + aborted = true + // don't rollback failed commands + undoCmdLines = undoCmdLines[:len(undoCmdLines)-1] + break + } + results = append(results, result.ToBytes()) + } + if !aborted { + return reply.MakeMultiRawReply(results) + } + // undo if aborted + size := len(undoCmdLines) + for i := size - 1; i >= 0; i-- { + curCmdLines := undoCmdLines[i] + if len(curCmdLines) == 0 { + continue + } + for _, cmdLine := range curCmdLines { + db.ExecWithLock(cmdLine) + } + } + return reply.MakeErrReply("EXECABORT Transaction discarded because of previous errors.") +} + +func discardMulti(db *DB, conn redis.Connection) redis.Reply { + if !conn.InMultiState() { + return reply.MakeErrReply("ERR DISCARD without MULTI") + } + conn.ClearQueuedCmds() + conn.SetMultiState(false) + return reply.MakeQueuedReply() +} diff --git a/multi_test.go b/multi_test.go new file mode 100644 index 00000000..14421d84 --- /dev/null +++ b/multi_test.go @@ -0,0 +1,59 @@ +package godis + +import ( + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/connection" + "github.com/hdt3213/godis/redis/reply/asserts" + "testing" +) + +func TestMulti(t *testing.T) { + testDB.Flush() + conn := new(connection.FakeConn) + result := testDB.Exec(conn, utils.ToCmdLine("multi")) + asserts.AssertNotError(t, result) + key := utils.RandString(10) + value := utils.RandString(10) + testDB.Exec(conn, utils.ToCmdLine("set", key, value)) + key2 := utils.RandString(10) + testDB.Exec(conn, utils.ToCmdLine("rpush", key2, value)) + result = testDB.Exec(conn, utils.ToCmdLine("exec")) + asserts.AssertNotError(t, result) + result = testDB.Exec(conn, utils.ToCmdLine("get", key)) + asserts.AssertBulkReply(t, result, value) + result = testDB.Exec(conn, utils.ToCmdLine("lrange", key2, "0", "-1")) + asserts.AssertMultiBulkReply(t, result, []string{value}) +} + +func TestRollback(t *testing.T) { + testDB.Flush() + conn := new(connection.FakeConn) + result := testDB.Exec(conn, utils.ToCmdLine("multi")) + asserts.AssertNotError(t, result) + key := utils.RandString(10) + value := utils.RandString(10) + testDB.Exec(conn, utils.ToCmdLine("set", key, value)) + testDB.Exec(conn, utils.ToCmdLine("rpush", key, value)) + result = testDB.Exec(conn, utils.ToCmdLine("exec")) + asserts.AssertErrReply(t, result, "EXECABORT Transaction discarded because of previous errors.") + result = testDB.Exec(conn, utils.ToCmdLine("type", key)) + asserts.AssertStatusReply(t, result, "none") +} + +func TestDiscard(t *testing.T) { + testDB.Flush() + conn := new(connection.FakeConn) + result := testDB.Exec(conn, utils.ToCmdLine("multi")) + asserts.AssertNotError(t, result) + key := utils.RandString(10) + value := utils.RandString(10) + testDB.Exec(conn, utils.ToCmdLine("set", key, value)) + key2 := utils.RandString(10) + testDB.Exec(conn, utils.ToCmdLine("rpush", key2, value)) + result = testDB.Exec(conn, utils.ToCmdLine("discard")) + asserts.AssertNotError(t, result) + result = testDB.Exec(conn, utils.ToCmdLine("get", key)) + asserts.AssertNullBulk(t, result) + result = testDB.Exec(conn, utils.ToCmdLine("lrange", key2, "0", "-1")) + asserts.AssertMultiBulkReplySize(t, result, 0) +} diff --git a/multi_utils.go b/multi_utils.go new file mode 100644 index 00000000..b97aa3c9 --- /dev/null +++ b/multi_utils.go @@ -0,0 +1,173 @@ +package godis + +import ( + "github.com/hdt3213/godis/lib/utils" + "strconv" +) + +func readFirstKey(args [][]byte) ([]string, []string) { + // assert len(args) > 0 + key := string(args[0]) + return nil, []string{key} +} + +func writeFirstKey(args [][]byte) ([]string, []string) { + key := string(args[0]) + return []string{key}, nil +} + +func writeAllKeys(args [][]byte) ([]string, []string) { + keys := make([]string, len(args)) + for i, v := range args { + keys[i] = string(v) + } + return keys, nil +} + +func readAllKeys(args [][]byte) ([]string, []string) { + keys := make([]string, len(args)) + for i, v := range args { + keys[i] = string(v) + } + return nil, keys +} + +func noPrepare(args [][]byte) ([]string, []string) { + return nil, nil +} + +func rollbackFirstKey(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + return rollbackGivenKeys(db, key) +} + +func rollbackGivenKeys(db *DB, keys ...string) []CmdLine { + var undoCmdLines [][][]byte + for _, key := range keys { + entity, ok := db.GetEntity(key) + if !ok { + undoCmdLines = append(undoCmdLines, + utils.ToCmdLine("DEL", key), + ) + } else { + undoCmdLines = append(undoCmdLines, + utils.ToCmdLine("DEL", key), // clean existed first + EntityToCmd(key, entity).Args, + toTTLCmd(db, key).Args, + ) + } + } + return undoCmdLines +} + +func rollbackHashFields(db *DB, key string, fields ...string) []CmdLine { + var undoCmdLines [][][]byte + dict, errReply := db.getAsDict(key) + if errReply != nil { + return nil + } + if dict == nil { + undoCmdLines = append(undoCmdLines, + utils.ToCmdLine("DEL", key), + ) + return undoCmdLines + } + for _, field := range fields { + entity, ok := dict.Get(field) + if !ok { + undoCmdLines = append(undoCmdLines, + utils.ToCmdLine("HDEL", key, field), + ) + } else { + value, _ := entity.([]byte) + undoCmdLines = append(undoCmdLines, + utils.ToCmdLine("HSET", key, field, string(value)), + ) + } + } + return undoCmdLines +} + +func prepareSetCalculate(args [][]byte) ([]string, []string) { + keys := make([]string, len(args)) + for i, arg := range args { + keys[i] = string(arg) + } + return nil, keys +} + +func prepareSetCalculateStore(args [][]byte) ([]string, []string) { + dest := string(args[0]) + keys := make([]string, len(args)-1) + keyArgs := args[1:] + for i, arg := range keyArgs { + keys[i] = string(arg) + } + return []string{dest}, keys +} + +func rollbackSetMembers(db *DB, key string, members ...string) []CmdLine { + var undoCmdLines [][][]byte + set, errReply := db.getAsSet(key) + if errReply != nil { + return nil + } + if set == nil { + undoCmdLines = append(undoCmdLines, + utils.ToCmdLine("DEL", key), + ) + return undoCmdLines + } + for _, member := range members { + ok := set.Has(member) + if !ok { + undoCmdLines = append(undoCmdLines, + utils.ToCmdLine("SREM", key, member), + ) + } else { + undoCmdLines = append(undoCmdLines, + utils.ToCmdLine("SADD", key, member), + ) + } + } + return undoCmdLines +} + +// undoSetChange rollbacks SADD and SREM command +func undoSetChange(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + memberArgs := args[1:] + members := make([]string, len(memberArgs)) + for i, mem := range memberArgs { + members[i] = string(mem) + } + return rollbackSetMembers(db, key, members...) +} + +func rollbackZSetFields(db *DB, key string, fields ...string) []CmdLine { + var undoCmdLines [][][]byte + zset, errReply := db.getAsSortedSet(key) + if errReply != nil { + return nil + } + if zset == nil { + undoCmdLines = append(undoCmdLines, + utils.ToCmdLine("DEL", key), + ) + return undoCmdLines + } + for _, field := range fields { + elem, ok := zset.Get(field) + if !ok { + undoCmdLines = append(undoCmdLines, + utils.ToCmdLine("ZREM", key, field), + ) + } else { + score := strconv.FormatFloat(elem.Score, 'f', -1, 64) + undoCmdLines = append(undoCmdLines, + utils.ToCmdLine("ZADD", key, score, field), + ) + } + } + return undoCmdLines +} diff --git a/multi_utils_test.go b/multi_utils_test.go new file mode 100644 index 00000000..38338201 --- /dev/null +++ b/multi_utils_test.go @@ -0,0 +1,186 @@ +package godis + +import ( + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/reply/asserts" + "testing" + "time" +) + +func TestRollbackGivenKeys(t *testing.T) { + testDB.Flush() + + // rollback to string + key := utils.RandString(10) + value := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine("SET", key, value, "EX", "200")) + undoCmdLines := rollbackGivenKeys(testDB, key) + rawExpire, _ := testDB.ttlMap.Get(key) + expireTime, _ := rawExpire.(time.Time) + // override given key + value2 := value + utils.RandString(5) + testDB.Exec(nil, utils.ToCmdLine("SET", key, value2, "EX", "1000")) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + actual := testDB.Exec(nil, utils.ToCmdLine("GET", key)) + asserts.AssertBulkReply(t, actual, value) + rawExpire2, _ := testDB.ttlMap.Get(key) + expireTime2, _ := rawExpire2.(time.Time) + timeDiff := expireTime.Sub(expireTime2) + if timeDiff < -time.Millisecond || timeDiff > time.Millisecond { + t.Error("rollback ttl failed") + } +} + +func TestRollbackToList(t *testing.T) { + key := utils.RandString(10) + value := utils.RandString(10) + value2 := utils.RandString(10) + testDB.Remove(key) + testDB.Exec(nil, utils.ToCmdLine("RPUSH", key, value, value2)) + undoCmdLines := rollbackGivenKeys(testDB, key) + testDB.Exec(nil, utils.ToCmdLine("LREM", key, value2)) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + actual := testDB.Exec(nil, utils.ToCmdLine("LRANGE", key, "0", "-1")) + asserts.AssertMultiBulkReply(t, actual, []string{value, value2}) +} + +func TestRollbackToSet(t *testing.T) { + key := utils.RandString(10) + value := utils.RandString(10) + value2 := utils.RandString(10) + testDB.Remove(key) + cmdLine := utils.ToCmdLine("SADD", key, value) + testDB.Exec(nil, cmdLine) + undoCmdLines := rollbackFirstKey(testDB, cmdLine[1:]) + testDB.Exec(nil, utils.ToCmdLine("SADD", key, value2)) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + actual := testDB.Exec(nil, utils.ToCmdLine("SMembers", key)) + asserts.AssertMultiBulkReply(t, actual, []string{value}) +} + +func TestRollbackSetMembers(t *testing.T) { + key := utils.RandString(10) + value := utils.RandString(10) + testDB.Remove(key) + + // undo srem + cmdLine := utils.ToCmdLine("SADD", key, value) + testDB.Exec(nil, cmdLine) + undoCmdLines := undoSetChange(testDB, cmdLine[1:]) + testDB.Exec(nil, utils.ToCmdLine("SREM", key, value)) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + actual := testDB.Exec(nil, utils.ToCmdLine("SIsMember", key, value)) + asserts.AssertIntReply(t, actual, 1) + + // undo sadd + testDB.Remove(key) + value2 := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine("SADD", key, value2)) + cmdLine = utils.ToCmdLine("SADD", key, value) + undoCmdLines = undoSetChange(testDB, cmdLine[1:]) + testDB.Exec(nil, cmdLine) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + actual = testDB.Exec(nil, utils.ToCmdLine("SIsMember", key, value)) + asserts.AssertIntReply(t, actual, 0) + + // undo sadd, only member + testDB.Remove(key) + undoCmdLines = rollbackSetMembers(testDB, key, value) + testDB.Exec(nil, utils.ToCmdLine("SAdd", key, value)) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + actual = testDB.Exec(nil, utils.ToCmdLine("type", key)) + asserts.AssertStatusReply(t, actual, "none") +} + +func TestRollbackToHash(t *testing.T) { + key := utils.RandString(10) + value := utils.RandString(10) + value2 := utils.RandString(10) + testDB.Remove(key) + testDB.Exec(nil, utils.ToCmdLine("HSet", key, value, value)) + undoCmdLines := rollbackGivenKeys(testDB, key) + testDB.Exec(nil, utils.ToCmdLine("HSet", key, value, value2)) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + actual := testDB.Exec(nil, utils.ToCmdLine("HGET", key, value)) + asserts.AssertBulkReply(t, actual, value) +} + +func TestRollbackHashFields(t *testing.T) { + key := utils.RandString(10) + value := utils.RandString(10) + value2 := utils.RandString(10) + testDB.Remove(key) + testDB.Exec(nil, utils.ToCmdLine("HSet", key, value, value)) + undoCmdLines := rollbackHashFields(testDB, key, value, value2) + testDB.Exec(nil, utils.ToCmdLine("HSet", key, value, value2, value2, value2)) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + actual := testDB.Exec(nil, utils.ToCmdLine("HGET", key, value)) + asserts.AssertBulkReply(t, actual, value) + actual = testDB.Exec(nil, utils.ToCmdLine("HGET", key, value2)) + asserts.AssertNullBulk(t, actual) + + testDB.Remove(key) + undoCmdLines = rollbackHashFields(testDB, key, value) + testDB.Exec(nil, utils.ToCmdLine("HSet", key, value, value)) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + actual = testDB.Exec(nil, utils.ToCmdLine("type", key)) + asserts.AssertStatusReply(t, actual, "none") +} + +func TestRollbackToZSet(t *testing.T) { + key := utils.RandString(10) + value := utils.RandString(10) + testDB.Remove(key) + testDB.Exec(nil, utils.ToCmdLine("ZADD", key, "1", value)) + undoCmdLines := rollbackGivenKeys(testDB, key) + testDB.Exec(nil, utils.ToCmdLine("ZADD", key, "2", value)) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + actual := testDB.Exec(nil, utils.ToCmdLine("ZSCORE", key, value)) + asserts.AssertBulkReply(t, actual, "1") +} + +func TestRollbackZSetFields(t *testing.T) { + key := utils.RandString(10) + value := utils.RandString(10) + value2 := utils.RandString(10) + testDB.Remove(key) + testDB.Exec(nil, utils.ToCmdLine("ZADD", key, "1", value)) + undoCmdLines := rollbackZSetFields(testDB, key, value, value2) + testDB.Exec(nil, utils.ToCmdLine("ZADD", key, "2", value, "3", value2)) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + actual := testDB.Exec(nil, utils.ToCmdLine("ZSCORE", key, value)) + asserts.AssertBulkReply(t, actual, "1") + actual = testDB.Exec(nil, utils.ToCmdLine("ZSCORE", key, value2)) + asserts.AssertNullBulk(t, actual) + + testDB.Remove(key) + undoCmdLines = rollbackZSetFields(testDB, key, value) + testDB.Exec(nil, utils.ToCmdLine("ZADD", key, "1", value)) + for _, cmdLine := range undoCmdLines { + testDB.Exec(nil, cmdLine) + } + actual = testDB.Exec(nil, utils.ToCmdLine("type", key)) + asserts.AssertStatusReply(t, actual, "none") +} diff --git a/prepare.go b/prepare.go deleted file mode 100644 index 4de41bc7..00000000 --- a/prepare.go +++ /dev/null @@ -1,5 +0,0 @@ -package godis - -func noPre(args [][]byte) ([]string, [][][]byte) { - return nil, nil -} diff --git a/redis/connection/conn.go b/redis/connection/conn.go index 023face2..3d101720 100644 --- a/redis/connection/conn.go +++ b/redis/connection/conn.go @@ -23,6 +23,10 @@ type Connection struct { // password may be changed by CONFIG command during runtime, so store the password password string + + // queued commands for `multi` + multiState bool + queue [][][]byte } // RemoteAddr returns the remote network address @@ -76,7 +80,7 @@ func (c *Connection) UnSubscribe(channel string) { c.mu.Lock() defer c.mu.Unlock() - if c.subs == nil { + if len(c.subs) == 0 { return } delete(c.subs, channel) @@ -84,9 +88,6 @@ func (c *Connection) UnSubscribe(channel string) { // SubsCount returns the number of subscribing channels func (c *Connection) SubsCount() int { - if c.subs == nil { - return 0 - } return len(c.subs) } @@ -114,6 +115,26 @@ func (c *Connection) GetPassword() string { return c.password } +func (c *Connection) InMultiState() bool { + return c.multiState +} + +func (c *Connection) SetMultiState(state bool) { + c.multiState = state +} + +func (c *Connection) GetQueuedCmdLine() [][][]byte { + return c.queue +} + +func (c *Connection) EnqueueCmd(cmdLine [][]byte) { + c.queue = append(c.queue, cmdLine) +} + +func (c *Connection) ClearQueuedCmds() { + c.queue = nil +} + // FakeConn implements redis.Connection for test type FakeConn struct { Connection diff --git a/redis/parser/parser_test.go b/redis/parser/parser_test.go index 0494cdc7..c87f8830 100644 --- a/redis/parser/parser_test.go +++ b/redis/parser/parser_test.go @@ -2,8 +2,8 @@ package parser import ( "bytes" - "github.com/hdt3213/godis/datastruct/utils" "github.com/hdt3213/godis/interface/redis" + "github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/redis/reply" "io" "testing" diff --git a/redis/reply/asserts/assert.go b/redis/reply/asserts/assert.go index 78dd32e9..3edadd72 100644 --- a/redis/reply/asserts/assert.go +++ b/redis/reply/asserts/assert.go @@ -2,8 +2,8 @@ package asserts import ( "fmt" - "github.com/hdt3213/godis/datastruct/utils" "github.com/hdt3213/godis/interface/redis" + "github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/redis/reply" "runtime" "testing" diff --git a/redis/reply/consts.go b/redis/reply/consts.go index 2ae810f8..7957030e 100644 --- a/redis/reply/consts.go +++ b/redis/reply/consts.go @@ -20,6 +20,12 @@ func (r *OkReply) ToBytes() []byte { return okBytes } +var theOkReply = new(OkReply) + +func MakeOkReply() *OkReply { + return theOkReply +} + var nullBulkBytes = []byte("$-1\r\n") // NullBulkReply is empty string @@ -59,3 +65,19 @@ var noBytes = []byte("") func (r *NoReply) ToBytes() []byte { return noBytes } + +// QueuedReply is +QUEUED +type QueuedReply struct{} + +var queuedBytes = []byte("+QUEUED\r\n") + +// ToBytes marshal redis.Reply +func (r *QueuedReply) ToBytes() []byte { + return queuedBytes +} + +var theQueuedReply = new(QueuedReply) + +func MakeQueuedReply() *QueuedReply { + return theQueuedReply +} diff --git a/redis/server/pubsub_test.go b/redis/server/pubsub_test.go index b49eb742..d1a810a8 100644 --- a/redis/server/pubsub_test.go +++ b/redis/server/pubsub_test.go @@ -14,9 +14,9 @@ func TestPublish(t *testing.T) { channel := utils.RandString(5) msg := utils.RandString(5) conn := &connection.FakeConn{} - pubsub.Subscribe(hub, conn, utils.ToBytesList(channel)) + pubsub.Subscribe(hub, conn, utils.ToCmdLine(channel)) conn.Clean() // clean subscribe success - pubsub.Publish(hub, utils.ToBytesList(channel, msg)) + pubsub.Publish(hub, utils.ToCmdLine(channel, msg)) data := conn.Bytes() ret, err := parser.ParseOne(data) if err != nil { @@ -30,19 +30,19 @@ func TestPublish(t *testing.T) { }) // unsubscribe - pubsub.UnSubscribe(hub, conn, utils.ToBytesList(channel)) + pubsub.UnSubscribe(hub, conn, utils.ToCmdLine(channel)) conn.Clean() - pubsub.Publish(hub, utils.ToBytesList(channel, msg)) + pubsub.Publish(hub, utils.ToCmdLine(channel, msg)) data = conn.Bytes() if len(data) > 0 { t.Error("expect no msg") } // unsubscribe all - pubsub.Subscribe(hub, conn, utils.ToBytesList(channel)) - pubsub.UnSubscribe(hub, conn, utils.ToBytesList()) + pubsub.Subscribe(hub, conn, utils.ToCmdLine(channel)) + pubsub.UnSubscribe(hub, conn, utils.ToCmdLine()) conn.Clean() - pubsub.Publish(hub, utils.ToBytesList(channel, msg)) + pubsub.Publish(hub, utils.ToCmdLine(channel, msg)) data = conn.Bytes() if len(data) > 0 { t.Error("expect no msg") diff --git a/router.go b/router.go index 3f63491a..28de869a 100644 --- a/router.go +++ b/router.go @@ -1,23 +1,27 @@ package godis -import "strings" +import ( + "strings" +) var cmdTable = make(map[string]*command) type command struct { executor ExecFunc - prepare PreFunc // return related keys and rollback command - arity int // allow number of args, arity < 0 means len(args) >= -arity + prepare PreFunc // return related keys command + undo UndoFunc + arity int // allow number of args, arity < 0 means len(args) >= -arity } // RegisterCommand registers a new command // arity means allowed number of cmdArgs, arity < 0 means len(args) >= -arity. // for example: the arity of `get` is 2, `mget` is -2 -func RegisterCommand(name string, executor ExecFunc, prepare PreFunc, arity int) { +func RegisterCommand(name string, executor ExecFunc, prepare PreFunc, rollback UndoFunc, arity int) { name = strings.ToLower(name) cmdTable[name] = &command{ executor: executor, prepare: prepare, + undo: rollback, arity: arity, } } diff --git a/server.go b/server.go index 8e518b30..a01341a7 100644 --- a/server.go +++ b/server.go @@ -41,5 +41,5 @@ func isAuthenticated(c redis.Connection) bool { } func init() { - RegisterCommand("ping", Ping, nil, -1) + RegisterCommand("ping", Ping, noPrepare, nil, -1) } diff --git a/server_test.go b/server_test.go index 291cadcd..bfd8bc2a 100644 --- a/server_test.go +++ b/server_test.go @@ -9,32 +9,32 @@ import ( ) func TestPing(t *testing.T) { - actual := Ping(testDB, utils.ToBytesList()) + actual := Ping(testDB, utils.ToCmdLine()) asserts.AssertStatusReply(t, actual, "PONG") val := utils.RandString(5) - actual = Ping(testDB, utils.ToBytesList(val)) + actual = Ping(testDB, utils.ToCmdLine(val)) asserts.AssertStatusReply(t, actual, val) - actual = Ping(testDB, utils.ToBytesList(val, val)) + actual = Ping(testDB, utils.ToCmdLine(val, val)) asserts.AssertErrReply(t, actual, "ERR wrong number of arguments for 'ping' command") } func TestAuth(t *testing.T) { passwd := utils.RandString(10) c := &connection.FakeConn{} - ret := testDB.Exec(c, utils.ToBytesList("AUTH")) + ret := testDB.Exec(c, utils.ToCmdLine("AUTH")) asserts.AssertErrReply(t, ret, "ERR wrong number of arguments for 'auth' command") - ret = testDB.Exec(c, utils.ToBytesList("AUTH", passwd)) + ret = testDB.Exec(c, utils.ToCmdLine("AUTH", passwd)) asserts.AssertErrReply(t, ret, "ERR Client sent AUTH, but no password is set") config.Properties.RequirePass = passwd defer func() { config.Properties.RequirePass = "" }() - ret = testDB.Exec(c, utils.ToBytesList("AUTH", passwd+"wrong")) + ret = testDB.Exec(c, utils.ToCmdLine("AUTH", passwd+"wrong")) asserts.AssertErrReply(t, ret, "ERR invalid password") - ret = testDB.Exec(c, utils.ToBytesList("PING")) + ret = testDB.Exec(c, utils.ToCmdLine("PING")) asserts.AssertErrReply(t, ret, "NOAUTH Authentication required") - ret = testDB.Exec(c, utils.ToBytesList("AUTH", passwd)) + ret = testDB.Exec(c, utils.ToCmdLine("AUTH", passwd)) asserts.AssertStatusReply(t, ret, "OK") } diff --git a/set.go b/set.go index 0a31213b..5fda13aa 100644 --- a/set.go +++ b/set.go @@ -40,10 +40,6 @@ func execSAdd(db *DB, args [][]byte) redis.Reply { key := string(args[0]) members := args[1:] - // lock - db.Lock(key) - defer db.UnLock(key) - // get or init entity set, _, errReply := db.getOrInitSet(key) if errReply != nil { @@ -62,9 +58,6 @@ func execSIsMember(db *DB, args [][]byte) redis.Reply { key := string(args[0]) member := string(args[1]) - db.RLock(key) - defer db.RUnLock(key) - // get set set, errReply := db.getAsSet(key) if errReply != nil { @@ -86,10 +79,6 @@ func execSRem(db *DB, args [][]byte) redis.Reply { key := string(args[0]) members := args[1:] - // lock - db.Lock(key) - defer db.UnLock(key) - set, errReply := db.getAsSet(key) if errReply != nil { return errReply @@ -114,9 +103,6 @@ func execSRem(db *DB, args [][]byte) redis.Reply { func execSCard(db *DB, args [][]byte) redis.Reply { key := string(args[0]) - db.RLock(key) - defer db.RUnLock(key) - // get or init entity set, errReply := db.getAsSet(key) if errReply != nil { @@ -132,10 +118,6 @@ func execSCard(db *DB, args [][]byte) redis.Reply { func execSMembers(db *DB, args [][]byte) redis.Reply { key := string(args[0]) - // lock - db.RLock(key) - defer db.RUnLock(key) - // get or init entity set, errReply := db.getAsSet(key) if errReply != nil { @@ -162,10 +144,6 @@ func execSInter(db *DB, args [][]byte) redis.Reply { keys[i] = string(arg) } - // lock - db.RLocks(keys...) - defer db.RUnLocks(keys...) - var result *HashSet.Set for _, key := range keys { set, errReply := db.getAsSet(key) @@ -207,10 +185,6 @@ func execSInterStore(db *DB, args [][]byte) redis.Reply { keys[i] = string(arg) } - // lock - db.RWLocks([]string{dest}, keys) - defer db.RWUnLocks([]string{dest}, keys) - var result *HashSet.Set for _, key := range keys { set, errReply := db.getAsSet(key) @@ -250,10 +224,6 @@ func execSUnion(db *DB, args [][]byte) redis.Reply { keys[i] = string(arg) } - // lock - db.RLocks(keys...) - defer db.RUnLocks(keys...) - var result *HashSet.Set for _, key := range keys { set, errReply := db.getAsSet(key) @@ -295,10 +265,6 @@ func execSUnionStore(db *DB, args [][]byte) redis.Reply { keys[i] = string(arg) } - // lock - db.RWLocks([]string{dest}, keys) - defer db.RWUnLocks([]string{dest}, keys) - var result *HashSet.Set for _, key := range keys { set, errReply := db.getAsSet(key) @@ -338,10 +304,6 @@ func execSDiff(db *DB, args [][]byte) redis.Reply { keys[i] = string(arg) } - // lock - db.RLocks(keys...) - defer db.RUnLocks(keys...) - var result *HashSet.Set for i, key := range keys { set, errReply := db.getAsSet(key) @@ -390,10 +352,6 @@ func execSDiffStore(db *DB, args [][]byte) redis.Reply { keys[i] = string(arg) } - // lock - db.RWLocks([]string{dest}, keys) - defer db.RWUnLocks([]string{dest}, keys) - var result *HashSet.Set for i, key := range keys { set, errReply := db.getAsSet(key) @@ -441,9 +399,6 @@ func execSRandMember(db *DB, args [][]byte) redis.Reply { return reply.MakeErrReply("ERR wrong number of arguments for 'srandmember' command") } key := string(args[0]) - // lock - db.RLock(key) - defer db.RUnLock(key) // get or init entity set, errReply := db.getAsSet(key) @@ -482,16 +437,16 @@ func execSRandMember(db *DB, args [][]byte) redis.Reply { } func init() { - RegisterCommand("SAdd", execSAdd, nil, -3) - RegisterCommand("SIsMember", execSIsMember, nil, 3) - RegisterCommand("SRem", execSRem, nil, -3) - RegisterCommand("SCard", execSCard, nil, 2) - RegisterCommand("SMembers", execSMembers, nil, 2) - RegisterCommand("SInter", execSInter, nil, -2) - RegisterCommand("SInterStore", execSInterStore, nil, -3) - RegisterCommand("SUnion", execSUnion, nil, -2) - RegisterCommand("SUnionStore", execSUnionStore, nil, -3) - RegisterCommand("SDiff", execSDiff, nil, -2) - RegisterCommand("SDiffStore", execSDiffStore, nil, -3) - RegisterCommand("SRandMember", execSRandMember, nil, -2) + RegisterCommand("SAdd", execSAdd, writeFirstKey, undoSetChange, -3) + RegisterCommand("SIsMember", execSIsMember, readFirstKey, nil, 3) + RegisterCommand("SRem", execSRem, writeFirstKey, undoSetChange, -3) + RegisterCommand("SCard", execSCard, readFirstKey, nil, 2) + RegisterCommand("SMembers", execSMembers, readFirstKey, nil, 2) + RegisterCommand("SInter", execSInter, prepareSetCalculate, nil, -2) + RegisterCommand("SInterStore", execSInterStore, prepareSetCalculateStore, rollbackFirstKey, -3) + RegisterCommand("SUnion", execSUnion, prepareSetCalculate, nil, -2) + RegisterCommand("SUnionStore", execSUnionStore, prepareSetCalculateStore, rollbackFirstKey, -3) + RegisterCommand("SDiff", execSDiff, prepareSetCalculate, nil, -2) + RegisterCommand("SDiffStore", execSDiffStore, prepareSetCalculateStore, rollbackFirstKey, -3) + RegisterCommand("SRandMember", execSRandMember, readFirstKey, nil, -2) } diff --git a/set_test.go b/set_test.go index 820cace8..f8b67161 100644 --- a/set_test.go +++ b/set_test.go @@ -11,29 +11,29 @@ import ( // basic add get and remove func TestSAdd(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 // test sadd key := utils.RandString(10) for i := 0; i < size; i++ { member := strconv.Itoa(i) - result := execSAdd(testDB, utils.ToBytesList(key, member)) + result := testDB.Exec(nil, utils.ToCmdLine("sadd", key, member)) asserts.AssertIntReply(t, result, 1) } // test scard - result := execSCard(testDB, utils.ToBytesList(key)) + result := testDB.Exec(nil, utils.ToCmdLine("SCard", key)) asserts.AssertIntReply(t, result, size) // test is member for i := 0; i < size; i++ { member := strconv.Itoa(i) - result := execSIsMember(testDB, utils.ToBytesList(key, member)) + result = testDB.Exec(nil, utils.ToCmdLine("SIsMember", key, member)) asserts.AssertIntReply(t, result, 1) } // test members - result = execSMembers(testDB, utils.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("SMembers", key)) multiBulk, ok := result.(*reply.MultiBulkReply) if !ok { t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes())) @@ -46,25 +46,25 @@ func TestSAdd(t *testing.T) { } func TestSRem(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 // mock data key := utils.RandString(10) for i := 0; i < size; i++ { member := strconv.Itoa(i) - execSAdd(testDB, utils.ToBytesList(key, member)) + testDB.Exec(nil, utils.ToCmdLine("sadd", key, member)) } for i := 0; i < size; i++ { member := strconv.Itoa(i) - execSRem(testDB, utils.ToBytesList(key, member)) - result := execSIsMember(testDB, utils.ToBytesList(key, member)) + testDB.Exec(nil, utils.ToCmdLine("srem", key, member)) + result := testDB.Exec(nil, utils.ToCmdLine("SIsMember", key, member)) asserts.AssertIntReply(t, result, 0) } } func TestSInter(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 step := 10 @@ -75,39 +75,39 @@ func TestSInter(t *testing.T) { keys = append(keys, key) for j := start; j < size+start; j++ { member := strconv.Itoa(j) - execSAdd(testDB, utils.ToBytesList(key, member)) + testDB.Exec(nil, utils.ToCmdLine("sadd", key, member)) } start += step } - result := execSInter(testDB, utils.ToBytesList(keys...)) + result := testDB.Exec(nil, utils.ToCmdLine2("sinter", keys...)) asserts.AssertMultiBulkReplySize(t, result, 70) destKey := utils.RandString(10) keysWithDest := []string{destKey} keysWithDest = append(keysWithDest, keys...) - result = execSInterStore(testDB, utils.ToBytesList(keysWithDest...)) + result = testDB.Exec(nil, utils.ToCmdLine2("SInterStore", keysWithDest...)) asserts.AssertIntReply(t, result, 70) // test empty set - execFlushAll(testDB, [][]byte{}) + testDB.Flush() key0 := utils.RandString(10) testDB.Remove(key0) key1 := utils.RandString(10) - execSAdd(testDB, utils.ToBytesList(key1, "a", "b")) + testDB.Exec(nil, utils.ToCmdLine("sadd", key1, "a", "b")) key2 := utils.RandString(10) - execSAdd(testDB, utils.ToBytesList(key2, "1", "2")) - result = execSInter(testDB, utils.ToBytesList(key0, key1, key2)) + testDB.Exec(nil, utils.ToCmdLine("sadd", key1, "1", "2")) + result = testDB.Exec(nil, utils.ToCmdLine("sinter", key0, key1, key2)) asserts.AssertMultiBulkReplySize(t, result, 0) - result = execSInter(testDB, utils.ToBytesList(key1, key2)) + result = testDB.Exec(nil, utils.ToCmdLine("sinter", key1, key2)) asserts.AssertMultiBulkReplySize(t, result, 0) - result = execSInterStore(testDB, utils.ToBytesList(utils.RandString(10), key0, key1, key2)) + result = testDB.Exec(nil, utils.ToCmdLine("sinterstore", utils.RandString(10), key0, key1, key2)) asserts.AssertIntReply(t, result, 0) - result = execSInterStore(testDB, utils.ToBytesList(utils.RandString(10), key1, key2)) + result = testDB.Exec(nil, utils.ToCmdLine("sinterstore", utils.RandString(10), key1, key2)) asserts.AssertIntReply(t, result, 0) } func TestSUnion(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 step := 10 @@ -118,22 +118,22 @@ func TestSUnion(t *testing.T) { keys = append(keys, key) for j := start; j < size+start; j++ { member := strconv.Itoa(j) - execSAdd(testDB, utils.ToBytesList(key, member)) + testDB.Exec(nil, utils.ToCmdLine("sadd", key, member)) } start += step } - result := execSUnion(testDB, utils.ToBytesList(keys...)) + result := testDB.Exec(nil, utils.ToCmdLine2("sunion", keys...)) asserts.AssertMultiBulkReplySize(t, result, 130) destKey := utils.RandString(10) keysWithDest := []string{destKey} keysWithDest = append(keysWithDest, keys...) - result = execSUnionStore(testDB, utils.ToBytesList(keysWithDest...)) + result = testDB.Exec(nil, utils.ToCmdLine2("SUnionStore", keysWithDest...)) asserts.AssertIntReply(t, result, 130) } func TestSDiff(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 step := 20 @@ -144,52 +144,52 @@ func TestSDiff(t *testing.T) { keys = append(keys, key) for j := start; j < size+start; j++ { member := strconv.Itoa(j) - execSAdd(testDB, utils.ToBytesList(key, member)) + testDB.Exec(nil, utils.ToCmdLine("sadd", key, member)) } start += step } - result := execSDiff(testDB, utils.ToBytesList(keys...)) + result := testDB.Exec(nil, utils.ToCmdLine2("SDiff", keys...)) asserts.AssertMultiBulkReplySize(t, result, step) destKey := utils.RandString(10) keysWithDest := []string{destKey} keysWithDest = append(keysWithDest, keys...) - result = execSDiffStore(testDB, utils.ToBytesList(keysWithDest...)) + result = testDB.Exec(nil, utils.ToCmdLine2("SDiffStore", keysWithDest...)) asserts.AssertIntReply(t, result, step) // test empty set - execFlushAll(testDB, [][]byte{}) + testDB.Flush() key0 := utils.RandString(10) testDB.Remove(key0) key1 := utils.RandString(10) - execSAdd(testDB, utils.ToBytesList(key1, "a", "b")) + testDB.Exec(nil, utils.ToCmdLine("sadd", key1, "a", "b")) key2 := utils.RandString(10) - execSAdd(testDB, utils.ToBytesList(key2, "a", "b")) - result = execSDiff(testDB, utils.ToBytesList(key0, key1, key2)) + testDB.Exec(nil, utils.ToCmdLine("sadd", key2, "a", "b")) + result = testDB.Exec(nil, utils.ToCmdLine("sdiff", key0, key1, key2)) asserts.AssertMultiBulkReplySize(t, result, 0) - result = execSDiff(testDB, utils.ToBytesList(key1, key2)) + result = testDB.Exec(nil, utils.ToCmdLine("sdiff", key1, key2)) asserts.AssertMultiBulkReplySize(t, result, 0) - result = execSDiffStore(testDB, utils.ToBytesList(utils.RandString(10), key0, key1, key2)) + result = testDB.Exec(nil, utils.ToCmdLine("SDiffStore", utils.RandString(10), key0, key1, key2)) asserts.AssertIntReply(t, result, 0) - result = execSDiffStore(testDB, utils.ToBytesList(utils.RandString(10), key1, key2)) + result = testDB.Exec(nil, utils.ToCmdLine("SDiffStore", utils.RandString(10), key1, key2)) asserts.AssertIntReply(t, result, 0) } func TestSRandMember(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() key := utils.RandString(10) for j := 0; j < 100; j++ { member := strconv.Itoa(j) - execSAdd(testDB, utils.ToBytesList(key, member)) + testDB.Exec(nil, utils.ToCmdLine("sadd", key, member)) } - result := execSRandMember(testDB, utils.ToBytesList(key)) + result := testDB.Exec(nil, utils.ToCmdLine("SRandMember", key)) br, ok := result.(*reply.BulkReply) if !ok && len(br.Arg) > 0 { t.Error(fmt.Sprintf("expected bulk reply, actually %s", result.ToBytes())) return } - result = execSRandMember(testDB, utils.ToBytesList(key, "10")) + result = testDB.Exec(nil, utils.ToCmdLine("SRandMember", key, "10")) asserts.AssertMultiBulkReplySize(t, result, 10) multiBulk, ok := result.(*reply.MultiBulkReply) if !ok { @@ -205,12 +205,12 @@ func TestSRandMember(t *testing.T) { return } - result = execSRandMember(testDB, utils.ToBytesList(key, "110")) + result = testDB.Exec(nil, utils.ToCmdLine("SRandMember", key, "110")) asserts.AssertMultiBulkReplySize(t, result, 100) - result = execSRandMember(testDB, utils.ToBytesList(key, "-10")) + result = testDB.Exec(nil, utils.ToCmdLine("SRandMember", key, "-10")) asserts.AssertMultiBulkReplySize(t, result, 10) - result = execSRandMember(testDB, utils.ToBytesList(key, "-110")) + result = testDB.Exec(nil, utils.ToCmdLine("SRandMember", key, "-110")) asserts.AssertMultiBulkReplySize(t, result, 110) } diff --git a/sortedset.go b/sortedset.go index 687c8426..e4d97d9b 100644 --- a/sortedset.go +++ b/sortedset.go @@ -57,10 +57,6 @@ func execZAdd(db *DB, args [][]byte) redis.Reply { } } - // lock - db.Lock(key) - defer db.UnLock(key) - // get or init entity sortedSet, _, errReply := db.getOrInitSortedSet(key) if errReply != nil { @@ -79,14 +75,22 @@ func execZAdd(db *DB, args [][]byte) redis.Reply { return reply.MakeIntReply(int64(i)) } +func undoZAdd(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + size := (len(args) - 1) / 2 + fields := make([]string, size) + for i := 0; i < size; i++ { + fields[i] = string(args[2*i+2]) + } + return rollbackZSetFields(db, key, fields...) +} + // execZScore gets score of a member in sortedset func execZScore(db *DB, args [][]byte) redis.Reply { // parse args key := string(args[0]) member := string(args[1]) - db.RLock(key) - defer db.RUnLock(key) sortedSet, errReply := db.getAsSortedSet(key) if errReply != nil { return errReply @@ -110,8 +114,6 @@ func execZRank(db *DB, args [][]byte) redis.Reply { member := string(args[1]) // get entity - db.RLock(key) - defer db.RUnLock(key) sortedSet, errReply := db.getAsSortedSet(key) if errReply != nil { return errReply @@ -134,8 +136,6 @@ func execZRevRank(db *DB, args [][]byte) redis.Reply { member := string(args[1]) // get entity - db.RLock(key) - defer db.RUnLock(key) sortedSet, errReply := db.getAsSortedSet(key) if errReply != nil { return errReply @@ -157,8 +157,6 @@ func execZCard(db *DB, args [][]byte) redis.Reply { key := string(args[0]) // get entity - db.RLock(key) - defer db.RUnLock(key) sortedSet, errReply := db.getAsSortedSet(key) if errReply != nil { return errReply @@ -221,10 +219,6 @@ func execZRevRange(db *DB, args [][]byte) redis.Reply { } func range0(db *DB, key string, start int64, stop int64, withScores bool, desc bool) redis.Reply { - // lock key - db.RLock(key) - defer db.RUnLock(key) - // get data sortedSet, errReply := db.getAsSortedSet(key) if errReply != nil { @@ -293,9 +287,6 @@ func execZCount(db *DB, args [][]byte) redis.Reply { return reply.MakeErrReply(err.Error()) } - db.RLock(key) - defer db.RUnLock(key) - // get data sortedSet, errReply := db.getAsSortedSet(key) if errReply != nil { @@ -312,10 +303,6 @@ func execZCount(db *DB, args [][]byte) redis.Reply { * param limit: limit < 0 means no limit */ func rangeByScore0(db *DB, key string, min *SortedSet.ScoreBorder, max *SortedSet.ScoreBorder, offset int64, limit int64, withScores bool, desc bool) redis.Reply { - // lock key - db.RLock(key) - defer db.RUnLock(key) - // get data sortedSet, errReply := db.getAsSortedSet(key) if errReply != nil { @@ -458,9 +445,6 @@ func execZRemRangeByScore(db *DB, args [][]byte) redis.Reply { return reply.MakeErrReply(err.Error()) } - db.Lock(key) - defer db.UnLock(key) - // get data sortedSet, errReply := db.getAsSortedSet(key) if errReply != nil { @@ -489,9 +473,6 @@ func execZRemRangeByRank(db *DB, args [][]byte) redis.Reply { return reply.MakeErrReply("ERR value is not an integer or out of range") } - db.Lock(key) - defer db.UnLock(key) - // get data sortedSet, errReply := db.getAsSortedSet(key) if errReply != nil { @@ -541,9 +522,6 @@ func execZRem(db *DB, args [][]byte) redis.Reply { fields[i] = string(v) } - db.Lock(key) - defer db.UnLock(key) - // get entity sortedSet, errReply := db.getAsSortedSet(key) if errReply != nil { @@ -565,6 +543,16 @@ func execZRem(db *DB, args [][]byte) redis.Reply { return reply.MakeIntReply(deleted) } +func undoZRem(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + fields := make([]string, len(args)-1) + fieldArgs := args[1:] + for i, v := range fieldArgs { + fields[i] = string(v) + } + return rollbackZSetFields(db, key, fields...) +} + // execZIncrBy increments the score of a member func execZIncrBy(db *DB, args [][]byte) redis.Reply { key := string(args[0]) @@ -575,9 +563,6 @@ func execZIncrBy(db *DB, args [][]byte) redis.Reply { return reply.MakeErrReply("ERR value is not a valid float") } - db.Lock(key) - defer db.UnLock(key) - // get or init entity sortedSet, _, errReply := db.getOrInitSortedSet(key) if errReply != nil { @@ -597,19 +582,27 @@ func execZIncrBy(db *DB, args [][]byte) redis.Reply { return reply.MakeBulkReply(bytes) } +func undoZIncr(db *DB, args [][]byte) []CmdLine { + key := string(args[0]) + field := string(args[2]) + return rollbackZSetFields(db, key, field) +} + func init() { - RegisterCommand("ZAdd", execZAdd, nil, -4) - RegisterCommand("ZScore", execZScore, nil, 3) - RegisterCommand("ZIncrBy", execZIncrBy, nil, 4) - RegisterCommand("ZRank", execZRank, nil, 3) - RegisterCommand("ZCount", execZCount, nil, 4) - RegisterCommand("ZRevRank", execZRevRank, nil, 3) - RegisterCommand("ZCard", execZCard, nil, 2) - RegisterCommand("ZRange", execZRange, nil, -4) - RegisterCommand("ZRangeByScore", execZRangeByScore, nil, -4) - RegisterCommand("ZRange", execZRevRange, nil, -4) - RegisterCommand("ZRangeByScore", execZRevRangeByScore, nil, -4) - RegisterCommand("ZRem", execZRem, nil, -3) - RegisterCommand("ZRemRangeByScore", execZRemRangeByScore, nil, 4) - RegisterCommand("ZRemRangeByRank", execZRemRangeByRank, nil, 4) + RegisterCommand("ZAdd", execZAdd, writeFirstKey, undoZAdd, -4) + RegisterCommand("ZScore", execZScore, readFirstKey, nil, 3) + RegisterCommand("ZIncrBy", execZIncrBy, writeFirstKey, undoZIncr, 4) + RegisterCommand("ZRank", execZRank, readFirstKey, nil, 3) + RegisterCommand("ZCount", execZCount, readFirstKey, nil, 4) + RegisterCommand("ZRevRank", execZRevRank, readFirstKey, nil, 3) + RegisterCommand("ZCard", execZCard, readFirstKey, nil, 2) + RegisterCommand("ZRange", execZRange, readFirstKey, nil, -4) + RegisterCommand("ZRangeByScore", execZRangeByScore, readFirstKey, nil, -4) + RegisterCommand("ZRange", execZRange, readFirstKey, nil, -4) + RegisterCommand("ZRevRange", execZRevRange, readFirstKey, nil, -4) + RegisterCommand("ZRangeByScore", execZRangeByScore, readFirstKey, nil, -4) + RegisterCommand("ZRevRangeByScore", execZRevRangeByScore, readFirstKey, nil, -4) + RegisterCommand("ZRem", execZRem, writeFirstKey, undoZRem, -3) + RegisterCommand("ZRemRangeByScore", execZRemRangeByScore, writeFirstKey, rollbackFirstKey, 4) + RegisterCommand("ZRemRangeByRank", execZRemRangeByRank, writeFirstKey, rollbackFirstKey, 4) } diff --git a/sortedset_test.go b/sortedset_test.go index 3475648b..7394d0c8 100644 --- a/sortedset_test.go +++ b/sortedset_test.go @@ -9,7 +9,7 @@ import ( ) func TestZAdd(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 // add new members @@ -22,18 +22,18 @@ func TestZAdd(t *testing.T) { scores[i] = rand.Float64() setArgs = append(setArgs, strconv.FormatFloat(scores[i], 'f', -1, 64), members[i]) } - result := execZAdd(testDB, utils.ToBytesList(setArgs...)) + result := testDB.Exec(nil, utils.ToCmdLine2("zadd", setArgs...)) asserts.AssertIntReply(t, result, size) // test zscore and zrank for i, member := range members { - result := execZScore(testDB, utils.ToBytesList(key, member)) + result = testDB.Exec(nil, utils.ToCmdLine("ZScore", key, member)) score := strconv.FormatFloat(scores[i], 'f', -1, 64) asserts.AssertBulkReply(t, result, score) } // test zcard - result = execZCard(testDB, utils.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("zcard", key)) asserts.AssertIntReply(t, result, size) // update members @@ -42,19 +42,19 @@ func TestZAdd(t *testing.T) { scores[i] = rand.Float64() + 100 setArgs = append(setArgs, strconv.FormatFloat(scores[i], 'f', -1, 64), members[i]) } - result = execZAdd(testDB, utils.ToBytesList(setArgs...)) + result = testDB.Exec(nil, utils.ToCmdLine2("zadd", setArgs...)) asserts.AssertIntReply(t, result, 0) // return number of new members // test updated score for i, member := range members { - result := execZScore(testDB, utils.ToBytesList(key, member)) + result = testDB.Exec(nil, utils.ToCmdLine("zscore", key, member)) score := strconv.FormatFloat(scores[i], 'f', -1, 64) asserts.AssertBulkReply(t, result, score) } } func TestZRank(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 key := utils.RandString(10) members := make([]string, size) @@ -65,32 +65,31 @@ func TestZRank(t *testing.T) { scores[i] = i setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) } - execZAdd(testDB, utils.ToBytesList(setArgs...)) + testDB.Exec(nil, utils.ToCmdLine2("zadd", setArgs...)) // test zrank for i, member := range members { - result := execZRank(testDB, utils.ToBytesList(key, member)) + result := testDB.Exec(nil, utils.ToCmdLine("zrank", key, member)) asserts.AssertIntReply(t, result, i) - - result = execZRevRank(testDB, utils.ToBytesList(key, member)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRevRank", key, member)) asserts.AssertIntReply(t, result, size-i-1) } } func TestZRange(t *testing.T) { // prepare - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 key := utils.RandString(10) members := make([]string, size) scores := make([]int, size) setArgs := []string{key} for i := 0; i < size; i++ { - members[i] = utils.RandString(10) + members[i] = strconv.Itoa(i) //utils.RandString(10) scores[i] = i setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) } - execZAdd(testDB, utils.ToBytesList(setArgs...)) + testDB.Exec(nil, utils.ToCmdLine2("zadd", setArgs...)) reverseMembers := make([]string, size) for i, v := range members { reverseMembers[size-i-1] = v @@ -98,39 +97,39 @@ func TestZRange(t *testing.T) { start := "0" end := "9" - result := execZRange(testDB, utils.ToBytesList(key, start, end)) + result := testDB.Exec(nil, utils.ToCmdLine("ZRange", key, start, end)) asserts.AssertMultiBulkReply(t, result, members[0:10]) - result = execZRange(testDB, utils.ToBytesList(key, start, end, "WITHSCORES")) + result = testDB.Exec(nil, utils.ToCmdLine("ZRange", key, start, end, "WITHSCORES")) asserts.AssertMultiBulkReplySize(t, result, 20) - result = execZRevRange(testDB, utils.ToBytesList(key, start, end)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRevRange", key, start, end)) asserts.AssertMultiBulkReply(t, result, reverseMembers[0:10]) start = "0" end = "200" - result = execZRange(testDB, utils.ToBytesList(key, start, end)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRange", key, start, end)) asserts.AssertMultiBulkReply(t, result, members) - result = execZRevRange(testDB, utils.ToBytesList(key, start, end)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRevRange", key, start, end)) asserts.AssertMultiBulkReply(t, result, reverseMembers) start = "0" end = "-10" - result = execZRange(testDB, utils.ToBytesList(key, start, end)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRange", key, start, end)) asserts.AssertMultiBulkReply(t, result, members[0:size-10+1]) - result = execZRevRange(testDB, utils.ToBytesList(key, start, end)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRevRange", key, start, end)) asserts.AssertMultiBulkReply(t, result, reverseMembers[0:size-10+1]) start = "0" end = "-200" - result = execZRange(testDB, utils.ToBytesList(key, start, end)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRange", key, start, end)) asserts.AssertMultiBulkReply(t, result, members[0:0]) - result = execZRevRange(testDB, utils.ToBytesList(key, start, end)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRevRange", key, start, end)) asserts.AssertMultiBulkReply(t, result, reverseMembers[0:0]) start = "-10" end = "-1" - result = execZRange(testDB, utils.ToBytesList(key, start, end)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRange", key, start, end)) asserts.AssertMultiBulkReply(t, result, members[90:]) - result = execZRevRange(testDB, utils.ToBytesList(key, start, end)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRevRange", key, start, end)) asserts.AssertMultiBulkReply(t, result, reverseMembers[90:]) } @@ -144,7 +143,7 @@ func reverse(src []string) []string { func TestZRangeByScore(t *testing.T) { // prepare - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 key := utils.RandString(10) members := make([]string, size) @@ -155,49 +154,50 @@ func TestZRangeByScore(t *testing.T) { scores[i] = i setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) } - result := execZAdd(testDB, utils.ToBytesList(setArgs...)) + result := testDB.Exec(nil, utils.ToCmdLine2("zadd", setArgs...)) asserts.AssertIntReply(t, result, size) min := "20" max := "30" - result = execZRangeByScore(testDB, utils.ToBytesList(key, min, max)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRangeByScore", key, min, max)) asserts.AssertMultiBulkReply(t, result, members[20:31]) - result = execZRangeByScore(testDB, utils.ToBytesList(key, min, max, "WITHSCORES")) + result = testDB.Exec(nil, utils.ToCmdLine("ZRangeByScore", key, min, max, "WithScores")) asserts.AssertMultiBulkReplySize(t, result, 22) - result = execZRevRangeByScore(testDB, utils.ToBytesList(key, max, min)) + result = execZRevRangeByScore(testDB, utils.ToCmdLine(key, max, min)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRevRangeByScore", key, max, min)) asserts.AssertMultiBulkReply(t, result, reverse(members[20:31])) min = "-10" max = "10" - result = execZRangeByScore(testDB, utils.ToBytesList(key, min, max)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRangeByScore", key, min, max)) asserts.AssertMultiBulkReply(t, result, members[0:11]) - result = execZRevRangeByScore(testDB, utils.ToBytesList(key, max, min)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRevRangeByScore", key, max, min)) asserts.AssertMultiBulkReply(t, result, reverse(members[0:11])) min = "90" max = "110" - result = execZRangeByScore(testDB, utils.ToBytesList(key, min, max)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRangeByScore", key, min, max)) asserts.AssertMultiBulkReply(t, result, members[90:]) - result = execZRevRangeByScore(testDB, utils.ToBytesList(key, max, min)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRevRangeByScore", key, max, min)) asserts.AssertMultiBulkReply(t, result, reverse(members[90:])) min = "(20" max = "(30" - result = execZRangeByScore(testDB, utils.ToBytesList(key, min, max)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRangeByScore", key, min, max)) asserts.AssertMultiBulkReply(t, result, members[21:30]) - result = execZRevRangeByScore(testDB, utils.ToBytesList(key, max, min)) + result = testDB.Exec(nil, utils.ToCmdLine("ZRevRangeByScore", key, max, min)) asserts.AssertMultiBulkReply(t, result, reverse(members[21:30])) min = "20" max = "40" - result = execZRangeByScore(testDB, utils.ToBytesList(key, min, max, "LIMIT", "5", "5")) + result = testDB.Exec(nil, utils.ToCmdLine("ZRangeByScore", key, min, max, "LIMIT", "5", "5")) asserts.AssertMultiBulkReply(t, result, members[25:30]) - result = execZRevRangeByScore(testDB, utils.ToBytesList(key, max, min, "LIMIT", "5", "5")) + result = testDB.Exec(nil, utils.ToCmdLine("ZRevRangeByScore", key, max, min, "LIMIT", "5", "5")) asserts.AssertMultiBulkReply(t, result, reverse(members[31:36])) } func TestZRem(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 key := utils.RandString(10) members := make([]string, size) @@ -208,17 +208,17 @@ func TestZRem(t *testing.T) { scores[i] = i setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) } - execZAdd(testDB, utils.ToBytesList(setArgs...)) + testDB.Exec(nil, utils.ToCmdLine2("zadd", setArgs...)) args := []string{key} args = append(args, members[0:10]...) - result := execZRem(testDB, utils.ToBytesList(args...)) + result := testDB.Exec(nil, utils.ToCmdLine2("zrem", args...)) asserts.AssertIntReply(t, result, 10) - result = execZCard(testDB, utils.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("zcard", key)) asserts.AssertIntReply(t, result, size-10) // test ZRemRangeByRank - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size = 100 key = utils.RandString(10) members = make([]string, size) @@ -229,15 +229,15 @@ func TestZRem(t *testing.T) { scores[i] = i setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) } - execZAdd(testDB, utils.ToBytesList(setArgs...)) + testDB.Exec(nil, utils.ToCmdLine2("zadd", setArgs...)) - result = execZRemRangeByRank(testDB, utils.ToBytesList(key, "0", "9")) + result = testDB.Exec(nil, utils.ToCmdLine("ZRemRangeByRank", key, "0", "9")) asserts.AssertIntReply(t, result, 10) - result = execZCard(testDB, utils.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("zcard", key)) asserts.AssertIntReply(t, result, size-10) // test ZRemRangeByScore - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size = 100 key = utils.RandString(10) members = make([]string, size) @@ -248,17 +248,17 @@ func TestZRem(t *testing.T) { scores[i] = i setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) } - execZAdd(testDB, utils.ToBytesList(setArgs...)) + testDB.Exec(nil, utils.ToCmdLine2("zadd", setArgs...)) - result = execZRemRangeByScore(testDB, utils.ToBytesList(key, "0", "9")) + result = testDB.Exec(nil, utils.ToCmdLine("ZRemRangeByScore", key, "0", "9")) asserts.AssertIntReply(t, result, 10) - result = execZCard(testDB, utils.ToBytesList(key)) + result = testDB.Exec(nil, utils.ToCmdLine("zcard", key)) asserts.AssertIntReply(t, result, size-10) } func TestZCount(t *testing.T) { // prepare - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 100 key := utils.RandString(10) members := make([]string, size) @@ -269,37 +269,37 @@ func TestZCount(t *testing.T) { scores[i] = i setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) } - execZAdd(testDB, utils.ToBytesList(setArgs...)) + testDB.Exec(nil, utils.ToCmdLine2("zadd", setArgs...)) min := "20" max := "30" - result := execZCount(testDB, utils.ToBytesList(key, min, max)) + result := testDB.Exec(nil, utils.ToCmdLine("zcount", key, min, max)) asserts.AssertIntReply(t, result, 11) min = "-10" max = "10" - result = execZCount(testDB, utils.ToBytesList(key, min, max)) + result = testDB.Exec(nil, utils.ToCmdLine("zcount", key, min, max)) asserts.AssertIntReply(t, result, 11) min = "90" max = "110" - result = execZCount(testDB, utils.ToBytesList(key, min, max)) + result = testDB.Exec(nil, utils.ToCmdLine("zcount", key, min, max)) asserts.AssertIntReply(t, result, 10) min = "(20" max = "(30" - result = execZCount(testDB, utils.ToBytesList(key, min, max)) + result = testDB.Exec(nil, utils.ToCmdLine("zcount", key, min, max)) asserts.AssertIntReply(t, result, 9) } func TestZIncrBy(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() key := utils.RandString(10) - result := execZIncrBy(testDB, utils.ToBytesList(key, "10", "a")) + result := testDB.Exec(nil, utils.ToCmdLine("ZIncrBy", key, "10", "a")) asserts.AssertBulkReply(t, result, "10") - result = execZIncrBy(testDB, utils.ToBytesList(key, "10", "a")) + result = testDB.Exec(nil, utils.ToCmdLine("ZIncrBy", key, "10", "a")) asserts.AssertBulkReply(t, result, "20") - result = execZScore(testDB, utils.ToBytesList(key, "a")) + result = testDB.Exec(nil, utils.ToCmdLine("ZScore", key, "a")) asserts.AssertBulkReply(t, result, "20") } diff --git a/string.go b/string.go index a44ee615..769c434c 100644 --- a/string.go +++ b/string.go @@ -173,9 +173,6 @@ func execSetEX(db *DB, args [][]byte) redis.Reply { Data: value, } - db.Lock(key) - defer db.UnLock(key) - db.PutEntity(key, entity) expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond) db.Expire(key, expireTime) @@ -201,9 +198,6 @@ func execPSetEX(db *DB, args [][]byte) redis.Reply { Data: value, } - db.Lock(key) - defer db.UnLock(key) - db.PutEntity(key, entity) expireTime := time.Now().Add(time.Duration(ttlArg) * time.Millisecond) db.Expire(key, expireTime) @@ -213,6 +207,20 @@ func execPSetEX(db *DB, args [][]byte) redis.Reply { return &reply.OkReply{} } +func prepareMSet(args [][]byte) ([]string, []string) { + size := len(args) / 2 + keys := make([]string, size) + for i := 0; i < size; i++ { + keys[i] = string(args[2*i]) + } + return keys, nil +} + +func undoMSet(db *DB, args [][]byte) []CmdLine { + writeKeys, _ := prepareMSet(args) + return rollbackGivenKeys(db, writeKeys...) +} + // execMSet sets multi key-value in database func execMSet(db *DB, args [][]byte) redis.Reply { if len(args)%2 != 0 { @@ -227,9 +235,6 @@ func execMSet(db *DB, args [][]byte) redis.Reply { values[i] = args[2*i+1] } - db.Locks(keys...) - defer db.UnLocks(keys...) - for i, key := range keys { value := values[i] db.PutEntity(key, &DataEntity{Data: value}) @@ -238,6 +243,14 @@ func execMSet(db *DB, args [][]byte) redis.Reply { return &reply.OkReply{} } +func prepareMGet(args [][]byte) ([]string, []string) { + keys := make([]string, len(args)) + for i, v := range args { + keys[i] = string(v) + } + return nil, keys +} + // execMGet get multi key-value from database func execMGet(db *DB, args [][]byte) redis.Reply { keys := make([]string, len(args)) @@ -277,10 +290,6 @@ func execMSetNX(db *DB, args [][]byte) redis.Reply { values[i] = args[2*i+1] } - // lock keys - db.Locks(keys...) - defer db.UnLocks(keys...) - for _, key := range keys { _, exists := db.GetEntity(key) if exists { @@ -319,9 +328,6 @@ func execGetSet(db *DB, args [][]byte) redis.Reply { func execIncr(db *DB, args [][]byte) redis.Reply { key := string(args[0]) - db.Lock(key) - defer db.UnLock(key) - bytes, err := db.getAsString(key) if err != nil { return err @@ -353,9 +359,6 @@ func execIncrBy(db *DB, args [][]byte) redis.Reply { return reply.MakeErrReply("ERR value is not an integer or out of range") } - db.Lock(key) - defer db.UnLock(key) - bytes, errReply := db.getAsString(key) if errReply != nil { return errReply @@ -388,9 +391,6 @@ func execIncrByFloat(db *DB, args [][]byte) redis.Reply { return reply.MakeErrReply("ERR value is not a valid float") } - db.Lock(key) - defer db.UnLock(key) - bytes, errReply := db.getAsString(key) if errReply != nil { return errReply @@ -418,9 +418,6 @@ func execIncrByFloat(db *DB, args [][]byte) redis.Reply { func execDecr(db *DB, args [][]byte) redis.Reply { key := string(args[0]) - db.Lock(key) - defer db.UnLock(key) - bytes, errReply := db.getAsString(key) if errReply != nil { return errReply @@ -453,9 +450,6 @@ func execDecrBy(db *DB, args [][]byte) redis.Reply { return reply.MakeErrReply("ERR value is not an integer or out of range") } - db.Lock(key) - defer db.UnLock(key) - bytes, errReply := db.getAsString(key) if errReply != nil { return errReply @@ -480,20 +474,18 @@ func execDecrBy(db *DB, args [][]byte) redis.Reply { } func init() { - RegisterCommand("Set", execSet, nil, -3) - RegisterCommand("SetNx", execSetNX, nil, 3) - RegisterCommand("SetEX", execSetEX, nil, 4) - RegisterCommand("PSetEX", execPSetEX, nil, 4) - RegisterCommand("MSet", execMSet, nil, -3) - RegisterCommand("MGet", execMGet, nil, -2) - RegisterCommand("MSetNX", execMSetNX, nil, -3) - RegisterCommand("Get", execGet, nil, 2) - RegisterCommand("MSet", execMSet, nil, -3) - RegisterCommand("GetSet", execGetSet, nil, 3) - RegisterCommand("MSet", execMSet, nil, -3) - RegisterCommand("Incr", execIncr, nil, 2) - RegisterCommand("IncrBy", execIncrBy, nil, 3) - RegisterCommand("IncrByFloat", execIncrByFloat, nil, 3) - RegisterCommand("Decr", execDecr, nil, 2) - RegisterCommand("DecrBy", execDecrBy, nil, 3) + RegisterCommand("Set", execSet, writeFirstKey, rollbackFirstKey, -3) + RegisterCommand("SetNx", execSetNX, writeFirstKey, rollbackFirstKey, 3) + RegisterCommand("SetEX", execSetEX, writeFirstKey, rollbackFirstKey, 4) + RegisterCommand("PSetEX", execPSetEX, writeFirstKey, rollbackFirstKey, 4) + RegisterCommand("MSet", execMSet, prepareMSet, undoMSet, -3) + RegisterCommand("MGet", execMGet, prepareMGet, nil, -2) + RegisterCommand("MSetNX", execMSetNX, prepareMSet, undoMSet, -3) + RegisterCommand("Get", execGet, readFirstKey, nil, 2) + RegisterCommand("GetSet", execGetSet, writeFirstKey, rollbackFirstKey, 3) + RegisterCommand("Incr", execIncr, writeFirstKey, rollbackFirstKey, 2) + RegisterCommand("IncrBy", execIncrBy, writeFirstKey, rollbackFirstKey, 3) + RegisterCommand("IncrByFloat", execIncrByFloat, writeFirstKey, rollbackFirstKey, 3) + RegisterCommand("Decr", execDecr, writeFirstKey, rollbackFirstKey, 2) + RegisterCommand("DecrBy", execDecrBy, writeFirstKey, rollbackFirstKey, 3) } diff --git a/string_test.go b/string_test.go index c6460363..bf3d9cbd 100644 --- a/string_test.go +++ b/string_test.go @@ -2,8 +2,7 @@ package godis import ( "fmt" - "github.com/hdt3213/godis/datastruct/utils" - utils2 "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply/asserts" "strconv" @@ -12,56 +11,71 @@ import ( var testDB = makeTestDB() +func TestSet2(t *testing.T) { + key := utils.RandString(10) + value := utils.RandString(10) + for i := 0; i < 1000; i++ { + testDB.Exec(nil, utils.ToCmdLine("SET", key, value)) + actual := testDB.Exec(nil, utils.ToCmdLine("GET", key)) + expected := reply.MakeBulkReply([]byte(value)) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) + } + } +} + func TestSet(t *testing.T) { - execFlushAll(testDB, [][]byte{}) - key := utils2.RandString(10) - value := utils2.RandString(10) + testDB.Flush() + key := utils.RandString(10) + value := utils.RandString(10) // normal set - execSet(testDB, utils2.ToBytesList(key, value)) - actual := execGet(testDB, utils2.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("SET", key, value)) + actual := testDB.Exec(nil, utils.ToCmdLine("GET", key)) expected := reply.MakeBulkReply([]byte(value)) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) } // set nx - actual = execSet(testDB, utils2.ToBytesList(key, value, "NX")) + actual = testDB.Exec(nil, utils.ToCmdLine("SET", key, value, "NX")) if _, ok := actual.(*reply.NullBulkReply); !ok { t.Error("expected true actual false") } - execFlushAll(testDB, [][]byte{}) - key = utils2.RandString(10) - value = utils2.RandString(10) - execSet(testDB, utils2.ToBytesList(key, value, "NX")) - actual = execGet(testDB, utils2.ToBytesList(key)) + testDB.Flush() + key = utils.RandString(10) + value = utils.RandString(10) + actual = testDB.Exec(nil, utils.ToCmdLine("SET", key, value, "NX")) + actual = testDB.Exec(nil, utils.ToCmdLine("GET", key)) expected = reply.MakeBulkReply([]byte(value)) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) } // set xx - execFlushAll(testDB, [][]byte{}) - key = utils2.RandString(10) - value = utils2.RandString(10) - actual = execSet(testDB, utils2.ToBytesList(key, value, "XX")) + testDB.Flush() + key = utils.RandString(10) + value = utils.RandString(10) + actual = testDB.Exec(nil, utils.ToCmdLine("SET", key, value, "XX")) if _, ok := actual.(*reply.NullBulkReply); !ok { t.Error("expected true actually false ") } - execSet(testDB, utils2.ToBytesList(key, value)) - execSet(testDB, utils2.ToBytesList(key, value, "XX")) - actual = execGet(testDB, utils2.ToBytesList(key)) + execSet(testDB, utils.ToCmdLine(key, value)) + testDB.Exec(nil, utils.ToCmdLine("SET", key, value)) + actual = testDB.Exec(nil, utils.ToCmdLine("SET", key, value, "XX")) + actual = testDB.Exec(nil, utils.ToCmdLine("GET", key)) asserts.AssertBulkReply(t, actual, value) // set ex testDB.Remove(key) ttl := "1000" - execSet(testDB, utils2.ToBytesList(key, value, "EX", ttl)) - actual = execGet(testDB, utils2.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("SET", key, value, "EX", ttl)) + actual = testDB.Exec(nil, utils.ToCmdLine("GET", key)) asserts.AssertBulkReply(t, actual, value) - actual = execTTL(testDB, utils2.ToBytesList(key)) + actual = execTTL(testDB, utils.ToCmdLine(key)) + actual = testDB.Exec(nil, utils.ToCmdLine("TTL", key)) intResult, ok := actual.(*reply.IntReply) if !ok { t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes())) @@ -75,10 +89,10 @@ func TestSet(t *testing.T) { // set px testDB.Remove(key) ttlPx := "1000000" - execSet(testDB, utils2.ToBytesList(key, value, "PX", ttlPx)) - actual = execGet(testDB, utils2.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("SET", key, value, "PX", ttlPx)) + actual = testDB.Exec(nil, utils.ToCmdLine("GET", key)) asserts.AssertBulkReply(t, actual, value) - actual = execTTL(testDB, utils2.ToBytesList(key)) + actual = testDB.Exec(nil, utils.ToCmdLine("TTL", key)) intResult, ok = actual.(*reply.IntReply) if !ok { t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes())) @@ -91,17 +105,17 @@ func TestSet(t *testing.T) { } func TestSetNX(t *testing.T) { - execFlushAll(testDB, [][]byte{}) - key := utils2.RandString(10) - value := utils2.RandString(10) - execSetNX(testDB, utils2.ToBytesList(key, value)) - actual := execGet(testDB, utils2.ToBytesList(key)) + testDB.Flush() + key := utils.RandString(10) + value := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine("SETNX", key, value)) + actual := testDB.Exec(nil, utils.ToCmdLine("GET", key)) expected := reply.MakeBulkReply([]byte(value)) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) } - actual = execSetNX(testDB, utils2.ToBytesList(key, value)) + actual = testDB.Exec(nil, utils.ToCmdLine("SETNX", key, value)) expected2 := reply.MakeIntReply(int64(0)) if !utils.BytesEquals(actual.ToBytes(), expected2.ToBytes()) { t.Error("expected: " + string(expected2.ToBytes()) + ", actual: " + string(actual.ToBytes())) @@ -109,15 +123,15 @@ func TestSetNX(t *testing.T) { } func TestSetEX(t *testing.T) { - execFlushAll(testDB, [][]byte{}) - key := utils2.RandString(10) - value := utils2.RandString(10) + testDB.Flush() + key := utils.RandString(10) + value := utils.RandString(10) ttl := "1000" - execSetEX(testDB, utils2.ToBytesList(key, ttl, value)) - actual := execGet(testDB, utils2.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("SETEX", key, ttl, value)) + actual := testDB.Exec(nil, utils.ToCmdLine("GET", key)) asserts.AssertBulkReply(t, actual, value) - actual = execTTL(testDB, utils2.ToBytesList(key)) + actual = testDB.Exec(nil, utils.ToCmdLine("TTL", key)) intResult, ok := actual.(*reply.IntReply) if !ok { t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes())) @@ -130,15 +144,15 @@ func TestSetEX(t *testing.T) { } func TestPSetEX(t *testing.T) { - execFlushAll(testDB, [][]byte{}) - key := utils2.RandString(10) - value := utils2.RandString(10) + testDB.Flush() + key := utils.RandString(10) + value := utils.RandString(10) ttl := "1000000" - execPSetEX(testDB, utils2.ToBytesList(key, ttl, value)) - actual := execGet(testDB, utils2.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("PSetEx", key, ttl, value)) + actual := testDB.Exec(nil, utils.ToCmdLine("GET", key)) asserts.AssertBulkReply(t, actual, value) - actual = execPTTL(testDB, utils2.ToBytesList(key)) + actual = testDB.Exec(nil, utils.ToCmdLine("PTTL", key)) intResult, ok := actual.(*reply.IntReply) if !ok { t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes())) @@ -151,51 +165,65 @@ func TestPSetEX(t *testing.T) { } func TestMSet(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 10 keys := make([]string, size) values := make([][]byte, size) - args := make([]string, 0, size*2) + var args []string for i := 0; i < size; i++ { - keys[i] = utils2.RandString(10) - value := utils2.RandString(10) + keys[i] = utils.RandString(10) + value := utils.RandString(10) values[i] = []byte(value) args = append(args, keys[i], value) } - execMSet(testDB, utils2.ToBytesList(args...)) - actual := execMGet(testDB, utils2.ToBytesList(keys...)) + testDB.Exec(nil, utils.ToCmdLine2("MSET", args...)) + actual := testDB.Exec(nil, utils.ToCmdLine2("MGET", keys...)) expected := reply.MakeMultiBulkReply(values) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) } + + // test mget with wrong type + key1 := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine2("SET", key1, key1)) + key2 := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine2("LPush", key2, key2)) + actual = testDB.Exec(nil, utils.ToCmdLine2("MGET", key1, key2)) + arr := actual.(*reply.MultiBulkReply) + if string(arr.Args[0]) != key1 { + t.Error("expected: " + key1 + ", actual: " + string(arr.Args[1])) + } + if len(arr.Args[1]) > 0 { + t.Error("expect null, actual: " + string(arr.Args[0])) + } } func TestIncr(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 10 - key := utils2.RandString(10) + key := utils.RandString(10) for i := 0; i < size; i++ { - execIncr(testDB, utils2.ToBytesList(key)) - actual := execGet(testDB, utils2.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("INCR", key)) + actual := testDB.Exec(nil, utils.ToCmdLine("GET", key)) expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(i+1), 10))) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) } } for i := 0; i < size; i++ { - execIncrBy(testDB, utils2.ToBytesList(key, "-1")) - actual := execGet(testDB, utils2.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("INCRBY", key, "-1")) + actual := testDB.Exec(nil, utils.ToCmdLine("GET", key)) expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(size-i-1), 10))) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) } } - execFlushAll(testDB, [][]byte{}) - key = utils2.RandString(10) + testDB.Flush() + key = utils.RandString(10) for i := 0; i < size; i++ { - execIncrBy(testDB, utils2.ToBytesList(key, "1")) - actual := execGet(testDB, utils2.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("INCRBY", key, "1")) + actual := testDB.Exec(nil, utils.ToCmdLine("GET", key)) expected := reply.MakeBulkReply([]byte(strconv.FormatInt(int64(i+1), 10))) if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { t.Error("expected: " + string(expected.ToBytes()) + ", actual: " + string(actual.ToBytes())) @@ -203,8 +231,8 @@ func TestIncr(t *testing.T) { } testDB.Remove(key) for i := 0; i < size; i++ { - execIncrByFloat(testDB, utils2.ToBytesList(key, "-1.0")) - actual := execGet(testDB, utils2.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("INCRBYFLOAT", key, "-1.0")) + actual := testDB.Exec(nil, utils.ToCmdLine("GET", key)) expected := -i - 1 bulk, ok := actual.(*reply.BulkReply) if !ok { @@ -224,18 +252,18 @@ func TestIncr(t *testing.T) { } func TestDecr(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 10 - key := utils2.RandString(10) + key := utils.RandString(10) for i := 0; i < size; i++ { - execDecr(testDB, utils2.ToBytesList(key)) - actual := execGet(testDB, utils2.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("DECR", key)) + actual := testDB.Exec(nil, utils.ToCmdLine("GET", key)) asserts.AssertBulkReply(t, actual, strconv.Itoa(-i-1)) } testDB.Remove(key) for i := 0; i < size; i++ { - execDecrBy(testDB, utils2.ToBytesList(key, "1")) - actual := execGet(testDB, utils2.ToBytesList(key)) + testDB.Exec(nil, utils.ToCmdLine("DECRBY", key, "1")) + actual := testDB.Exec(nil, utils.ToCmdLine("GET", key)) expected := -i - 1 bulk, ok := actual.(*reply.BulkReply) if !ok { @@ -255,35 +283,35 @@ func TestDecr(t *testing.T) { } func TestGetSet(t *testing.T) { - execFlushAll(testDB, [][]byte{}) - key := utils2.RandString(10) - value := utils2.RandString(10) + testDB.Flush() + key := utils.RandString(10) + value := utils.RandString(10) - result := execGetSet(testDB, utils2.ToBytesList(key, value)) - _, ok := result.(*reply.NullBulkReply) + actual := testDB.Exec(nil, utils.ToCmdLine("GETSET", key, value)) + _, ok := actual.(*reply.NullBulkReply) if !ok { - t.Errorf("expect null bulk reply, get: %s", string(result.ToBytes())) + t.Errorf("expect null bulk reply, get: %s", string(actual.ToBytes())) return } - value2 := utils2.RandString(10) - result = execGetSet(testDB, utils2.ToBytesList(key, value2)) - asserts.AssertBulkReply(t, result, value) - result = execGet(testDB, utils2.ToBytesList(key)) - asserts.AssertBulkReply(t, result, value2) + value2 := utils.RandString(10) + actual = testDB.Exec(nil, utils.ToCmdLine("GETSET", key, value2)) + asserts.AssertBulkReply(t, actual, value) + actual = testDB.Exec(nil, utils.ToCmdLine("GET", key)) + asserts.AssertBulkReply(t, actual, value2) } func TestMSetNX(t *testing.T) { - execFlushAll(testDB, [][]byte{}) + testDB.Flush() size := 10 args := make([]string, 0, size*2) for i := 0; i < size; i++ { - str := utils2.RandString(10) + str := utils.RandString(10) args = append(args, str, str) } - result := execMSetNX(testDB, utils2.ToBytesList(args...)) + result := testDB.Exec(nil, utils.ToCmdLine2("MSETNX", args...)) asserts.AssertIntReply(t, result, 1) - result = execMSetNX(testDB, utils2.ToBytesList(args[0:4]...)) + result = testDB.Exec(nil, utils.ToCmdLine2("MSETNX", args[0:4]...)) asserts.AssertIntReply(t, result, 0) }