Skip to content

Commit

Permalink
feat: Enhance host extraction from headers (#2292)
Browse files Browse the repository at this point in the history
- Refactor SUBController subs and subJsons methods to extract host from X-Forwarded-Host header, falling back to X-Real-IP header and then to the request host if unavailable.
- Update html function to extract host from X-Forwarded-Host header, falling back to X-Real-IP header and then to the request host if unavailable.
- Update DomainValidatorMiddleware to first attempt to extract host from X-Forwarded-Host header, falling back to X-Real-IP header and then to the request host.

Fixes: #2284

Signed-off-by: Ahmad Thoriq Najahi <najahi@zephyrus.id>
  • Loading branch information
najahiiii authored May 23, 2024
1 parent 5ec1630 commit d070a82
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
26 changes: 24 additions & 2 deletions sub/subController.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,18 @@ func (a *SUBController) initRouter(g *gin.RouterGroup) {

func (a *SUBController) subs(c *gin.Context) {
subId := c.Param("subid")
host, _, _ := net.SplitHostPort(c.Request.Host)
host := c.GetHeader("X-Forwarded-Host")
if host == "" {
host = c.GetHeader("X-Real-IP")
}
if host == "" {
var err error
host, _, err = net.SplitHostPort(c.Request.Host)
if err != nil {
host = c.Request.Host
}
}
host = host
subs, header, err := a.subService.GetSubs(subId, host)
if err != nil || len(subs) == 0 {
c.String(400, "Error!")
Expand All @@ -79,7 +90,18 @@ func (a *SUBController) subs(c *gin.Context) {

func (a *SUBController) subJsons(c *gin.Context) {
subId := c.Param("subid")
host, _, _ := net.SplitHostPort(c.Request.Host)
host := c.GetHeader("X-Forwarded-Host")
if host == "" {
host = c.GetHeader("X-Real-IP")
}
if host == "" {
var err error
host, _, err = net.SplitHostPort(c.Request.Host)
if err != nil {
host = c.Request.Host
}
}
host = host
jsonSub, header, err := a.subJsonService.GetJson(subId, host)
if err != nil || len(jsonSub) == 0 {
c.String(400, "Error!")
Expand Down
13 changes: 12 additions & 1 deletion web/controller/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,18 @@ func html(c *gin.Context, name string, title string, data gin.H) {
data = gin.H{}
}
data["title"] = title
data["host"], _, _ = net.SplitHostPort(c.Request.Host)
host := c.GetHeader("X-Forwarded-Host")
if host == "" {
host = c.GetHeader("X-Real-IP")
}
if host == "" {
var err error
host, _, err = net.SplitHostPort(c.Request.Host)
if err != nil {
host = c.Request.Host
}
}
data["host"] = host
data["request_uri"] = c.Request.RequestURI
data["base_path"] = c.GetString("base_path")
c.HTML(http.StatusOK, name, getContext(data))
Expand Down
16 changes: 10 additions & 6 deletions web/middleware/domainValidator.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ import (

func DomainValidatorMiddleware(domain string) gin.HandlerFunc {
return func(c *gin.Context) {
host, _, _ := net.SplitHostPort(c.Request.Host)

if host != domain {
c.AbortWithStatus(http.StatusForbidden)
return
host := c.GetHeader("X-Forwarded-Host")
if host == "" {
host = c.GetHeader("X-Real-IP")
}

if host == "" {
host, _, _ := net.SplitHostPort(c.Request.Host)
if host != domain {
c.AbortWithStatus(http.StatusForbidden)
return
}
c.Next()
}
}
}

0 comments on commit d070a82

Please sign in to comment.