Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ func contextSave(r *http.Request, key string, val interface{}) *http.Request {
ctx = context.WithValue(ctx, key, val)
return r.WithContext(ctx)
}

func contextClear(r *http.Request) {
// no-op for go1.7+
}
4 changes: 4 additions & 0 deletions context_legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ func contextSave(r *http.Request, key string, val interface{}) *http.Request {
context.Set(r, key, val)
return r
}

func contextClear(r *http.Request) {
context.Clear(r)
}
15 changes: 7 additions & 8 deletions csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/pkg/errors"

"github.com/gorilla/context"
"github.com/gorilla/securecookie"
)

Expand Down Expand Up @@ -195,15 +194,15 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// as it will no longer match the request token.
realToken, err = generateRandomBytes(tokenLength)
if err != nil {
envError(r, err)
r = envError(r, err)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}

// Save the new (real) token in the session store.
err = cs.st.Save(realToken, w)
if err != nil {
envError(r, err)
r = envError(r, err)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}
Expand All @@ -225,13 +224,13 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// otherwise fails to parse.
referer, err := url.Parse(r.Referer())
if err != nil || referer.String() == "" {
envError(r, ErrNoReferer)
r = envError(r, ErrNoReferer)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}

if sameOrigin(r.URL, referer) == false {
envError(r, ErrBadReferer)
r = envError(r, ErrBadReferer)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}
Expand All @@ -240,7 +239,7 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// If the token returned from the session store is nil for non-idempotent
// ("unsafe") methods, call the error handler.
if realToken == nil {
envError(r, ErrNoToken)
r = envError(r, ErrNoToken)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}
Expand All @@ -250,7 +249,7 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// Compare the request token against the real token
if !compareTokens(requestToken, realToken) {
envError(r, ErrBadToken)
r = envError(r, ErrBadToken)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}
Expand All @@ -263,7 +262,7 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Call the wrapped handler/router on success.
cs.h.ServeHTTP(w, r)
// Clear the request context after the handler has completed.
context.Clear(r)
contextClear(r)
}

// unauthorizedhandler sets a HTTP 403 Forbidden status and writes the
Expand Down
6 changes: 2 additions & 4 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import (
"html/template"
"net/http"
"net/url"

"github.com/gorilla/context"
)

// Token returns a masked CSRF token ready for passing into HTML template or
Expand Down Expand Up @@ -200,6 +198,6 @@ func contains(vals []string, s string) bool {
}

// envError stores a CSRF error in the request context.
func envError(r *http.Request, err error) {
context.Set(r, errorKey, err)
func envError(r *http.Request, err error) *http.Request {
return contextSave(r, errorKey, err)
}