Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 65 additions & 50 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,13 @@ import (
"github.com/openfaas/nats-queue-worker/nats"
)

func makeFunctionURL(req *queue.Request, config *QueueWorkerConfig, path, queryString string) string {
qs := ""
if len(queryString) > 0 {
qs = fmt.Sprintf("?%s", strings.TrimLeft(queryString, "?"))
}
pathVal := "/"
if len(path) > 0 {
pathVal = path
}
functionURL := fmt.Sprintf("http://%s%s:8080%s%s",
req.Function,
config.FunctionSuffix,
pathVal,
qs)

if config.GatewayInvoke {
functionURL = fmt.Sprintf("http://%s:%d/function/%s%s%s",
config.GatewayAddress,
config.GatewayPort,
strings.Trim(req.Function, "/"),
pathVal,
qs)
}

return functionURL
}

func main() {
readConfig := ReadConfig{}
config := readConfig.Read()
config, configErr := readConfig.Read()
if configErr != nil {
panic(configErr)
}

log.SetFlags(0)

hostname, _ := os.Hostname()
Expand Down Expand Up @@ -88,16 +65,20 @@ func main() {

xCallID := req.Header.Get("X-Call-Id")

fmt.Printf("Invoking: %s, %d bytes.\n", req.Function, len(req.Body))
functionURL := makeFunctionURL(&req, &config, req.Path, req.QueryString)
fmt.Printf("Invoking: %s with %d bytes, via: %s\n", req.Function, len(req.Body), functionURL)

if config.DebugPrintBody {
fmt.Println(string(req.Body))
}

functionURL := makeFunctionURL(&req, &config, req.Path, req.QueryString)

start := time.Now()
request, err := http.NewRequest(http.MethodPost, functionURL, bytes.NewReader(req.Body))
if err != nil {
log.Printf("Unable to post message due to invalid URL, error: %s", err.Error())
return
}

defer request.Body.Close()
copyHeaders(request.Header, &req.Header)

Expand All @@ -108,46 +89,48 @@ func main() {

var statusCode int
if err != nil {

statusCode = http.StatusServiceUnavailable
} else {
statusCode = res.StatusCode
}

duration := time.Since(start)

log.Printf("Invoked: %s [%d] in %fs", req.Function, statusCode, duration.Seconds())

if err != nil {
status = http.StatusServiceUnavailable

log.Println(err)
log.Printf("Error invoking %s, error: %s", req.Function, err)

timeTaken := time.Since(started).Seconds()

if req.CallbackURL != nil {
log.Printf("Callback to: %s\n", req.CallbackURL.String())

resultStatusCode, resultErr := postResult(&client,
res,
functionResult,
req.CallbackURL.String(),
xCallID,
status)

if resultErr != nil {
log.Println(resultErr)
log.Printf("Posted callback to: %s - status %d, error: %s\n", req.CallbackURL.String(), http.StatusServiceUnavailable, resultErr.Error())
} else {
log.Printf("Posted result: %d", resultStatusCode)
log.Printf("Posted result to %s - status: %d", req.CallbackURL.String(), resultStatusCode)
}
}

if config.GatewayInvoke == false {
statusCode, reportErr := postReport(&client, req.Function, status, timeTaken, config.GatewayAddress, credentials)
statusCode, reportErr := postReport(&client, req.Function, status, timeTaken, config.GatewayAddressURL(), credentials)
if reportErr != nil {
log.Println(reportErr)
log.Printf("Error posting report: %s\n", reportErr)
} else {
log.Printf("Posting report - %d\n", statusCode)
log.Printf("Posting report to gateway for %s - status: %d\n", req.Function, statusCode)
}
return
}

return
}

if res.Body != nil {
Expand All @@ -157,42 +140,41 @@ func main() {
functionResult = resData

if err != nil {
log.Println(err)
log.Printf("Error reading body for: %s, error: %s", req.Function, err)
}

if config.WriteDebug {
fmt.Println(string(functionResult))
} else {
fmt.Printf("Wrote %d Bytes\n", len(string(functionResult)))
fmt.Printf("%s returned %d bytes\n", req.Function, len(functionResult))
}
}

timeTaken := time.Since(started).Seconds()

fmt.Println(res.Status)

if req.CallbackURL != nil {
log.Printf("Callback to: %s\n", req.CallbackURL.String())

resultStatusCode, resultErr := postResult(&client,
res,
functionResult,
req.CallbackURL.String(),
xCallID,
res.StatusCode)

if resultErr != nil {
log.Println(resultErr)
log.Printf("Error posting to callback-url: %s\n", resultErr)
} else {
log.Printf("Posted result: %d", resultStatusCode)
log.Printf("Posted result for %s to callback-url: %s, status: %d", req.Function, req.CallbackURL.String(), resultStatusCode)
}
}

if config.GatewayInvoke == false {
statusCode, reportErr := postReport(&client, req.Function, res.StatusCode, timeTaken, config.GatewayAddress, credentials)

statusCode, reportErr := postReport(&client, req.Function, res.StatusCode, timeTaken, config.GatewayAddressURL(), credentials)
if reportErr != nil {
log.Println(reportErr)
log.Printf("Error posting report: %s\n", reportErr.Error())
} else {
log.Printf("Posting report - %d\n", statusCode)
log.Printf("Posting report for %s, status: %d\n", req.Function, statusCode)
}
}

Expand Down Expand Up @@ -279,6 +261,10 @@ func postResult(client *http.Client, functionRes *http.Response, result []byte,

request, err := http.NewRequest(http.MethodPost, callbackURL, reader)

if err != nil {
return http.StatusInternalServerError, fmt.Errorf("unable to post result, error: %s", err.Error())
}

if functionRes != nil {
copyHeaders(request.Header, &functionRes.Header)
}
Expand All @@ -302,6 +288,7 @@ func postResult(client *http.Client, functionRes *http.Response, result []byte,
if res.Body != nil {
defer res.Body.Close()
}

return res.StatusCode, nil
}

Expand Down Expand Up @@ -342,3 +329,31 @@ func postReport(client *http.Client, function string, statusCode int, timeTaken

return res.StatusCode, nil
}

func makeFunctionURL(req *queue.Request, config *QueueWorkerConfig, path, queryString string) string {
qs := ""
if len(queryString) > 0 {
qs = fmt.Sprintf("?%s", strings.TrimLeft(queryString, "?"))
}
pathVal := "/"
if len(path) > 0 {
pathVal = path
}

var functionURL string
if config.GatewayInvoke {
functionURL = fmt.Sprintf("http://%s/function/%s%s%s",
config.GatewayAddressURL(),
strings.Trim(req.Function, "/"),
pathVal,
qs)
} else {
functionURL = fmt.Sprintf("http://%s%s:8080%s%s",
req.Function,
config.FunctionSuffix,
pathVal,
qs)
}

return functionURL
}
40 changes: 22 additions & 18 deletions readconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const DefaultMaxReconnect = 120

const DefaultReconnectDelay = time.Second * 2

func (ReadConfig) Read() QueueWorkerConfig {
func (ReadConfig) Read() (QueueWorkerConfig, error) {
cfg := QueueWorkerConfig{
AckWait: time.Second * 30,
MaxInflight: 1,
Expand Down Expand Up @@ -54,16 +54,15 @@ func (ReadConfig) Read() QueueWorkerConfig {
if value, exists := os.LookupEnv("faas_gateway_port"); exists {
val, err := strconv.Atoi(value)
if err != nil {
log.Println("converting faas_gateway_port to int error:", err)
} else {
cfg.GatewayPort = val
return QueueWorkerConfig{}, fmt.Errorf("converting faas_gateway_port %s to int error: %s", value, err)
}

cfg.GatewayPort = val

} else {
cfg.GatewayPort = 8080
}

cfg.GatewayAddress = fmt.Sprintf("%s:%d", cfg.GatewayAddress, cfg.GatewayPort)

if val, exists := os.LookupEnv("faas_function_suffix"); exists {
cfg.FunctionSuffix = val
}
Expand Down Expand Up @@ -138,23 +137,28 @@ func (ReadConfig) Read() QueueWorkerConfig {
cfg.BasicAuth = true
}
}

return cfg
var err error
return cfg, err
}

type QueueWorkerConfig struct {
NatsAddress string
NatsPort int
NatsClusterName string
GatewayAddress string
GatewayPort int
FunctionSuffix string
DebugPrintBody bool
WriteDebug bool
MaxInflight int
AckWait time.Duration
MaxReconnect int
ReconnectDelay time.Duration
GatewayInvoke bool // GatewayInvoke invoke functions through gateway rather than directly
BasicAuth bool

GatewayPort int
FunctionSuffix string
DebugPrintBody bool
WriteDebug bool
MaxInflight int
AckWait time.Duration
MaxReconnect int
ReconnectDelay time.Duration
GatewayInvoke bool // GatewayInvoke invoke functions through gateway rather than directly
BasicAuth bool
}

func (q QueueWorkerConfig) GatewayAddressURL() string {
return fmt.Sprintf("%s:%d", q.GatewayAddress, q.GatewayPort)
}
Loading