Skip to content

Commit

Permalink
remove gatekeeper in favor of a gateway package that does the same th…
Browse files Browse the repository at this point in the history
…ing, but with more robust features, including middleware pipelines
  • Loading branch information
patinthehat committed Aug 4, 2023
1 parent e7a91c4 commit 746e775
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 64 deletions.
7 changes: 5 additions & 2 deletions lib/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/emirpasic/gods/stacks/linkedliststack"
"github.com/joho/godotenv"
"github.com/robfig/cron/v3"
"github.com/stackup-app/stackup/lib/gateway"
"github.com/stackup-app/stackup/lib/support"
"github.com/stackup-app/stackup/lib/updater"
"github.com/stackup-app/stackup/lib/utils"
Expand Down Expand Up @@ -42,7 +43,7 @@ type Application struct {
CmdStartCallback CommandCallback
KillCommandCallback CommandCallback
ConfigFilename string
Gatekeeper *Gatekeeper
Gateway *gateway.Gateway
}

func (a *Application) loadWorkflowFile(filename string) *StackupWorkflow {
Expand All @@ -68,7 +69,7 @@ func (a *Application) loadWorkflowFile(filename string) *StackupWorkflow {
}

func (a *Application) init() {
a.Gatekeeper = CreateGatekeeper()
a.Gateway = gateway.New([]string{}, []string{})
a.ConfigFilename = support.FindExistingFile([]string{"stackup.dist.yaml", "stackup.yaml"}, "stackup.yaml")

a.flags = AppFlags{
Expand Down Expand Up @@ -336,6 +337,8 @@ func (a *Application) Run() {
a.handleFlagOptions()

a.Workflow.Initialize()
a.Gateway.SetAllowedDomains(a.Workflow.Settings.Domains.Allowed)

if len(a.Workflow.Settings.DotEnvFiles) > 0 {
godotenv.Load(a.Workflow.Settings.DotEnvFiles...)
}
Expand Down
44 changes: 0 additions & 44 deletions lib/app/gatekeeper.go

This file was deleted.

6 changes: 3 additions & 3 deletions lib/app/scriptNet.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func CreateScriptNetObject(vm *otto.Otto) {
}

func (net *ScriptNet) Fetch(url string) any {
if !App.Gatekeeper.CanAccessUrl(url) {
if !App.Gateway.Allowed(url) {
support.FailureMessageWithXMark("fetch failed: access to " + url + " is not allowed.")
return ""
}
Expand All @@ -26,7 +26,7 @@ func (net *ScriptNet) Fetch(url string) any {
}

func (net *ScriptNet) FetchJson(url string) any {
if !App.Gatekeeper.CanAccessUrl(url) {
if !App.Gateway.Allowed(url) {
support.FailureMessageWithXMark("fetchJson failed: access to " + url + " is not allowed.")
return interface{}(nil)
}
Expand All @@ -37,7 +37,7 @@ func (net *ScriptNet) FetchJson(url string) any {
}

func (net *ScriptNet) DownloadTo(url string, filename string) {
if !App.Gatekeeper.CanAccessUrl(url) {
if !App.Gateway.Allowed(url) {
support.FailureMessageWithXMark("download failed: access to " + url + " is not allowed.")
return
}
Expand Down
30 changes: 15 additions & 15 deletions lib/app/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package app

import (
"fmt"
"net/url"
"os"
"path"
"regexp"
Expand Down Expand Up @@ -163,7 +162,7 @@ func (wi *WorkflowInclude) ValidateChecksum(contents string) (bool, string, erro
}

for _, url := range checksumUrls {
if !App.Gatekeeper.CanAccessUrl(url) {
if !App.Gateway.Allowed(url) {
support.FailureMessageWithXMark("Access to " + url + " is not allowed.")
continue
}
Expand Down Expand Up @@ -395,6 +394,8 @@ func (workflow *StackupWorkflow) Initialize() {
workflow.Settings.Domains.Allowed = []string{"raw.githubusercontent.com", "api.github.com"}
}

App.Gateway.SetAllowedDomains(workflow.Settings.Domains.Allowed)

if workflow.Settings.Cache.TtlMinutes <= 0 {
workflow.Settings.Cache.TtlMinutes = 5
}
Expand All @@ -418,18 +419,17 @@ func (workflow *StackupWorkflow) Initialize() {
}
}

// ensure that the allowed domains are in the correct format, i.e. without a protocol or port
tempDomains := []string{}
for _, domain := range workflow.Settings.Domains.Allowed {
if strings.Contains(domain, "://") {
parsedUrl, _ := url.Parse(domain)
tempDomains = append(tempDomains, parsedUrl.Host)
} else {
tempDomains = append(tempDomains, domain)
}
}

copy(workflow.Settings.Domains.Allowed, tempDomains)
// // ensure that the allowed domains are in the correct format, i.e. without a protocol or port
// tempDomains := []string{}
// for _, domain := range workflow.Settings.Domains.Allowed {
// if strings.Contains(domain, "://") {
// parsedUrl, _ := url.Parse(domain)
// tempDomains = append(tempDomains, parsedUrl.Host)
// } else {
// tempDomains = append(tempDomains, domain)
// }
// }
// copy(workflow.Settings.Domains.Allowed, tempDomains)
// workflow.Settings.Domains.Allowed = tempDomains

// initialize the includes
Expand Down Expand Up @@ -498,7 +498,7 @@ func (workflow *StackupWorkflow) ProcessInclude(include *WorkflowInclude) bool {
}

if include.IsS3Url() || include.IsRemoteUrl() {
if !App.Gatekeeper.CanAccessUrl(include.FullUrl()) {
if !App.Gateway.Allowed(include.FullUrl()) {
support.FailureMessageWithXMark("Access to " + include.FullUrl() + " is not allowed.")
return false
}
Expand Down
174 changes: 174 additions & 0 deletions lib/gateway/gateway.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package gateway

import (
"fmt"
"io"
"net/http"
"net/url"
"strings"

glob "github.com/ryanuber/go-glob"
)

type GatewayUrlRequestMiddleware struct {
Name string
Handler func(g *Gateway, link string) error
}

type Gateway struct {
Enabled bool
AllowedDomains []string
DeniedDomains []string
Middleware []*GatewayUrlRequestMiddleware
}

// New initializes the gateway with deny/allow lists
func New(deniedDomains, allowedDomains []string) *Gateway {
result := Gateway{
Enabled: true,
DeniedDomains: deniedDomains,
AllowedDomains: allowedDomains,
Middleware: []*GatewayUrlRequestMiddleware{},
}

result.Initialize()

return &result
}

func (g *Gateway) Initialize() {
g.normalizeDataArray(&g.DeniedDomains)
g.normalizeDataArray(&g.AllowedDomains)

g.AddMiddleware(&ValidateUrlMiddleware)
g.AddMiddleware(&VerifyFileTypeMiddleware)

g.Enable()
}

func (g *Gateway) SetAllowedDomains(domains []string) {
g.AllowedDomains = domains
g.normalizeDataArray(&g.AllowedDomains)
}

func (g *Gateway) SetDeniedDomains(domains []string) {
g.DeniedDomains = domains
g.normalizeDataArray(&g.DeniedDomains)
}

func (g *Gateway) AddMiddleware(mw *GatewayUrlRequestMiddleware) {
g.Middleware = append(g.Middleware, mw)
}

// The `runUrlRequestPipeline` function is a method of the `Gateway` struct. It iterates over the
// `Middleware` slice of the `Gateway` struct and executes each middleware function in order. Each
// middleware function takes a `Gateway` instance and a URL `link` as parameters and returns an error.
// If any middleware function returns an error, the `runUrlRequestPipeline` function immediately
// returns that error. If all middleware functions are executed successfully, the function returns
// `nil`.
func (g *Gateway) runUrlRequestPipeline(link string) error {
for _, mw := range g.Middleware {
fmt.Printf("running middleware: %s\n", mw.Name)

err := (*mw).Handler(g, link)
if err != nil {
return err
}
}

return nil
}

func (g *Gateway) Allowed(link string) bool {
return g.runUrlRequestPipeline(link) == nil
}

// processes an array of domains and remove any empty strings and extract hostnames from URLs if
// they are present, then copy the result back to the original array so we have an array of only
// hostnames with or without wildcard characters.
func (g *Gateway) normalizeDataArray(arr *[]string) {
tempDomains := []string{}

for _, domain := range *arr {
if len(strings.TrimSpace(domain)) == 0 {
continue
}
if strings.Contains(domain, "://") {
parsedUrl, _ := url.Parse(domain)
domain = parsedUrl.Host
}

tempDomains = append(tempDomains, domain)
}

copy(*arr, tempDomains)
}

func (g *Gateway) Enable() {
g.Enabled = true
}

func (g *Gateway) Disable() {
g.Enabled = false
}

func (g *Gateway) checkArrayForMatch(arr *[]string, s string) bool {
for _, domain := range *arr {
if strings.EqualFold(s, domain) {
return true
}
if strings.Contains(domain, "*") && glob.Glob(domain, s) {
return true
}
}

return false
}

// GetUrl returns the contents of a URL as a string, assuming it
// is allowed by the gateway, otherwise it returns an error.
func (g *Gateway) GetUrl(urlStr string, headers ...string) (string, error) {
err := g.runUrlRequestPipeline(urlStr)
if err != nil {
return "", err
}

// remove the header items that are empty strings:
var tempHeaders []string
for _, header := range headers {
if strings.TrimSpace(header) != "" {
tempHeaders = append(tempHeaders, header)
}
}

req, err := http.NewRequest("GET", urlStr, nil)
if err != nil {
return "", err
}

// Add headers to the request
for _, header := range tempHeaders {
parts := strings.SplitN(header, ":", 2)
if len(parts) == 2 {
req.Header.Set(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
}
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

if resp.StatusCode >= 400 {
return "", fmt.Errorf("HTTP error: %d", resp.StatusCode)
}

// Read the response body into a byte slice
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}

return string(body), nil
}
Loading

0 comments on commit 746e775

Please sign in to comment.