Skip to content

Commit

Permalink
Merge pull request #197 from bramvbilsen/master
Browse files Browse the repository at this point in the history
Add Opt-In Functionality for Hashed Session Token Storage
  • Loading branch information
alexedwards authored Mar 16, 2024
2 parents d7ab9d9 + 7134b6f commit 7e11d57
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 0 deletions.
15 changes: 15 additions & 0 deletions data.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package scs
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"sort"
Expand Down Expand Up @@ -623,6 +624,11 @@ func generateToken() (string, error) {
return base64.RawURLEncoding.EncodeToString(b), nil
}

func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return base64.RawURLEncoding.EncodeToString(hash[:])
}

type contextKey string

var (
Expand All @@ -638,6 +644,9 @@ func generateContextKey() contextKey {
}

func (s *SessionManager) doStoreDelete(ctx context.Context, token string) (err error) {
if s.HashTokenInStore {
token = hashToken(token)
}
c, ok := s.Store.(interface {
DeleteCtx(context.Context, string) error
})
Expand All @@ -648,6 +657,9 @@ func (s *SessionManager) doStoreDelete(ctx context.Context, token string) (err e
}

func (s *SessionManager) doStoreFind(ctx context.Context, token string) (b []byte, found bool, err error) {
if s.HashTokenInStore {
token = hashToken(token)
}
c, ok := s.Store.(interface {
FindCtx(context.Context, string) ([]byte, bool, error)
})
Expand All @@ -658,6 +670,9 @@ func (s *SessionManager) doStoreFind(ctx context.Context, token string) (b []byt
}

func (s *SessionManager) doStoreCommit(ctx context.Context, token string, b []byte, expiry time.Time) (err error) {
if s.HashTokenInStore {
token = hashToken(token)
}
c, ok := s.Store.(interface {
CommitCtx(context.Context, string, []byte, time.Time) error
})
Expand Down
73 changes: 73 additions & 0 deletions data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,50 @@ func TestSessionManager_Load(T *testing.T) {
t.Error("returned context is unexpectedly nil")
}
})

T.Run("with token hashing", func(t *testing.T) {
s := New()
s.HashTokenInStore = true
s.IdleTimeout = time.Hour * 24

expectedToken := "example"
expectedExpiry := time.Now().Add(time.Hour)

initialCtx := context.WithValue(context.Background(), s.contextKey, &sessionData{
deadline: expectedExpiry,
token: expectedToken,
values: map[string]interface{}{
"blah": "blah",
},
mu: sync.Mutex{},
})

actualToken, actualExpiry, err := s.Commit(initialCtx)
if expectedToken != actualToken {
t.Errorf("expected token to equal %q, but received %q", expectedToken, actualToken)
}
if expectedExpiry != actualExpiry {
t.Errorf("expected expiry to equal %v, but received %v", expectedExpiry, actualExpiry)
}
if err != nil {
t.Errorf("unexpected error returned: %v", err)
}

retrievedCtx, err := s.Load(context.Background(), expectedToken)
if err != nil {
t.Errorf("unexpected error returned: %v", err)
}
retrievedSessionData, ok := retrievedCtx.Value(s.contextKey).(*sessionData)
if !ok {
t.Errorf("unexpected data in retrieved context")
} else if retrievedSessionData.token != expectedToken {
t.Errorf("expected token in context's session data data to equal %v, but received %v", expectedToken, retrievedSessionData.token)
}

if err := s.Destroy(retrievedCtx); err != nil {
t.Errorf("unexpected error returned: %v", err)
}
})
}

func TestSessionManager_Commit(T *testing.T) {
Expand Down Expand Up @@ -320,6 +364,35 @@ func TestSessionManager_Commit(T *testing.T) {
t.Error("expected error not returned")
}
})

T.Run("with token hashing", func(t *testing.T) {
s := New()
s.HashTokenInStore = true
s.IdleTimeout = time.Hour * 24

expectedToken := "example"
expectedExpiry := time.Now().Add(time.Hour)

ctx := context.WithValue(context.Background(), s.contextKey, &sessionData{
deadline: expectedExpiry,
token: expectedToken,
values: map[string]interface{}{
"blah": "blah",
},
mu: sync.Mutex{},
})

actualToken, actualExpiry, err := s.Commit(ctx)
if expectedToken != actualToken {
t.Errorf("expected token to equal %q, but received %q", expectedToken, actualToken)
}
if expectedExpiry != actualExpiry {
t.Errorf("expected expiry to equal %v, but received %v", expectedExpiry, actualExpiry)
}
if err != nil {
t.Errorf("unexpected error returned: %v", err)
}
})
}

func TestPut(t *testing.T) {
Expand Down
3 changes: 3 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ type SessionManager struct {
// a function which logs the error and returns a customized HTML error page.
ErrorFunc func(http.ResponseWriter, *http.Request, error)

// HashTokenInStore controls whether or not to store the session token or a hashed version in the store.
HashTokenInStore bool

// contextKey is the key used to set and retrieve the session data from a
// context.Context. It's automatically generated to ensure uniqueness.
contextKey contextKey
Expand Down

0 comments on commit 7e11d57

Please sign in to comment.