Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(middleware/csrf): Add support for trusted origins #2910

Merged
merged 10 commits into from
Mar 10, 2024
Prev Previous commit
Next Next commit
chore(middleware/csrf): Sentinel Errors
test(middleware/csrf): improve coverage
  • Loading branch information
sixcolors committed Mar 10, 2024
commit 68f1ee0fc6615221fa1dbd3e1214267517ed8ed9
8 changes: 5 additions & 3 deletions docs/api/middleware/csrf.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,11 @@ The CSRF middleware utilizes a set of sentinel errors to handle various scenario

- `ErrTokenNotFound`: Indicates that the CSRF token was not found.
- `ErrTokenInvalid`: Indicates that the CSRF token is invalid.
- `ErrNoReferer`: Indicates that the referer was not supplied.
- `ErrBadReferer`: Indicates that the referer is invalid.
- `ErrBadOrigin`: Indicates that the origin is invalid.
- `ErrRefererNotFound`: Indicates that the referer was not supplied.
- `ErrRefererInvalid`: Indicates that the referer is invalid.
- `ErrRefererNoMatch`: Indicates that the referer does not match host and is not a trusted origin.
- `ErrOriginInvalid`: Indicates that the origin is invalid.
- `ErrOriginNoMatch`: Indicates that the origin does not match host and is not a trusted origin.

If you use the default error handler, the client will receive a 403 Forbidden error without any additional information.

Expand Down
30 changes: 16 additions & 14 deletions middleware/csrf/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ import (
)

var (
ErrTokenNotFound = errors.New("csrf token not found")
ErrTokenInvalid = errors.New("csrf token invalid")
ErrNoReferer = errors.New("referer not supplied")
ErrBadReferer = errors.New("referer invalid")
ErrBadOrigin = errors.New("origin invalid")
dummyValue = []byte{'+'}
errNoOrigin = errors.New("origin not supplied")
ErrTokenNotFound = errors.New("csrf token not found")
ErrTokenInvalid = errors.New("csrf token invalid")
ErrRefererNotFound = errors.New("referer not supplied")
ErrRefererInvalid = errors.New("referer invalid")
ErrRefererNoMatch = errors.New("referer does not match host and is not a trusted origin")
ErrOriginInvalid = errors.New("origin invalid")
ErrOriginNoMatch = errors.New("origin does not match host and is not a trusted origin")
errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user
dummyValue = []byte{'+'}
)

// Handler for CSRF middleware
Expand Down Expand Up @@ -89,7 +91,7 @@ func New(config ...Config) fiber.Handler {
err := originMatchesHost(c, cfg.TrustedOrigins)

// If there's no origin, enforce a referer check for HTTPS connections.
if errors.Is(err, errNoOrigin) {
if errors.Is(err, errOriginNotFound) {
if c.Scheme() == "https" {
err = refererMatchesHost(c, cfg.TrustedOrigins)
} else {
Expand Down Expand Up @@ -263,12 +265,12 @@ func isFromCookie(extractor any) bool {
func originMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
origin := c.Get(fiber.HeaderOrigin)
if origin == "" || origin == "null" { // "null" is set by some browsers when the origin is a secure context https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin#description
return errNoOrigin
return errOriginNotFound
}

originURL, err := url.Parse(origin)
if err != nil {
return ErrBadOrigin
return ErrOriginInvalid
}

if originURL.Host != c.Host() {
Expand All @@ -277,7 +279,7 @@ func originMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
return nil
}
}
return ErrBadOrigin
return ErrOriginNoMatch
}

return nil
Expand All @@ -289,12 +291,12 @@ func originMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
func refererMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
referer := c.Get(fiber.HeaderReferer)
if referer == "" {
return ErrNoReferer
return ErrRefererNotFound
}

refererURL, err := url.Parse(referer)
if err != nil {
return ErrBadReferer
return ErrRefererInvalid
}

if refererURL.Host != c.Host() {
Expand All @@ -303,7 +305,7 @@ func refererMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
return nil
}
}
return ErrBadReferer
return ErrRefererNoMatch
}

return nil
Expand Down
115 changes: 108 additions & 7 deletions middleware/csrf/csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())

// Empty/invalid CSRF token
// Invalid CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
Expand Down Expand Up @@ -598,6 +598,50 @@
require.Equal(t, 200, ctx.Response.StatusCode())
}

func Test_CSRF_Extractor_EmptyString(t *testing.T) {

Check warning on line 601 in middleware/csrf/csrf_test.go

View workflow job for this annotation

GitHub Actions / lint

empty-lines: extra empty line at the end of a block (revive)
t.Parallel()
app := fiber.New()

extractor := func(_ fiber.Ctx) (string, error) {
return "", nil
}

errorHandler := func(c fiber.Ctx, err error) error {
return c.Status(403).SendString(err.Error())
}

app.Use(New(Config{
Extractor: extractor,
ErrorHandler: errorHandler,
}))

app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

h := app.Handler()
ctx := &fasthttp.RequestCtx{}

// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]

ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
ctx.Request.SetBodyString("_csrf=" + token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
require.Equal(t, ErrTokenNotFound.Error(), string(ctx.Response.Body()))

Check failure on line 642 in middleware/csrf/csrf_test.go

View workflow job for this annotation

GitHub Actions / lint

File is not `gofumpt`-ed with `-extra` (gofumpt)
}

Check failure on line 643 in middleware/csrf/csrf_test.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary trailing newline (whitespace)

func Test_CSRF_Origin(t *testing.T) {
t.Parallel()
app := fiber.New()
Expand Down Expand Up @@ -1101,7 +1145,7 @@
app := fiber.New()

errHandler := func(ctx fiber.Ctx, err error) error {
require.Equal(t, ErrNoReferer, err)
require.Equal(t, ErrRefererNotFound, err)
return ctx.Status(419).Send([]byte("empty CSRF token"))
}

Expand Down Expand Up @@ -1275,7 +1319,7 @@
require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
}

func Test_InvalidURLHeaders(t *testing.T) {
func Test_CSRF_InvalidURLHeaders(t *testing.T) {
t.Parallel()
app := fiber.New()

Expand Down Expand Up @@ -1307,12 +1351,12 @@
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "Invalid Origin")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://[::1]:%38%30/Invalid Origin")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 419, ctx.Response.StatusCode())
require.Equal(t, ErrBadOrigin.Error(), string(ctx.Response.Body()))
require.Equal(t, ErrOriginInvalid.Error(), string(ctx.Response.Body()))

// invalid Referer
ctx.Request.Reset()
Expand All @@ -1323,10 +1367,67 @@
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderReferer, "Invalid Referer")
ctx.Request.Header.Set(fiber.HeaderReferer, "http://[::1]:%38%30/Invalid Referer")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 419, ctx.Response.StatusCode())
require.Equal(t, ErrBadReferer.Error(), string(ctx.Response.Body()))
require.Equal(t, ErrRefererInvalid.Error(), string(ctx.Response.Body()))
}

func Test_CSRF_TokenFromContext(t *testing.T) {
t.Parallel()
app := fiber.New()

app.Use(New())

app.Get("/", func(c fiber.Ctx) error {
token := TokenFromContext(c)
require.NotEmpty(t, token)
return c.SendStatus(fiber.StatusOK)
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}

func Test_CSRF_FromContextMethods(t *testing.T) {
t.Parallel()
app := fiber.New()

app.Use(New())

app.Get("/", func(c fiber.Ctx) error {
token := TokenFromContext(c)
require.NotEmpty(t, token)

handler := HandlerFromContext(c)
require.NotNil(t, handler)

return c.SendStatus(fiber.StatusOK)
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}

func Test_CSRF_FromContextMethods_Invalid(t *testing.T) {
t.Parallel()
app := fiber.New()

app.Get("/", func(c fiber.Ctx) error {
token := TokenFromContext(c)
require.Empty(t, token)

handler := HandlerFromContext(c)
require.Nil(t, handler)

return c.SendStatus(fiber.StatusOK)
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
Loading