Skip to content

Commit

Permalink
fix(middleware/csrf): null origin
Browse files Browse the repository at this point in the history
expand tests to check invalid urls in headers
  • Loading branch information
sixcolors committed Mar 10, 2024
1 parent 77fae20 commit db2734e
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/api/middleware/csrf.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ The CSRF middleware utilizes a set of sentinel errors to handle various scenario
- `ErrTokenInvalid`: Indicates that the CSRF token is invalid.
- `ErrNoReferer`: Indicates that the referer was not supplied.
- `ErrBadReferer`: Indicates that the referer is invalid.
- `ErrBadOrigin`: Indicates that the origin is invalid.

If you use the default error handler, the client will receive a 403 Forbidden error without any additional information.

Expand Down
12 changes: 6 additions & 6 deletions middleware/csrf/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,18 +262,18 @@ func isFromCookie(extractor any) bool {
// returns nil if the origin header is valid
func originMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
origin := c.Get(fiber.HeaderOrigin)
if origin == "" {
if origin == "" || origin == "null" { // "null" is set by some browsers when the origin is a secure context https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin#description
return errNoOrigin
}

originURL, err := url.Parse(origin)
if err != nil {
return ErrBadReferer
return ErrBadOrigin

Check warning on line 271 in middleware/csrf/csrf.go

View check run for this annotation

Codecov / codecov/patch

middleware/csrf/csrf.go#L271

Added line #L271 was not covered by tests
}

if originURL.Host != c.Host() {
for _, trustedOrigin := range trustedOrigins {
if isSameSchemeAndDomain(trustedOrigin, origin) {
if isTrustedSchemeAndDomain(trustedOrigin, origin) {
return nil
}
}
Expand All @@ -299,7 +299,7 @@ func refererMatchesHost(c fiber.Ctx, trustedOrigins []string) error {

if refererURL.Host != c.Host() {
for _, trustedOrigin := range trustedOrigins {
if isSameSchemeAndDomain(trustedOrigin, referer) {
if isTrustedSchemeAndDomain(trustedOrigin, referer) {
return nil
}
}
Expand All @@ -309,10 +309,10 @@ func refererMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
return nil
}

// isSameSchemeAndDomain checks if the trustedProtoDomain is the same as the protoDomain
// isTrustedSchemeAndDomain checks if the trustedProtoDomain is the same as the protoDomain
// or if the protoDomain is a subdomain of the trustedProtoDomain where trustedProtoDomain
// is prefixed with "https://." or "http://."
func isSameSchemeAndDomain(trustedProtoDomain, protoDomain string) bool {
func isTrustedSchemeAndDomain(trustedProtoDomain, protoDomain string) bool {
if trustedProtoDomain == protoDomain {
return true
}
Expand Down
84 changes: 84 additions & 0 deletions middleware/csrf/csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,34 @@ func Test_CSRF_Origin(t *testing.T) {
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())

// Test Correct Origin with wrong port
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com:3000")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())

// Test Correct Origin with null
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "null")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())

// Test Correct Origin with ReverseProxy
ctx.Request.Reset()
ctx.Response.Reset()
Expand Down Expand Up @@ -1246,3 +1274,59 @@ func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {

require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
}

func Test_InvalidURLHeaders(t *testing.T) {
t.Parallel()
app := fiber.New()

errHandler := func(ctx fiber.Ctx, err error) error {
return ctx.Status(419).Send([]byte(err.Error()))
}

app.Use(New(Config{ErrorHandler: errHandler}))

app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

h := app.Handler()
ctx := &fasthttp.RequestCtx{}

// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]

// invalid Origin
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "Invalid Origin")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 419, ctx.Response.StatusCode())
require.Equal(t, ErrBadOrigin.Error(), string(ctx.Response.Body()))

// invalid Referer
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderReferer, "Invalid Referer")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 419, ctx.Response.StatusCode())
require.Equal(t, ErrBadReferer.Error(), string(ctx.Response.Body()))
}

0 comments on commit db2734e

Please sign in to comment.