diff --git a/cmd_string.go b/cmd_string.go index cec9d483..b8013f28 100644 --- a/cmd_string.go +++ b/cmd_string.go @@ -131,16 +131,27 @@ func (m *Miniredis) cmdSet(c *server.Peer, cmd string, args []string) { withTx(m, c, func(c *server.Peer, ctx *connCtx) { db := m.db(ctx.selectedDB) + readonly := false if opts.nx { if db.exists(opts.key) { - c.WriteNull() - return + if opts.get { + // special case for SET NX GET + readonly = true + } else { + c.WriteNull() + return + } } } if opts.xx { if !db.exists(opts.key) { - c.WriteNull() - return + if opts.get { + // special case for SET XX GET + readonly = true + } else { + c.WriteNull() + return + } } } if opts.keepttl { @@ -154,14 +165,17 @@ func (m *Miniredis) cmdSet(c *server.Peer, cmd string, args []string) { return } } + old, existed := db.stringKeys[opts.key] - db.del(opts.key, true) // be sure to remove existing values of other type keys. - // a vanilla SET clears the expire - if opts.ttl >= 0 { // EXAT/PXAT can expire right away - db.stringSet(opts.key, opts.value) - } - if opts.ttl != 0 { - db.ttl[opts.key] = opts.ttl + if !readonly { + db.del(opts.key, true) // be sure to remove existing values of other type keys. + // a vanilla SET clears the expire + if opts.ttl >= 0 { // EXAT/PXAT can expire right away + db.stringSet(opts.key, opts.value) + } + if opts.ttl != 0 { + db.ttl[opts.key] = opts.ttl + } } if opts.get { if !existed { diff --git a/integration/string_test.go b/integration/string_test.go index 84b3cba6..cfe27b4a 100644 --- a/integration/string_test.go +++ b/integration/string_test.go @@ -32,6 +32,13 @@ func TestString(t *testing.T) { c.Do("SET", "gone", "bar", "EXAT", "123") c.Do("EXISTS", "gone") + // SET NX GET + c.Do("SET", "unique", "value1", "NX", "GET") + c.Do("SET", "unique", "value2", "NX", "GET") + c.Do("SET", "unique", "value3", "XX", "GET") + c.Do("SET", "unique", "value4", "XX", "GET") + c.Do("SET", "uniquer", "value5", "XX", "GET") + // Failure cases c.Error("wrong number", "SET") c.Error("wrong number", "SET", "foo")