Skip to content

Commit

Permalink
Propagating IG headers to it's nodes. (kubeflow#2396)
Browse files Browse the repository at this point in the history
* Propagting headers from IG to ISVCs

Signed-off-by: rchauhan4 <rachit_chauhan@intuit.com>
Signed-off-by: rachitchauhann43 <rachitchauhan43@gmail.com>

* Propagating headers fro IG to nodes

Signed-off-by: rchauhan4 <rachit_chauhan@intuit.com>
Signed-off-by: rachitchauhann43 <rachitchauhan43@gmail.com>

* Using global http client instead of new client everytime

Signed-off-by: rachitchauhann43 <rachitchauhan43@gmail.com>

* Making headers propagation configurable

Signed-off-by: rchauhan4 <rachit_chauhan@intuit.com>

* Incorporating review comments

Signed-off-by: rchauhan4 <rachit_chauhan@intuit.com>

* Incorporated review comments

Signed-off-by: rachitchauhan43 <rachitchauhan43@gmail.com>

Signed-off-by: rchauhan4 <rachit_chauhan@intuit.com>
Signed-off-by: rachitchauhann43 <rachitchauhan43@gmail.com>
Signed-off-by: rachitchauhan43 <rachitchauhan43@gmail.com>
Co-authored-by: rchauhan4 <rachit_chauhan@intuit.com>
  • Loading branch information
rachitchauhan43 and rchauhan4 authored Oct 23, 2022
1 parent b45b6d7 commit 9499aba
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 18 deletions.
38 changes: 25 additions & 13 deletions cmd/router/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/kserve/kserve/pkg/constants"
"io/ioutil"
"net/http"
"os"
"strconv"
"strings"
"time"

"github.com/tidwall/gjson"
Expand All @@ -38,8 +40,18 @@ import (

var log = logf.Log.WithName("InferenceGraphRouter")

func callService(serviceUrl string, input []byte) ([]byte, error) {
resp, err := http.Post(serviceUrl, "application/json", bytes.NewBuffer(input))
func callService(serviceUrl string, input []byte, headers http.Header) ([]byte, error) {
req, err := http.NewRequest("POST", serviceUrl, bytes.NewBuffer(input))
for _, h := range headersToPropagate {
if values, ok := headers[h]; ok {
for _, v := range values {
req.Header.Add(h, v)
}
}
}
req.Header.Add("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)

if err != nil {
log.Error(err, "An error has occurred from service", "service", serviceUrl)
return nil, err
Expand Down Expand Up @@ -83,20 +95,19 @@ func timeTrack(start time.Time, name string) {
log.Info("elapsed time", "node", name, "time", elapsed)
}

func routeStep(nodeName string, graph v1alpha1.InferenceGraphSpec, input []byte) ([]byte, error) {
log.Info("current step", "nodeName", nodeName)
func routeStep(nodeName string, graph v1alpha1.InferenceGraphSpec, input []byte, headers http.Header) ([]byte, error) {
defer timeTrack(time.Now(), nodeName)
currentNode := graph.Nodes[nodeName]

if currentNode.RouterType == v1alpha1.Splitter {
return executeStep(pickupRoute(currentNode.Steps), graph, input)
return executeStep(pickupRoute(currentNode.Steps), graph, input, headers)
}
if currentNode.RouterType == v1alpha1.Switch {
route := pickupRouteByCondition(input, currentNode.Steps)
if route == nil {
return input, nil //TODO maybe should fail in this case?
}
return executeStep(route, graph, input)
return executeStep(route, graph, input, headers)
}
if currentNode.RouterType == v1alpha1.Ensemble {
ensembleRes := make([]chan map[string]interface{}, len(currentNode.Steps))
Expand All @@ -106,7 +117,7 @@ func routeStep(nodeName string, graph v1alpha1.InferenceGraphSpec, input []byte)
resultChan := make(chan map[string]interface{})
ensembleRes[i] = resultChan
go func() {
output, err := executeStep(step, graph, input)
output, err := executeStep(step, graph, input, headers)
if err == nil {
var res map[string]interface{}
if err = json.Unmarshal(output, &res); err == nil {
Expand Down Expand Up @@ -151,7 +162,7 @@ func routeStep(nodeName string, graph v1alpha1.InferenceGraphSpec, input []byte)
return responseBytes, nil
}
}
if responseBytes, err = executeStep(step, graph, request); err != nil {
if responseBytes, err = executeStep(step, graph, request, headers); err != nil {
return nil, err
}
}
Expand All @@ -161,19 +172,19 @@ func routeStep(nodeName string, graph v1alpha1.InferenceGraphSpec, input []byte)
return nil, fmt.Errorf("invalid route type: %v", currentNode.RouterType)
}

func executeStep(step *v1alpha1.InferenceStep, graph v1alpha1.InferenceGraphSpec, input []byte) ([]byte, error) {
func executeStep(step *v1alpha1.InferenceStep, graph v1alpha1.InferenceGraphSpec, input []byte, headers http.Header) ([]byte, error) {
if step.NodeName != "" {
// when nodeName is specified make a recursive call for routing to next step
return routeStep(step.NodeName, graph, input)
return routeStep(step.NodeName, graph, input, headers)
}
return callService(step.ServiceURL, input)
return callService(step.ServiceURL, input, headers)
}

var inferenceGraph *v1alpha1.InferenceGraphSpec

func graphHandler(w http.ResponseWriter, req *http.Request) {
inputBytes, _ := ioutil.ReadAll(req.Body)
if response, err := routeStep(v1alpha1.GraphRootNodeName, *inferenceGraph, inputBytes); err != nil {
if response, err := routeStep(v1alpha1.GraphRootNodeName, *inferenceGraph, inputBytes, req.Header); err != nil {
log.Error(err, "failed to process request")
w.WriteHeader(500) //TODO status code tbd
w.Write([]byte(fmt.Sprintf("Failed to process request: %v", err)))
Expand All @@ -183,7 +194,8 @@ func graphHandler(w http.ResponseWriter, req *http.Request) {
}

var (
jsonGraph = flag.String("graph-json", "", "serialized json graph def")
jsonGraph = flag.String("graph-json", "", "serialized json graph def")
headersToPropagate = strings.Split(os.Getenv(constants.RouterHeadersPropagateEnvVar), ",")
)

func main() {
Expand Down
158 changes: 153 additions & 5 deletions cmd/router/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,17 @@ func TestSimpleModelChainer(t *testing.T) {
},
}
jsonBytes, _ := json.Marshal(input)
res, err := routeStep("root", graphSpec, jsonBytes)
headers := http.Header{
"Authorization": {"Bearer Token"},
}

res, err := routeStep("root", graphSpec, jsonBytes, headers)
var response map[string]interface{}
err = json.Unmarshal(res, &response)
expectedResponse := map[string]interface{}{
"predictions": "2",
}
fmt.Printf("final response:%v", response)
fmt.Printf("final response:%v\n", response)
assert.Equal(t, expectedResponse, response)
}

Expand Down Expand Up @@ -141,7 +145,10 @@ func TestSimpleModelEnsemble(t *testing.T) {
},
}
jsonBytes, _ := json.Marshal(input)
res, err := routeStep("root", graphSpec, jsonBytes)
headers := http.Header{
"Authorization": {"Bearer Token"},
}
res, err := routeStep("root", graphSpec, jsonBytes, headers)
var response map[string]interface{}
err = json.Unmarshal(res, &response)
expectedResponse := map[string]interface{}{
Expand All @@ -152,7 +159,7 @@ func TestSimpleModelEnsemble(t *testing.T) {
"predictions": "2",
},
}
fmt.Printf("final response:%v", response)
fmt.Printf("final response:%v\n", response)
assert.Equal(t, expectedResponse, response)
}

Expand Down Expand Up @@ -317,7 +324,10 @@ func TestInferenceGraphWithCondition(t *testing.T) {
},
}
jsonBytes, _ := json.Marshal(input)
res, err := routeStep("root", graphSpec, jsonBytes)
headers := http.Header{
"Authorization": {"Bearer Token"},
}
res, err := routeStep("root", graphSpec, jsonBytes, headers)
var response map[string]interface{}
err = json.Unmarshal(res, &response)
expectedModel3Response := map[string]interface{}{
Expand Down Expand Up @@ -345,3 +355,141 @@ func TestInferenceGraphWithCondition(t *testing.T) {
assert.Equal(t, expectedModel3Response, response["model3"])
assert.Equal(t, expectedModel4Response, response["model4"])
}

func TestCallServiceWhenNoneHeadersToPropagateIsEmpty(t *testing.T) {
// Start a local HTTP server
model1 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, err := ioutil.ReadAll(req.Body)
if err != nil {
return
}
// Putting headers as part of response so that we can assert the headers' presence later
response := make(map[string]interface{})
response["predictions"] = "1"
for _, h := range headersToPropagate {
response[h] = req.Header[h][0]
}
responseBytes, err := json.Marshal(response)
_, err = rw.Write(responseBytes)
}))
model1Url, err := apis.ParseURL(model1.URL)
if err != nil {
t.Fatalf("Failed to parse model url")
}
defer model1.Close()

input := map[string]interface{}{
"instances": []string{
"test",
"test2",
},
}
jsonBytes, _ := json.Marshal(input)
headers := http.Header{
"Authorization": {"Bearer Token"},
"Test-Header-Key": {"Test-Header-Value"},
}
// Propagating no header
headersToPropagate = []string{}
res, err := callService(model1Url.String(), jsonBytes, headers)
var response map[string]interface{}
err = json.Unmarshal(res, &response)
expectedResponse := map[string]interface{}{
"predictions": "1",
}
fmt.Printf("final response:%v\n", response)
assert.Equal(t, expectedResponse, response)
}

func TestCallServiceWhen1HeaderToPropagate(t *testing.T) {
// Start a local HTTP serverq
model1 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, err := ioutil.ReadAll(req.Body)
if err != nil {
return
}
// Putting headers as part of response so that we can assert the headers' presence later
response := make(map[string]interface{})
response["predictions"] = "1"
for _, h := range headersToPropagate {
response[h] = req.Header[h][0]
}
responseBytes, err := json.Marshal(response)
_, err = rw.Write(responseBytes)
}))
model1Url, err := apis.ParseURL(model1.URL)
if err != nil {
t.Fatalf("Failed to parse model url")
}
defer model1.Close()

input := map[string]interface{}{
"instances": []string{
"test",
"test2",
},
}
jsonBytes, _ := json.Marshal(input)
headers := http.Header{
"Authorization": {"Bearer Token"},
"Test-Header-Key": {"Test-Header-Value"},
}
// Propagating only 1 header "Test-Header-Key"
headersToPropagate = []string{"Test-Header-Key"}
res, err := callService(model1Url.String(), jsonBytes, headers)
var response map[string]interface{}
err = json.Unmarshal(res, &response)
expectedResponse := map[string]interface{}{
"predictions": "1",
"Test-Header-Key": "Test-Header-Value",
}
fmt.Printf("final response:%v\n", response)
assert.Equal(t, expectedResponse, response)
}

func TestCallServiceWhenMultipleHeadersToPropagate(t *testing.T) {
// Start a local HTTP server
model1 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, err := ioutil.ReadAll(req.Body)
if err != nil {
return
}
// Putting headers as part of response so that we can assert the headers' presence later
response := make(map[string]interface{})
response["predictions"] = "1"
for _, h := range headersToPropagate {
response[h] = req.Header[h][0]
}
responseBytes, err := json.Marshal(response)
_, err = rw.Write(responseBytes)
}))
model1Url, err := apis.ParseURL(model1.URL)
if err != nil {
t.Fatalf("Failed to parse model url")
}
defer model1.Close()

input := map[string]interface{}{
"instances": []string{
"test",
"test2",
},
}
jsonBytes, _ := json.Marshal(input)
headers := http.Header{
"Authorization": {"Bearer Token"},
"Test-Header-Key": {"Test-Header-Value"},
}
// Propagating multiple headers "Test-Header-Key"
headersToPropagate = []string{"Test-Header-Key", "Authorization"}
res, err := callService(model1Url.String(), jsonBytes, headers)
var response map[string]interface{}
err = json.Unmarshal(res, &response)
expectedResponse := map[string]interface{}{
"predictions": "1",
"Test-Header-Key": "Test-Header-Value",
"Authorization": "Bearer Token",
}
fmt.Printf("final response:%v\n", response)
assert.Equal(t, expectedResponse, response)
}
5 changes: 5 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ var (
InferenceServiceConfigMapName = "inferenceservice-config"
)

// InferenceGraph Constants
const (
RouterHeadersPropagateEnvVar = "PROPAGATE_HEADERS"
)

// TrainedModel Constants
var (
TrainedModelAllocated = KServeAPIGroupName + "/" + "trainedmodel-allocated"
Expand Down
13 changes: 13 additions & 0 deletions pkg/controller/v1alpha1/inferencegraph/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ type RouterConfig struct {
CpuLimit string `json:"cpuLimit"`
MemoryRequest string `json:"memoryRequest"`
MemoryLimit string `json:"memoryLimit"`
/*
Example of how to add headers in router config:
headers: {
"propagate": [
"Custom-Header1",
"Custom-Header2"
]
}
Note: Making Headers, a map of strings, gives the flexibility to extend it in the future to support adding more
operations on headers. For example: Similar to "propagate" operation, one can add "transform" operation if they
want to transform headers keys or values before passing down to nodes.
*/
Headers map[string][]string `json:"headers"`
}

func getRouterConfigs(configMap *v1.ConfigMap) (*RouterConfig, error) {
Expand Down
13 changes: 13 additions & 0 deletions pkg/controller/v1alpha1/inferencegraph/knative_reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
knservingv1 "knative.dev/serving/pkg/apis/serving/v1"
"sigs.k8s.io/controller-runtime/pkg/client"
logf "sigs.k8s.io/controller-runtime/pkg/log"
"strings"
)

var log = logf.Log.WithName("GraphKsvcReconciler")
Expand Down Expand Up @@ -170,6 +171,18 @@ func createKnativeService(componentMeta metav1.ObjectMeta, graph *v1alpha1api.In
},
},
}

// Only adding this env variable "PROPAGATE_HEADERS" if router's headers config has the key "propagate"
value, exists := config.Headers["propagate"]
if exists {
service.Spec.ConfigurationSpec.Template.Spec.PodSpec.Containers[0].Env = []v1.EnvVar{
{
Name: constants.RouterHeadersPropagateEnvVar,
Value: strings.Join(value, ","),
},
}
}

//Call setDefaults on desired knative service here to avoid diffs generated because knative defaulter webhook is
//called when creating or updating the knative service
service.SetDefaults(context.TODO())
Expand Down

0 comments on commit 9499aba

Please sign in to comment.