Skip to content

Commit

Permalink
make AUTH auth.
Browse files Browse the repository at this point in the history
  • Loading branch information
Harmen committed Feb 26, 2015
1 parent 637be1c commit 93b2b81
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 21 deletions.
21 changes: 21 additions & 0 deletions cmd_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,33 @@ func commandsConnection(m *Miniredis, srv *redeo.Server) {

// PING
func (m *Miniredis) cmdPing(out *redeo.Responder, r *redeo.Request) error {
if !m.handleAuth(r.Client(), out) {
return nil
}
out.WriteInlineString("PONG")
return nil
}

// AUTH
func (m *Miniredis) cmdAuth(out *redeo.Responder, r *redeo.Request) error {
if len(r.Args) != 1 {
setDirty(r.Client())
return r.WrongNumberOfArgs()
}
pw := r.Args[0]

m.Lock()
defer m.Unlock()
if m.password == "" {
out.WriteErrorString("ERR Client sent AUTH, but no password is set")
return nil
}
if m.password != pw {
out.WriteErrorString("ERR invalid password")
return nil
}

setAuthenticated(r.Client())
out.WriteOK()
return nil
}
Expand Down
14 changes: 13 additions & 1 deletion cmd_connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,20 @@ func TestAuth(t *testing.T) {
c, err := redis.Dial("tcp", s.Addr())
ok(t, err)

// We accept all AUTH
_, err = c.Do("AUTH", "foo", "bar")
assert(t, err != nil, "no password set")

s.RequireAuth("nocomment")
_, err = c.Do("PING", "foo", "bar")
assert(t, err != nil, "need AUTH")

_, err = c.Do("AUTH", "wrongpasswd")
assert(t, err != nil, "wrong password")

_, err = c.Do("AUTH", "nocomment")
ok(t, err)

_, err = c.Do("PING")
ok(t, err)
}

Expand Down
6 changes: 6 additions & 0 deletions cmd_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ func commandsString(m *Miniredis, srv *redeo.Server) {

// SET
func (m *Miniredis) cmdSet(out *redeo.Responder, r *redeo.Request) error {
if !m.handleAuth(r.Client(), out) {
return nil
}
if len(r.Args) < 2 {
setDirty(r.Client())
return r.WrongNumberOfArgs()
Expand Down Expand Up @@ -246,6 +249,9 @@ func (m *Miniredis) cmdMsetnx(out *redeo.Responder, r *redeo.Request) error {

// GET
func (m *Miniredis) cmdGet(out *redeo.Responder, r *redeo.Request) error {
if !m.handleAuth(r.Client(), out) {
return nil
}
if len(r.Args) != 1 {
setDirty(r.Client())
return r.WrongNumberOfArgs()
Expand Down
68 changes: 48 additions & 20 deletions miniredis.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type Miniredis struct {
sync.Mutex
srv *redeo.Server
listenAddr string
password string
closed chan struct{}
listen net.Listener
dbs map[int]*RedisDB
Expand All @@ -62,6 +63,7 @@ type dbKey struct {
// connCtx has all state for a single connection.
type connCtx struct {
selectedDB int // selected DB
authenticated bool // auth enabled and a valid AUTH seen
transaction []txCmd // transaction callbacks. Or nil.
dirtyTransaction bool // any error during QUEUEing.
watch map[dbKey]uint // WATCHed keys.
Expand Down Expand Up @@ -96,26 +98,6 @@ func Run() (*Miniredis, error) {
return m, m.Start()
}

// Restart restarts a Close()d server on the same port. Values will be
// preserved.
func (m *Miniredis) Restart() error {
m.Lock()
defer m.Unlock()

l, err := listen(m.listenAddr)
if err != nil {
return err
}
m.listen = l

go func() {
m.srv.Serve(m.listen)
m.closed <- struct{}{}
}()

return nil
}

// Start starts a server. It listens on a random port on localhost. See also
// Addr().
func (m *Miniredis) Start() error {
Expand Down Expand Up @@ -146,6 +128,26 @@ func (m *Miniredis) Start() error {
return nil
}

// Restart restarts a Close()d server on the same port. Values will be
// preserved.
func (m *Miniredis) Restart() error {
m.Lock()
defer m.Unlock()

l, err := listen(m.listenAddr)
if err != nil {
return err
}
m.listen = l

go func() {
m.srv.Serve(m.listen)
m.closed <- struct{}{}
}()

return nil
}

func listen(addr string) (net.Listener, error) {
l, err := net.Listen("tcp", addr)
if err != nil {
Expand All @@ -171,6 +173,14 @@ func (m *Miniredis) Close() {
m.listen = nil
}

// RequireAuth makes every connection need to AUTH first. Disable again by
// setting an empty string.
func (m *Miniredis) RequireAuth(pw string) {
m.Lock()
defer m.Unlock()
m.password = pw
}

// DB returns a DB by ID.
func (m *Miniredis) DB(i int) *RedisDB {
m.Lock()
Expand Down Expand Up @@ -233,6 +243,20 @@ func (m *Miniredis) TotalConnectionCount() int {
return int(m.srv.Info().TotalConnections())
}

// handleAuth returns false if connection has no access. It sends the reply.
func (m *Miniredis) handleAuth(cl *redeo.Client, out *redeo.Responder) bool {
m.Lock()
defer m.Unlock()
if m.password == "" {
return true
}
if cl.Ctx == nil || !getCtx(cl).authenticated {
out.WriteErrorString("NOAUTH Authentication required.")
return false
}
return true
}

func getCtx(cl *redeo.Client) *connCtx {
if cl.Ctx == nil {
cl.Ctx = &connCtx{}
Expand Down Expand Up @@ -281,3 +305,7 @@ func setDirty(cl *redeo.Client) {
}
getCtx(cl).dirtyTransaction = true
}

func setAuthenticated(cl *redeo.Client) {
getCtx(cl).authenticated = true
}

0 comments on commit 93b2b81

Please sign in to comment.