From 93b2b810a04d58c234fa552120cc371e72dc325b Mon Sep 17 00:00:00 2001 From: Harmen Date: Thu, 26 Feb 2015 09:50:56 +0100 Subject: [PATCH] make AUTH auth. --- cmd_connection.go | 21 +++++++++++++ cmd_connection_test.go | 14 ++++++++- cmd_string.go | 6 ++++ miniredis.go | 68 +++++++++++++++++++++++++++++------------- 4 files changed, 88 insertions(+), 21 deletions(-) diff --git a/cmd_connection.go b/cmd_connection.go index 4a4a7284..12b6ad49 100644 --- a/cmd_connection.go +++ b/cmd_connection.go @@ -18,12 +18,33 @@ func commandsConnection(m *Miniredis, srv *redeo.Server) { // PING func (m *Miniredis) cmdPing(out *redeo.Responder, r *redeo.Request) error { + if !m.handleAuth(r.Client(), out) { + return nil + } out.WriteInlineString("PONG") return nil } // AUTH func (m *Miniredis) cmdAuth(out *redeo.Responder, r *redeo.Request) error { + if len(r.Args) != 1 { + setDirty(r.Client()) + return r.WrongNumberOfArgs() + } + pw := r.Args[0] + + m.Lock() + defer m.Unlock() + if m.password == "" { + out.WriteErrorString("ERR Client sent AUTH, but no password is set") + return nil + } + if m.password != pw { + out.WriteErrorString("ERR invalid password") + return nil + } + + setAuthenticated(r.Client()) out.WriteOK() return nil } diff --git a/cmd_connection_test.go b/cmd_connection_test.go index 19d22e33..729a7a9d 100644 --- a/cmd_connection_test.go +++ b/cmd_connection_test.go @@ -13,8 +13,20 @@ func TestAuth(t *testing.T) { c, err := redis.Dial("tcp", s.Addr()) ok(t, err) - // We accept all AUTH _, err = c.Do("AUTH", "foo", "bar") + assert(t, err != nil, "no password set") + + s.RequireAuth("nocomment") + _, err = c.Do("PING", "foo", "bar") + assert(t, err != nil, "need AUTH") + + _, err = c.Do("AUTH", "wrongpasswd") + assert(t, err != nil, "wrong password") + + _, err = c.Do("AUTH", "nocomment") + ok(t, err) + + _, err = c.Do("PING") ok(t, err) } diff --git a/cmd_string.go b/cmd_string.go index bec39f90..d352d2a4 100644 --- a/cmd_string.go +++ b/cmd_string.go @@ -38,6 +38,9 @@ func commandsString(m *Miniredis, srv *redeo.Server) { // SET func (m *Miniredis) cmdSet(out *redeo.Responder, r *redeo.Request) error { + if !m.handleAuth(r.Client(), out) { + return nil + } if len(r.Args) < 2 { setDirty(r.Client()) return r.WrongNumberOfArgs() @@ -246,6 +249,9 @@ func (m *Miniredis) cmdMsetnx(out *redeo.Responder, r *redeo.Request) error { // GET func (m *Miniredis) cmdGet(out *redeo.Responder, r *redeo.Request) error { + if !m.handleAuth(r.Client(), out) { + return nil + } if len(r.Args) != 1 { setDirty(r.Client()) return r.WrongNumberOfArgs() diff --git a/miniredis.go b/miniredis.go index 3ec5795e..253aac0f 100644 --- a/miniredis.go +++ b/miniredis.go @@ -45,6 +45,7 @@ type Miniredis struct { sync.Mutex srv *redeo.Server listenAddr string + password string closed chan struct{} listen net.Listener dbs map[int]*RedisDB @@ -62,6 +63,7 @@ type dbKey struct { // connCtx has all state for a single connection. type connCtx struct { selectedDB int // selected DB + authenticated bool // auth enabled and a valid AUTH seen transaction []txCmd // transaction callbacks. Or nil. dirtyTransaction bool // any error during QUEUEing. watch map[dbKey]uint // WATCHed keys. @@ -96,26 +98,6 @@ func Run() (*Miniredis, error) { return m, m.Start() } -// Restart restarts a Close()d server on the same port. Values will be -// preserved. -func (m *Miniredis) Restart() error { - m.Lock() - defer m.Unlock() - - l, err := listen(m.listenAddr) - if err != nil { - return err - } - m.listen = l - - go func() { - m.srv.Serve(m.listen) - m.closed <- struct{}{} - }() - - return nil -} - // Start starts a server. It listens on a random port on localhost. See also // Addr(). func (m *Miniredis) Start() error { @@ -146,6 +128,26 @@ func (m *Miniredis) Start() error { return nil } +// Restart restarts a Close()d server on the same port. Values will be +// preserved. +func (m *Miniredis) Restart() error { + m.Lock() + defer m.Unlock() + + l, err := listen(m.listenAddr) + if err != nil { + return err + } + m.listen = l + + go func() { + m.srv.Serve(m.listen) + m.closed <- struct{}{} + }() + + return nil +} + func listen(addr string) (net.Listener, error) { l, err := net.Listen("tcp", addr) if err != nil { @@ -171,6 +173,14 @@ func (m *Miniredis) Close() { m.listen = nil } +// RequireAuth makes every connection need to AUTH first. Disable again by +// setting an empty string. +func (m *Miniredis) RequireAuth(pw string) { + m.Lock() + defer m.Unlock() + m.password = pw +} + // DB returns a DB by ID. func (m *Miniredis) DB(i int) *RedisDB { m.Lock() @@ -233,6 +243,20 @@ func (m *Miniredis) TotalConnectionCount() int { return int(m.srv.Info().TotalConnections()) } +// handleAuth returns false if connection has no access. It sends the reply. +func (m *Miniredis) handleAuth(cl *redeo.Client, out *redeo.Responder) bool { + m.Lock() + defer m.Unlock() + if m.password == "" { + return true + } + if cl.Ctx == nil || !getCtx(cl).authenticated { + out.WriteErrorString("NOAUTH Authentication required.") + return false + } + return true +} + func getCtx(cl *redeo.Client) *connCtx { if cl.Ctx == nil { cl.Ctx = &connCtx{} @@ -281,3 +305,7 @@ func setDirty(cl *redeo.Client) { } getCtx(cl).dirtyTransaction = true } + +func setAuthenticated(cl *redeo.Client) { + getCtx(cl).authenticated = true +}